@@ -86,7 +86,7 @@ def create_barrier_flags(m, n, l, mma_tiler_mn, cluster_shape_mn, sm_count):
86
86
barrier_size = Sm100BlockScaledPersistentDenseGemmKernel .compute_barrier_flag_size (
87
87
m , n , l , mma_tiler_mn , cluster_shape_mn , sm_count
88
88
)
89
- print ("LOOK HERE" , (barrier_size ,))
89
+ # print("LOOK HERE", (barrier_size,))
90
90
# NOTE: use_2cta_instrs from blockedscaled_gemm logic
91
91
92
92
# use_2cta_instrs = mma_tiler_mn[0] == 256
@@ -481,23 +481,23 @@ def multi_process_parallel(
481
481
@pytest .mark .parametrize (
482
482
"ab_dtype,sf_dtype,c_dtype,sf_vec_size" ,
483
483
[
484
- ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 )
485
- ("float4_e2m1fn" , "float8_e8m0fnu" , "float16" , 16 ),
486
- ("float4_e2m1fn" , "float8_e8m0fnu" , "bfloat16" , 16 ),
487
- ("float4_e2m1fn" , "float8_e8m0fnu" , "float32" , 16 ),
488
- ("float4_e2m1fn" , "float8_e4m3fn" , "float16" , 16 ),
489
- ("float4_e2m1fn" , "float8_e4m3fn" , "bfloat16" , 16 ),
490
- ("float4_e2m1fn" , "float8_e4m3fn" , "float32" , 16 ),
491
- ("float8_e4m3fn" , "float8_e8m0fnu" , "bfloat16" , 32 ),
492
- ("float8_e4m3fn" , "float8_e8m0fnu" , "float16" , 32 ),
493
- ("float8_e4m3fn" , "float8_e8m0fnu" , "float32" , 32 ),
484
+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32)
485
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float16", 16),
486
+ # ("float4_e2m1fn", "float8_e8m0fnu", "bfloat16", 16),
487
+ # ("float4_e2m1fn", "float8_e8m0fnu", "float32", 16),
488
+ # ("float4_e2m1fn", "float8_e4m3fn", "float16", 16),
489
+ # ("float4_e2m1fn", "float8_e4m3fn", "bfloat16", 16),
490
+ # ("float4_e2m1fn", "float8_e4m3fn", "float32", 16),
491
+ # ("float8_e4m3fn", "float8_e8m0fnu", "bfloat16", 32),
492
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float16", 32),
493
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float32", 32),
494
494
("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
495
- ("float8_e4m3fn" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
496
- ("float8_e5m2" , "float8_e8m0fnu" , "bfloat16" , 32 ),
497
- ("float8_e5m2" , "float8_e8m0fnu" , "float16" , 32 ),
498
- ("float8_e5m2" , "float8_e8m0fnu" , "float32" , 32 ),
499
- ("float8_e5m2" , "float8_e8m0fnu" , "float8_e4m3fn" , 32 ),
500
- ("float8_e5m2" , "float8_e8m0fnu" , "float8_e5m2" , 32 ),
495
+ # ("float8_e4m3fn", "float8_e8m0fnu", "float8_e5m2", 32),
496
+ # ("float8_e5m2", "float8_e8m0fnu", "bfloat16", 32),
497
+ # ("float8_e5m2", "float8_e8m0fnu", "float16", 32),
498
+ # ("float8_e5m2", "float8_e8m0fnu", "float32", 32),
499
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e4m3fn", 32),
500
+ # ("float8_e5m2", "float8_e8m0fnu", "float8_e5m2", 32),
501
501
],
502
502
)
503
503
@pytest .mark .parametrize ("a_major" , ["k" ])
0 commit comments