|
10 | 10 | from triton.experimental.gluon.language.nvidia import hopper
|
11 | 11 | from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
|
12 | 12 | from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
|
| 13 | +from triton.experimental.gluon.language.amd import _layouts as amd_layouts |
13 | 14 | from triton._filecheck import filecheck_test, run_parser
|
14 | 15 | from triton.runtime.jit import MockTensor
|
15 | 16 | import triton.language as tl
|
|
23 | 24 | HOPPER_TARGET = GPUTarget("cuda", 90, 32)
|
24 | 25 | AMPERE_TARGET = GPUTarget("cuda", 80, 32)
|
25 | 26 | HIP_TARGET = GPUTarget("hip", "gfx1200", 32)
|
| 27 | +HIP_TARGET_CDNA3 = GPUTarget("hip", "gfx942", 64) |
| 28 | +HIP_TARGET_CDNA4 = GPUTarget("hip", "gfx950", 64) |
26 | 29 |
|
27 | 30 | ALL_TARGETS = [AMPERE_TARGET, HOPPER_TARGET, BLACKWELL_TARGET, HIP_TARGET]
|
28 | 31 |
|
@@ -1338,3 +1341,91 @@ def test_auto_layout_broadcast():
|
1338 | 1341 | # CHECK: [[XBCAST2:%.*]] = tt.broadcast [[XCVT2]]
|
1339 | 1342 | # CHECK: arith.muli [[YBCAST2]], [[XBCAST2]] : tensor<16x16xi32, [[BLOCKED]]>
|
1340 | 1343 | _ = y * x
|
| 1344 | + |
| 1345 | + |
| 1346 | +@gluon.jit |
| 1347 | +def amd_mfma_layout_kernel(): |
| 1348 | + mfma_layout_fp32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True, |
| 1349 | + warps_per_cta=[4, 1], tiles_per_warp=[4, 1], |
| 1350 | + ctas_per_cga=[1, |
| 1351 | + 1], cta_split_num=[1, |
| 1352 | + 1], cta_order=[1, 0]) |
| 1353 | + |
| 1354 | + mfma_layout_fp64: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True, |
| 1355 | + warps_per_cta=[4, 1], tiles_per_warp=[4, 1], |
| 1356 | + elem_type=ttgl.float64, ctas_per_cga=[1, 1], |
| 1357 | + cta_split_num=[1, 1], cta_order=[1, 0]) |
| 1358 | + |
| 1359 | + layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0]) |
| 1360 | + |
| 1361 | + x_fp32 = ttgl.full([128, 32], 0, ttgl.float32, layout) |
| 1362 | + x_fp64 = ttgl.full([128, 32], 0, ttgl.float64, layout) |
| 1363 | + |
| 1364 | + ttgl.convert_layout(x_fp32, mfma_layout_fp32) |
| 1365 | + ttgl.convert_layout(x_fp64, mfma_layout_fp64) |
| 1366 | + |
| 1367 | + |
| 1368 | +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) |
| 1369 | +def test_amd_mfma_layout(target): |
| 1370 | + |
| 1371 | + module = run_parser(amd_mfma_layout_kernel, target=target) |
| 1372 | + expecttest.assert_expected_inline( |
| 1373 | + anonymize_ir(module.str_nodebug()), """\ |
| 1374 | +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> |
| 1375 | +#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}> |
| 1376 | +#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}> |
| 1377 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { |
| 1378 | + tt.func public @amd_mfma_layout_kernel() attributes {noinline = false} { |
| 1379 | + %cst = arith.constant 0.000000e+00 : f32 |
| 1380 | + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked> |
| 1381 | + %cst_1 = arith.constant 0.000000e+00 : f64 |
| 1382 | + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #blocked> |
| 1383 | + %0 = ttg.convert_layout %cst_0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #mma> |
| 1384 | + %1 = ttg.convert_layout %cst_2 : tensor<128x32xf64, #blocked> -> tensor<128x32xf64, #mma1> |
| 1385 | + tt.return |
| 1386 | + } |
| 1387 | +} |
| 1388 | +""") |
| 1389 | + |
| 1390 | + |
| 1391 | +@gluon.jit |
| 1392 | +def add_fp(a, b): |
| 1393 | + return a + b |
| 1394 | + |
| 1395 | + |
| 1396 | +@gluon.jit |
| 1397 | +def infer_layout_for_amd_mfma_kernel(): |
| 1398 | + layout: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True, |
| 1399 | + warps_per_cta=[4, 1], tiles_per_warp=[4, 1], ctas_per_cga=[1, 1], |
| 1400 | + cta_split_num=[1, 1], cta_order=[1, 0]) |
| 1401 | + a = ttgl.full([128, 32], 1.0, ttgl.float32, layout) |
| 1402 | + b = ttgl.reduce(a, 1, add_fp) |
| 1403 | + ttgl.static_assert(b.type.layout == ttgl.SliceLayout(1, layout)) |
| 1404 | + |
| 1405 | + |
| 1406 | +@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4]) |
| 1407 | +def test_infer_layout_for_amd_mfma(target): |
| 1408 | + module = run_parser(infer_layout_for_amd_mfma_kernel, target=target) |
| 1409 | + expecttest.assert_expected_inline( |
| 1410 | + anonymize_ir(module.str_nodebug()), """\ |
| 1411 | +#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}> |
| 1412 | +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} { |
| 1413 | + tt.func public @infer_layout_for_amd_mfma_kernel() attributes {noinline = false} { |
| 1414 | + %cst = arith.constant 1.000000e+00 : f32 |
| 1415 | + %cst_0 = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #mma> |
| 1416 | + %0 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({ |
| 1417 | + ^bb0(%arg0: f32, %arg1: f32): |
| 1418 | + %1 = tt.call @test_frontend.add_fp__fp32_fp32__(%arg0, %arg1) : (f32, f32) -> f32 |
| 1419 | + tt.reduce.return %1 : f32 |
| 1420 | + }) : (tensor<128x32xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> |
| 1421 | + tt.return |
| 1422 | + } |
| 1423 | + tt.func private @test_frontend.add_fp__fp32_fp32__(%arg0: f32, %arg1: f32) -> f32 attributes {noinline = false} { |
| 1424 | + %0 = arith.addf %arg0, %arg1 : f32 |
| 1425 | + tt.return %0 : f32 |
| 1426 | + ^bb1: // no predecessors |
| 1427 | + %1 = ub.poison : f32 |
| 1428 | + tt.return %1 : f32 |
| 1429 | + } |
| 1430 | +} |
| 1431 | +""") |
0 commit comments