@@ -1413,31 +1413,27 @@ def test_atomic_cas():
14131413
14141414@gluon .jit
14151415def amd_mfma_layout_kernel ():
1416- mfma_layout_fp32 : ttgl .constexpr = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ], transposed = True ,
1417- warps_per_cta = [4 , 1 ], tiles_per_warp = [4 , 1 ],
1418- ctas_per_cga = [1 ,
1419- 1 ], cta_split_num = [1 ,
1420- 1 ], cta_order = [1 , 0 ])
1416+ ttgl .full ([128 , 32 ], 0 , ttgl .float32 , layout = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ],
1417+ transposed = True , warps_per_cta = [4 , 1 ]))
14211418
1422- mfma_layout_fp64 : ttgl .constexpr = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [16 , 16 ], transposed = True ,
1423- warps_per_cta = [4 , 1 ], tiles_per_warp = [4 , 1 ],
1424- elem_type = ttgl .float64 , ctas_per_cga = [1 , 1 ],
1425- cta_split_num = [1 , 1 ], cta_order = [1 , 0 ])
1419+ ttgl .full ([128 , 32 ], 0 , ttgl .float32 ,
1420+ layout = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ], tiles_per_warp = [4 , 1 ], transposed = True ,
1421+ warps_per_cta = [4 , 1 ]))
14261422
1427- mfma_layout_int32 : ttgl .constexpr = amd_layouts . AMDMFMALayout ( version = 3 , instr_shape = [ 16 , 16 ], transposed = True ,
1428- warps_per_cta = [ 4 , 1 ], tiles_per_warp = [4 , 1 ],
1429- elem_type = ttgl . int32 , ctas_per_cga = [1 , 1 ],
1430- cta_split_num = [ 1 , 1 ], cta_order = [1 , 0 ])
1423+ ttgl .full ([ 128 , 32 ], 0 , ttgl . float32 ,
1424+ layout = amd_layouts . AMDMFMALayout ( version = 3 , instr_shape = [ 32 , 32 ], transposed = True , warps_per_cta = [4 , 1 ],
1425+ ctas_per_cga = [ 1 , 1 ], tiles_per_warp = [ 1 , 1 ], cta_split_num = [1 , 1 ],
1426+ cta_order = [1 , 0 ]) )
14311427
1432- layout : ttgl .constexpr = ttgl .BlockedLayout ([1 , 1 ], [1 , 64 ], [4 , 1 ], [1 , 0 ])
1428+ ttgl .full ([128 , 32 ], 0 , ttgl .float64 ,
1429+ layout = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [16 , 16 ], transposed = True , warps_per_cta = [4 , 1 ],
1430+ elem_type = ttgl .float64 , tiles_per_warp = [1 , 1 ], ctas_per_cga = [1 , 1 ],
1431+ cta_split_num = [1 , 1 ], cta_order = [1 , 0 ]))
14331432
1434- x_fp32 = ttgl .full ([128 , 32 ], 0 , ttgl .float32 , layout )
1435- x_fp64 = ttgl .full ([128 , 32 ], 0 , ttgl .float64 , layout )
1436- x_int32 = ttgl .full ([128 , 32 ], 0 , ttgl .int32 , layout )
1437-
1438- ttgl .convert_layout (x_fp32 , mfma_layout_fp32 )
1439- ttgl .convert_layout (x_fp64 , mfma_layout_fp64 )
1440- ttgl .convert_layout (x_int32 , mfma_layout_int32 )
1433+ ttgl .full ([128 , 32 ], 0 , ttgl .int32 ,
1434+ layout = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [16 , 16 ], transposed = True , warps_per_cta = [4 , 1 ],
1435+ elem_type = ttgl .int32 , tiles_per_warp = [1 , 1 ], ctas_per_cga = [1 , 1 ],
1436+ cta_split_num = [1 , 1 ]))
14411437
14421438
14431439@pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA3 , HIP_TARGET_CDNA4 ])
@@ -1446,21 +1442,22 @@ def test_amd_mfma_layout(target):
14461442 module = run_parser (amd_mfma_layout_kernel , target = target )
14471443 expecttest .assert_expected_inline (
14481444 anonymize_ir (module .str_nodebug ()), """\
1449- #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64 ], warpsPerCTA = [4, 1 ], order = [1, 0] }>
1450- #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1451- #mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1452- #mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
1445+ #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1 ], instrShape = [32, 32 ], isTransposed = true }>
1446+ #mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1447+ #mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1448+ #mma3 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
14531449module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14541450 tt.func public @amd_mfma_layout_kernel() attributes {noinline = false} {
14551451 %cst = arith.constant 0.000000e+00 : f32
1456- %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
1457- %cst_1 = arith.constant 0.000000e+00 : f64
1458- %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #blocked>
1452+ %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1453+ %cst_1 = arith.constant 0.000000e+00 : f32
1454+ %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma1>
1455+ %cst_3 = arith.constant 0.000000e+00 : f32
1456+ %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1457+ %cst_5 = arith.constant 0.000000e+00 : f64
1458+ %cst_6 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #mma2>
14591459 %c0_i32 = arith.constant 0 : i32
1460- %cst_3 = arith.constant dense<0> : tensor<128x32xi32, #blocked>
1461- %0 = ttg.convert_layout %cst_0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #mma>
1462- %1 = ttg.convert_layout %cst_2 : tensor<128x32xf64, #blocked> -> tensor<128x32xf64, #mma1>
1463- %2 = ttg.convert_layout %cst_3 : tensor<128x32xi32, #blocked> -> tensor<128x32xi32, #mma2>
1460+ %cst_7 = arith.constant dense<0> : tensor<128x32xi32, #mma3>
14641461 tt.return
14651462 }
14661463}
@@ -1475,8 +1472,8 @@ def add_int(a, b):
14751472@gluon .jit
14761473def infer_layout_for_amd_mfma_kernel ():
14771474 layout : ttgl .constexpr = amd_layouts .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ], transposed = True ,
1478- elem_type = ttgl . int32 , warps_per_cta = [4 ,
1479- 1 ], tiles_per_warp = [4 , 1 ],
1475+ warps_per_cta = [4 ,
1476+ 1 ], elem_type = ttgl . int32 , tiles_per_warp = [1 , 1 ],
14801477 ctas_per_cga = [1 , 1 ], cta_split_num = [1 , 1 ], cta_order = [1 , 0 ])
14811478 a = ttgl .full ([128 , 32 ], 1 , ttgl .int32 , layout )
14821479 b = ttgl .reduce (a , 1 , add_int )
@@ -1489,7 +1486,7 @@ def test_infer_layout_for_amd_mfma(target):
14891486
14901487 expecttest .assert_expected_inline (
14911488 anonymize_ir (module .str_nodebug ()), """\
1492- #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
1489+ #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
14931490module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14941491 tt.func public @infer_layout_for_amd_mfma_kernel() attributes {noinline = false} {
14951492 %c1_i32 = arith.constant 1 : i32
@@ -1719,3 +1716,49 @@ def test_buffer_load_store_with_broadcast(target):
17191716 }
17201717}
17211718""" )
1719+
1720+
1721+ @pytest .mark .parametrize ("target" , [HIP_TARGET_CDNA3 , HIP_TARGET_CDNA4 ])
1722+ def test_amd_mfma (target ):
1723+
1724+ @gluon .jit
1725+ def kernel ():
1726+ mfma_layout : ttgl .constexpr = ttgl .amd .AMDMFMALayout (version = 3 , instr_shape = [32 , 32 ], transposed = True ,
1727+ warps_per_cta = [4 , 1 ])
1728+
1729+ a = ttgl .full ([64 , 32 ], 1.0 , ttgl .float32 , layout = ttgl .DotOperandLayout (operand_index = 0 , parent = mfma_layout ,
1730+ k_width = 8 ))
1731+ b = ttgl .full ([32 , 64 ], 2.0 , ttgl .float32 , layout = ttgl .DotOperandLayout (operand_index = 1 , parent = mfma_layout ,
1732+ k_width = 8 ))
1733+
1734+ acc = ttgl .zeros ([64 , 64 ], ttgl .float32 , mfma_layout )
1735+ acc = ttgl .amd .cdna3 .mfma (a , b , acc )
1736+ ttgl .static_assert (isinstance (acc , ttgl .tensor ))
1737+ ttgl .static_assert (acc .type .layout == mfma_layout )
1738+
1739+ module = run_parser (kernel , target = target )
1740+
1741+ expecttest .assert_expected_inline (
1742+ anonymize_ir (module .str_nodebug ()), """\
1743+ #mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
1744+ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1745+ tt.func public @kernel() attributes {noinline = false} {
1746+ %cst = arith.constant 1.000000e+00 : f32
1747+ %cst_0 = arith.constant dense<1.000000e+00> : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
1748+ %cst_1 = arith.constant 2.000000e+00 : f32
1749+ %cst_2 = arith.constant dense<2.000000e+00> : tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
1750+ %0 = tt.call @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() : () -> tensor<64x64xf32, #mma>
1751+ %cst_3 = arith.constant 0.000000e+00 : f32
1752+ %1 = tt.dot %cst_0, %cst_2, %0 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x64xf32, #mma>
1753+ tt.return
1754+ }
1755+ tt.func private @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() -> tensor<64x64xf32, #mma> attributes {noinline = false} {
1756+ %cst = arith.constant 0.000000e+00 : f32
1757+ %cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
1758+ tt.return %cst_0 : tensor<64x64xf32, #mma>
1759+ ^bb1: // no predecessors
1760+ %0 = ub.poison : tensor<64x64xf32, #mma>
1761+ tt.return %0 : tensor<64x64xf32, #mma>
1762+ }
1763+ }
1764+ """ )
0 commit comments