Skip to content

Commit d016691

Browse files
[AMD][GLUON] Expose MFMA layout (#7653)
This PR is to expose AMDMFMALayout in gluon so the kernel author can use it for better performance on AMD. --------- Co-authored-by: peterbell10 <[email protected]>
1 parent 7d0efaa commit d016691

File tree

6 files changed

+212
-0
lines changed

6 files changed

+212
-0
lines changed

.github/workflows/integration-tests-amd.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ jobs:
115115
if [ ! -d "${INSTRUMENTATION_LIB_DIR}" ]; then
116116
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
117117
fi
118+
119+
# Test gluon
120+
pytest --capture=tee-sys -rfs -n 8 python/test/gluon/
121+
118122
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
119123
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py
120124
TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py

python/src/gluon_ir.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,14 @@ struct GluonLayouts {
9494
py::handle NVMMADistributedLayout;
9595
py::handle NVMMASharedLayout;
9696
py::handle SwizzledSharedLayout;
97+
py::handle AMDMFMALayout;
98+
py::handle GluonDType;
9799

98100
GluonLayouts() {
99101
auto layouts =
100102
py::module::import("triton.experimental.gluon.language._layouts");
103+
auto amdLayouts =
104+
py::module::import("triton.experimental.gluon.language.amd._layouts");
101105
AutoLayout = py::object(layouts.attr("AutoLayout")).release();
102106
BlockedLayout = py::object(layouts.attr("BlockedLayout")).release();
103107
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
@@ -109,6 +113,10 @@ struct GluonLayouts {
109113
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
110114
SwizzledSharedLayout =
111115
py::object(layouts.attr("SwizzledSharedLayout")).release();
116+
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
117+
118+
auto core = py::module::import("triton.language.core");
119+
GluonDType = py::object(core.attr("dtype")).release();
112120
}
113121
};
114122

@@ -186,7 +194,22 @@ py::object layoutToGluon(Attribute layout) {
186194
toStdVector(ctaLayout.getCTAOrder()));
187195
} else if (auto autoEnc = dyn_cast<gluon::AutoEncodingAttr>(layout)) {
188196
return layouts.AutoLayout();
197+
} else if (auto amdMfma = dyn_cast<ttg::AMDMfmaEncodingAttr>(layout)) {
198+
auto ctaLayout = amdMfma.getCTALayout();
199+
std::vector<unsigned> instrShape{amdMfma.getMDim(), amdMfma.getNDim()};
200+
auto isFP32 = !amdMfma.getElementType().has_value() ||
201+
amdMfma.getElementType().value().isF32();
202+
203+
return layouts.AMDMFMALayout(amdMfma.getVersion(), instrShape,
204+
amdMfma.getIsTransposed(),
205+
toStdVector(amdMfma.getWarpsPerCTA()),
206+
toStdVector(amdMfma.getTilesPerWarp()),
207+
layouts.GluonDType(isFP32 ? "fp32" : "fp64"),
208+
toStdVector(ctaLayout.getCTAsPerCGA()),
209+
toStdVector(ctaLayout.getCTASplitNum()),
210+
toStdVector(ctaLayout.getCTAOrder()));
189211
}
212+
190213
throw py::value_error("Unhandled encoding encountered");
191214
}
192215

