Skip to content

Commit 787f9f9

Browse files
committed
remove element type options, only (f16,f16,f32) can be lowered
1 parent 42899a2 commit 787f9f9

File tree

1 file changed

+8
-18
lines changed

1 file changed

+8
-18
lines changed

python/examples/xegpu_matmul/matmul.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(
4444
M: int,
4545
N: int,
4646
K: int,
47-
ab_type: str = "f32",
47+
ab_type: str = "f16",
4848
c_type: str = "f32",
4949
has_bias: bool = False,
5050
has_relu: bool = False,
@@ -55,6 +55,8 @@ def __init__(
5555
self.a_shape = (M, K)
5656
self.b_shape = (K, N)
5757
self.c_shape = (M, N)
58+
assert ab_type == "f16", "Only f16 type is supported for A and B"
59+
assert c_type == "f32", "Only f32 type is supported for C"
5860
self.ab_type = ab_type
5961
self.c_type = c_type
6062
type_str_to_numpy = {
@@ -282,20 +284,6 @@ def parse_cli():
282284
default=1,
283285
help="Number of initial prefetches.",
284286
)
285-
parser.add_argument(
286-
"--ab-type",
287-
type=str,
288-
choices=["f16", "f32"],
289-
default="f16",
290-
help="Data type of A and B matrices.",
291-
)
292-
parser.add_argument(
293-
"--c-type",
294-
type=str,
295-
choices=["f16", "f32"],
296-
default="f32",
297-
help="Data type of the C matrix.",
298-
)
299287
parser.add_argument(
300288
"--nruns",
301289
type=int,
@@ -359,14 +347,16 @@ def parse_cli():
359347
}
360348

361349
M, N, K = args.sizes
350+
ab_type = "f16"
351+
c_type = "f32"
362352

363353
with ir.Context(), ir.Location.unknown():
364354
wload = XeGPUMatMul(
365355
M=M,
366356
N=N,
367357
K=K,
368-
ab_type=args.ab_type,
369-
c_type=args.c_type,
358+
ab_type=ab_type,
359+
c_type=c_type,
370360
has_bias=False,
371361
has_relu=args.relu,
372362
)
@@ -397,7 +387,7 @@ def list2str(a):
397387

398388
parts = [
399389
f"sizes={list2str(args.sizes)}",
400-
f"dt={args.ab_type},{args.c_type}",
390+
f"dt={ab_type},{c_type}",
401391
f"wg-tile={list2str(args.wg_tile)}",
402392
f"sg-tile={list2str(args.sg_tile)}",
403393
f"k-tile={args.k_tile}",

0 commit comments

Comments
 (0)