@@ -44,7 +44,7 @@ def __init__(
4444 M : int ,
4545 N : int ,
4646 K : int ,
47- ab_type : str = "f32 " ,
47+ ab_type : str = "f16 " ,
4848 c_type : str = "f32" ,
4949 has_bias : bool = False ,
5050 has_relu : bool = False ,
@@ -55,6 +55,8 @@ def __init__(
5555 self .a_shape = (M , K )
5656 self .b_shape = (K , N )
5757 self .c_shape = (M , N )
58+ assert ab_type == "f16" , "Only f16 type is supported for A and B"
59+ assert c_type == "f32" , "Only f32 type is supported for C"
5860 self .ab_type = ab_type
5961 self .c_type = c_type
6062 type_str_to_numpy = {
@@ -282,20 +284,6 @@ def parse_cli():
282284 default = 1 ,
283285 help = "Number of initial prefetches." ,
284286 )
285- parser .add_argument (
286- "--ab-type" ,
287- type = str ,
288- choices = ["f16" , "f32" ],
289- default = "f16" ,
290- help = "Data type of A and B matrices." ,
291- )
292- parser .add_argument (
293- "--c-type" ,
294- type = str ,
295- choices = ["f16" , "f32" ],
296- default = "f32" ,
297- help = "Data type of the C matrix." ,
298- )
299287 parser .add_argument (
300288 "--nruns" ,
301289 type = int ,
@@ -359,14 +347,16 @@ def parse_cli():
359347 }
360348
361349 M , N , K = args .sizes
350+ ab_type = "f16"
351+ c_type = "f32"
362352
363353 with ir .Context (), ir .Location .unknown ():
364354 wload = XeGPUMatMul (
365355 M = M ,
366356 N = N ,
367357 K = K ,
368- ab_type = args . ab_type ,
369- c_type = args . c_type ,
358+ ab_type = ab_type ,
359+ c_type = c_type ,
370360 has_bias = False ,
371361 has_relu = args .relu ,
372362 )
@@ -397,7 +387,7 @@ def list2str(a):
397387
398388 parts = [
399389 f"sizes={ list2str (args .sizes )} " ,
400- f"dt={ args . ab_type } ,{ args . c_type } " ,
390+ f"dt={ ab_type } ,{ c_type } " ,
401391 f"wg-tile={ list2str (args .wg_tile )} " ,
402392 f"sg-tile={ list2str (args .sg_tile )} " ,
403393 f"k-tile={ args .k_tile } " ,
0 commit comments