@@ -284,6 +307,22 @@ void init_gluon_ir(py::module &&m) {
284307
ctx, version[0], version[1], warpsPerCta, ctaLayout,
285308
instrShape);
286309
})
310+
.def("get_amd_mfma_layout",
311+
[](GluonOpBuilder &self, unsigned version,
312+
std::vector<unsigned> &tilesPerWarp,
313+
std::vector<unsigned> &warpsPerCta,
314+
std::vector<unsigned> &ctasPerCga,
315+
std::vector<unsigned> &ctaSplitNum,
316+
std::vector<unsigned> &ctaOrder,
317+
std::vector<unsigned> &instrShape, bool transposed,
318+
mlir::Type elemType) -> Attribute {
319+
auto ctx = self.getContext();
320+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
321+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
322+
return ttg::AMDMfmaEncodingAttr::get(
323+
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
324+
instrShape[1], transposed, ctaLayout, elemType);
325+
})
287326
.def("get_nvmma_shared_layout",
288327
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
289328
unsigned elementBitwidth, bool transposed, bool fp4Padded,

python/test/gluon/test_frontend.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from triton.experimental.gluon.language.nvidia import hopper
1111
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
1212
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
13+
from triton.experimental.gluon.language.amd import _layouts as amd_layouts
1314
from triton._filecheck import filecheck_test, run_parser
1415
from triton.runtime.jit import MockTensor
1516
import triton.language as tl
@@ -23,6 +24,8 @@
2324
HOPPER_TARGET = GPUTarget("cuda", 90, 32)
2425
AMPERE_TARGET = GPUTarget("cuda", 80, 32)
2526
HIP_TARGET = GPUTarget("hip", "gfx1200", 32)
27+
HIP_TARGET_CDNA3 = GPUTarget("hip", "gfx942", 64)
28+
HIP_TARGET_CDNA4 = GPUTarget("hip", "gfx950", 64)
2629

2730
ALL_TARGETS = [AMPERE_TARGET, HOPPER_TARGET, BLACKWELL_TARGET, HIP_TARGET]
2831

@@ -1338,3 +1341,91 @@ def test_auto_layout_broadcast():
13381341
# CHECK: [[XBCAST2:%.*]] = tt.broadcast [[XCVT2]]
13391342
# CHECK: arith.muli [[YBCAST2]], [[XBCAST2]] : tensor<16x16xi32, [[BLOCKED]]>
13401343
_ = y * x
1344+
1345+
1346+
@gluon.jit
1347+
def amd_mfma_layout_kernel():
1348+
mfma_layout_fp32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1349+
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1350+
ctas_per_cga=[1,
1351+
1], cta_split_num=[1,
1352+
1], cta_order=[1, 0])
1353+
1354+
mfma_layout_fp64: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True,
1355+
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1356+
elem_type=ttgl.float64, ctas_per_cga=[1, 1],
1357+
cta_split_num=[1, 1], cta_order=[1, 0])
1358+
1359+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0])
1360+
1361+
x_fp32 = ttgl.full([128, 32], 0, ttgl.float32, layout)
1362+
x_fp64 = ttgl.full([128, 32], 0, ttgl.float64, layout)
1363+
1364+
ttgl.convert_layout(x_fp32, mfma_layout_fp32)
1365+
ttgl.convert_layout(x_fp64, mfma_layout_fp64)
1366+
1367+
1368+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1369+
def test_amd_mfma_layout(target):
1370+
1371+
module = run_parser(amd_mfma_layout_kernel, target=target)
1372+
expecttest.assert_expected_inline(
1373+
anonymize_ir(module.str_nodebug()), """\
1374+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1375+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1376+
#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1377+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1378+
tt.func public @amd_mfma_layout_kernel() attributes {noinline = false} {
1379+
%cst = arith.constant 0.000000e+00 : f32
1380+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
1381+
%cst_1 = arith.constant 0.000000e+00 : f64
1382+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #blocked>
1383+
%0 = ttg.convert_layout %cst_0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #mma>
1384+
%1 = ttg.convert_layout %cst_2 : tensor<128x32xf64, #blocked> -> tensor<128x32xf64, #mma1>
1385+
tt.return
1386+
}
1387+
}
1388+
""")
1389+
1390+
1391+
@gluon.jit
1392+
def add_fp(a, b):
1393+
return a + b
1394+
1395+
1396+
@gluon.jit
1397+
def infer_layout_for_amd_mfma_kernel():
1398+
layout: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1399+
warps_per_cta=[4, 1], tiles_per_warp=[4, 1], ctas_per_cga=[1, 1],
1400+
cta_split_num=[1, 1], cta_order=[1, 0])
1401+
a = ttgl.full([128, 32], 1.0, ttgl.float32, layout)
1402+
b = ttgl.reduce(a, 1, add_fp)
1403+
ttgl.static_assert(b.type.layout == ttgl.SliceLayout(1, layout))
1404+
1405+
1406+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1407+
def test_infer_layout_for_amd_mfma(target):
1408+
module = run_parser(infer_layout_for_amd_mfma_kernel, target=target)
1409+
expecttest.assert_expected_inline(
1410+
anonymize_ir(module.str_nodebug()), """\
1411+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1412+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1413+
tt.func public @infer_layout_for_amd_mfma_kernel() attributes {noinline = false} {
1414+
%cst = arith.constant 1.000000e+00 : f32
1415+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<128x32xf32, #mma>
1416+
%0 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({
1417+
^bb0(%arg0: f32, %arg1: f32):
1418+
%1 = tt.call @test_frontend.add_fp__fp32_fp32__(%arg0, %arg1) : (f32, f32) -> f32
1419+
tt.reduce.return %1 : f32
1420+
}) : (tensor<128x32xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>>
1421+
tt.return
1422+
}
1423+
tt.func private @test_frontend.add_fp__fp32_fp32__(%arg0: f32, %arg1: f32) -> f32 attributes {noinline = false} {
1424+
%0 = arith.addf %arg0, %arg1 : f32
1425+
tt.return %0 : f32
1426+
^bb1: // no predecessors
1427+
%1 = ub.poison : f32
1428+
tt.return %1 : f32
1429+
}
1430+
}
1431+
""")

