@@ -1132,15 +1132,12 @@ def gen_mma_tests():
11321132
11331133
11341134def get_mma_block_scale_ops ():
1135- return (
1136- make_mma_ops (["m16n8k64" ], ["e2m1" ], [], ["f32" ], [])
1137- + make_mma_ops (
1138- ["m16n8k32" ],
1139- ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1140- ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1141- ["f32" ],
1142- [],
1143- )
1135+ return make_mma_ops (["m16n8k64" ], ["e2m1" ], [], ["f32" ], []) + make_mma_ops (
1136+ ["m16n8k32" ],
1137+ ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1138+ ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1139+ ["f32" ],
1140+ [],
11441141 )
11451142
11461143
@@ -1196,7 +1193,9 @@ def is_mma_block_scale_variant_supported(op, kind, scale_vec_size, stype):
11961193 return False
11971194
11981195
1199- def common_mma_block_scale_test_gen (params , op , intrinsic_template , instruction_template ):
1196+ def common_mma_block_scale_test_gen (
1197+ params , op , intrinsic_template , instruction_template
1198+ ):
12001199 mma_block_scale_template = """
12011200declare ${ret_ty} @${intrinsic}(
12021201 ${args});
@@ -1250,12 +1249,8 @@ def gen_mma_block_scale_tests():
12501249 if not (ptx_version >= 87 and gpu_arch >= 120 and aa ):
12511250 return []
12521251
1253- mma_block_scale_intrinsic_template = (
1254- "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
1255- )
1256- mma_block_scale_instruction_template = (
1257- "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
1258- )
1252+ mma_block_scale_intrinsic_template = "llvm.nvvm.mma.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
1253+ mma_block_scale_instruction_template = "mma.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
12591254
12601255 generated_items = []
12611256
@@ -1282,7 +1277,9 @@ def gen_mma_block_scale_tests():
12821277 instruction_template = mma_block_scale_instruction_template
12831278
12841279 generated_items .append (
1285- common_mma_block_scale_test_gen (params , op , intrinsic_template , instruction_template )
1280+ common_mma_block_scale_test_gen (
1281+ params , op , intrinsic_template , instruction_template
1282+ )
12861283 )
12871284
12881285 return generated_items
@@ -1381,7 +1378,7 @@ def is_mma_sp_variant_supported(op, metadata, kind, satf):
13811378 return True
13821379
13831380
1384- def sp_selector_gen (op , block_scale = False ):
1381+ def sp_selector_gen (op , block_scale = False ):
13851382 if block_scale :
13861383 # PTX ISA 9.0 has the sparsity selector equal to 0 only
13871384 return range (1 )
@@ -1517,16 +1514,13 @@ def gen_mma_sp_tests():
15171514
15181515
15191516def get_mma_sp_block_scale_ops ():
1520- return (
1521- make_mma_ops (["m16n8k128" ], ["e2m1" ], [], ["f32" ], [], True )
1522- + make_mma_ops (
1523- ["m16n8k64" ],
1524- ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1525- ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1526- ["f32" ],
1527- [],
1528- True ,
1529- )
1517+ return make_mma_ops (["m16n8k128" ], ["e2m1" ], [], ["f32" ], [], True ) + make_mma_ops (
1518+ ["m16n8k64" ],
1519+ ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1520+ ["e4m3" , "e5m2" , "e3m2" , "e2m3" , "e2m1" ],
1521+ ["f32" ],
1522+ [],
1523+ True ,
15301524 )
15311525
15321526
@@ -1582,7 +1576,9 @@ def is_mma_sp_block_scale_variant_supported(op, kind, scale_vec_size, stype):
15821576 return False
15831577
15841578
1585- def common_mma_sp_block_scale_test_gen (params , op , intrinsic_template , instruction_template ):
1579+ def common_mma_sp_block_scale_test_gen (
1580+ params , op , intrinsic_template , instruction_template
1581+ ):
15861582 mma_sp_block_scale_decl_template = """
15871583declare ${ret_ty} @${intrinsic}(
15881584 ${args});
@@ -1653,12 +1649,8 @@ def gen_mma_sp_block_scale_tests():
16531649 if not (ptx_version >= 87 and gpu_arch >= 120 and aa ):
16541650 return []
16551651
1656- mma_sp_block_scale_intrinsic_template = (
1657- "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
1658- )
1659- mma_sp_block_scale_instruction_template = (
1660- "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
1661- )
1652+ mma_sp_block_scale_intrinsic_template = "llvm.nvvm.mma.sp.ordered.metadata.block.scale.${geom}.row.col.${kind}${scale}.${intrinsic_signature}.${stype}"
1653+ mma_sp_block_scale_instruction_template = "mma.sp::ordered_metadata.sync.aligned.${geom}.row.col.kind::${kind}.block_scale${scale_vec_size}.${ptx_signature}.${stype}"
16621654
16631655 generated_items = []
16641656
@@ -1685,7 +1677,9 @@ def gen_mma_sp_block_scale_tests():
16851677 instruction_template = mma_sp_block_scale_instruction_template
16861678
16871679 generated_items .append (
1688- common_mma_sp_block_scale_test_gen (params , op , intrinsic_template , instruction_template )
1680+ common_mma_sp_block_scale_test_gen (
1681+ params , op , intrinsic_template , instruction_template
1682+ )
16891683 )
16901684
16911685 return generated_items
0 commit comments