python/triton/experimental/gluon/language/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from ._standard import __all__ as __standard_all
99

1010
from . import nvidia
11+
from . import amd
1112

1213
__all__ = [
1314
*__core_all,
1415
*__layouts_all,
1516
*__math_all,
1617
*__standard_all,
1718
"nvidia",
19+
"amd",
1820
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from ._layouts import AMDMFMALayout
2+
3+
__all__ = ["AMDMFMALayout"]
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import List
5+
from triton.language.core import _unwrap_if_constexpr
6+
7+
from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
8+
from triton.experimental.gluon import language as ttgl
9+
10+
__all__ = [
11+
"AMDMFMALayout",
12+
]
13+
14+
15+
@dataclass(frozen=True)
16+
class AMDMFMALayout(DistributedLayout):
17+
"""
18+
Represents a layout for AMD MFMA (matrix core) operations.
19+
20+
Args:
21+
version (int): Major and minor identifier for the MFMA instruction.
22+
instr_shape: (M, N) dimension for the instrinsic shape.
23+
transposed: indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
24+
warps_per_cta (List[int]): Number of warps per CTA.
25+
tiles_per_warp: (List[int]): Number of tiles per WARP.
26+
elem_type: fp32 or fp64
27+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
28+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
29+
cta_order (Optional[List[int]]): CTA ordering.
30+
"""
31+
version: int
32+
instr_shape: List[int]
33+
transposed: bool
34+
warps_per_cta: List[int]
35+
tiles_per_warp: List[int]
36+
elem_type: ttgl.dtype = ttgl.float32
37+
ctas_per_cga: List[int] | None = None
38+
cta_split_num: List[int] | None = None
39+
cta_order: List[int] | None = None
40+
41+
def __post_init__(self):
42+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
43+
super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
44+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
45+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
46+
super().__setattr__("tiles_per_warp", _unwrap_if_constexpr(self.tiles_per_warp))
47+
super().__setattr__("elem_type", _unwrap_if_constexpr(self.elem_type))
48+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
49+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
50+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
51+
52+
assert self.elem_type.is_fp32() or self.elem_type.is_fp64(
53+
), "The element type in AMDMFMALayout should be float32 or float64 type"
54+
55+
rank = len(self.cta_order)
56+
_realize_cta_layout(self, rank)
57+
assert len(self.ctas_per_cga) == rank
58+
assert len(self.cta_split_num) == rank
59+
assert len(self.cta_order) == rank
60+
61+
def _to_ir(self, builder):
62+
type = builder.get_float_ty() if self.elem_type is ttgl.float32 else builder.get_double_ty()
63+
return builder.get_amd_mfma_layout(self.version, self.tiles_per_warp, self.warps_per_cta, self.ctas_per_cga,
64+
self.cta_split_num, self.cta_order, self.instr_shape, self.transposed, type)
65+
66+
def mangle(self) -> str:
67+
68+
def stringify(x):
69+
if x is None:
70+
return ""
71+
return "_".join(map(str, x))
72+
73+
return f"MFMA_{self.version}_{stringify(self.instr_shape)}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.tiles_per_warp)}_{self.elem_type}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_MFMA"

0 commit comments

Comments
 (0)