Skip to content

Commit 4f8712b

Browse files
authored
[AMD][GLUON] Turn select scale layout into constexpr function (#8673)
Following #8496, this PR changes `get_wmma_scale_layout` / `get_mfma_scale_layout` into `constexpr_function`.
1 parent 14b7d02 commit 4f8712b

File tree

4 files changed

+68
-60
lines changed

4 files changed

+68
-60
lines changed

python/src/gluon_ir.cc

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -406,16 +406,6 @@ void init_gluon_ir(py::module &&m) {
406406
ctx, version, warpsPerCta, instrShape, transposed, ctaLayout,
407407
tilesPerWarp, elementBitWidth);
408408
})
409-
.def("get_amd_mfma_scale_layout",
410-
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
411-
unsigned mfmaMDim, std::vector<unsigned> &tilesPerWarp,
412-
std::vector<unsigned> &warpsPerCTA) -> py::object {
413-
auto ctx = self.getContext();
414-
auto ll = ttg::chooseScaledMfmaScaleLayout(
415-
ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
416-
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
417-
return layoutToGluon(attr);
418-
})
419409
.def("get_amd_wmma_layout",
420410
[](GluonOpBuilder &self, unsigned version, bool transposed,
421411
std::vector<unsigned> &warpsPerCta,
@@ -431,16 +421,6 @@ void init_gluon_ir(py::module &&m) {
431421
warpsPerCta, tilesPerWarp,
432422
ctaLayout, instrShape);
433423
})
434-
.def("get_amd_wmma_scale_layout",
435-
[](GluonOpBuilder &self, unsigned opIdx, std::vector<int64_t> &shape,
436-
unsigned mfmaMDim, std::vector<unsigned> &tilesPerWarp,
437-
std::vector<unsigned> &warpsPerCTA) -> py::object {
438-
auto ctx = self.getContext();
439-
auto ll = ttg::chooseScaledWmmaScaleLayout(
440-
ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
441-
auto attr = ttg::LinearEncodingAttr::get(ctx, ll);
442-
return layoutToGluon(attr);
443-
})
444424
.def("get_padded_shared_layout",
445425
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
446426
std::vector<unsigned> &paddings,
@@ -913,6 +893,40 @@ void init_gluon_ir(py::module &&m) {
913893
return layoutToGluon(attr);
914894
});
915895

896+
m.def("get_amd_mfma_scale_layout",
897+
[](unsigned opIdx, std::vector<int64_t> &shape, unsigned mfmaMDim,
898+
std::vector<unsigned> &tilesPerWarp,
899+
std::vector<unsigned> &warpsPerCTA) -> py::object {
900+
DialectRegistry registry;
901+
registry.insert<triton::TritonDialect, ttg::TritonGPUDialect,
902+
ttng::TritonNvidiaGPUDialect, gluon::GluonDialect>();
903+
MLIRContext ctx(MLIRContext::Threading::DISABLED);
904+
ctx.appendDialectRegistry(registry);
905+
ctx.loadAllAvailableDialects();
906+
907+
auto ll = ttg::chooseScaledMfmaScaleLayout(
908+
&ctx, opIdx, shape, mfmaMDim, tilesPerWarp, warpsPerCTA);
909+
auto attr = ttg::LinearEncodingAttr::get(&ctx, ll);
910+
return layoutToGluon(attr);
911+
});
912+
913+
m.def("get_amd_wmma_scale_layout",
914+
[](unsigned opIdx, std::vector<int64_t> &shape, unsigned wmmaMDim,
915+
std::vector<unsigned> &tilesPerWarp,
916+
std::vector<unsigned> &warpsPerCTA) -> py::object {
917+
DialectRegistry registry;
918+
registry.insert<triton::TritonDialect, ttg::TritonGPUDialect,
919+
ttng::TritonNvidiaGPUDialect, gluon::GluonDialect>();
920+
MLIRContext ctx(MLIRContext::Threading::DISABLED);
921+
ctx.appendDialectRegistry(registry);
922+
ctx.loadAllAvailableDialects();
923+
924+
auto ll = ttg::chooseScaledWmmaScaleLayout(
925+
&ctx, opIdx, shape, wmmaMDim, tilesPerWarp, warpsPerCTA);
926+
auto attr = ttg::LinearEncodingAttr::get(&ctx, ll);
927+
return layoutToGluon(attr);
928+
});
929+
916930
py::class_<ttg::WarpSpecializeOp, OpState>(m, "WarpSpecializeOp",
917931
py::module_local())
918932
.def("get_default_region", &ttg::WarpSpecializeOp::getDefaultRegion,

python/triton/experimental/gluon/language/amd/_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import math
2+
13
from triton import knobs
24
from triton.experimental.gluon.language import _core as ttgl
35
from triton.experimental.gluon.language._semantic import _check
@@ -57,13 +59,13 @@ def _create_and_broadcast_default_scale(op_idx, scale, format):
5759
operand = a if op_idx == 0 else b
5860

5961
scale_shape = _get_scale_shape(op_idx, operand, format)
60-
scale_layout = scale_fn(operand.type.layout, scale_shape, semantic)
61-
6262
if isinstance(scale, ttgl.tensor) and scale.numel.value != 1:
63-
assert scale.type.shape == scale_shape, \
64-
f"Expect scale tensor to have shape {scale_shape}, but got {scale.type.shape}"
63+
# In the case of scale pre-shuffling, the input shape is different from the default shape. We only check
64+
# the number of elements here.
65+
assert math.prod(scale_shape) == scale.numel.value, "Incompatible scale shape"
6566
return scale
6667

68+
scale_layout = scale_fn(operand.type.layout, scale_shape)
6769
scale_value = _unwrap_if_constexpr(scale)
6870
scale_value = 0x7F if scale_value is None else scale_value
6971
return semantic.full(scale_shape, scale_value, ttgl.uint8, scale_layout)

python/triton/experimental/gluon/language/amd/cdna4/__init__.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from ..._core import builtin, _unwrap_if_constexpr
1+
from triton.runtime.jit import constexpr_function
2+
from triton._C.libtriton.gluon_ir import get_amd_mfma_scale_layout as _get_mfma_scale_layout
3+
4+
from ..._core import builtin
25
from ..._layouts import DotOperandLayout
36
from .._layouts import AMDMFMALayout
47
from .._ops import _mma_scaled
@@ -10,19 +13,6 @@
1013
__all__ = [*__cdna3_all, "async_copy", "mfma_scaled", "get_mfma_scale_layout"]
1114

1215

13-
def _get_mfma_scale_layout(dot_operand_layout, shape, semantic):
14-
dot_operand_layout = _unwrap_if_constexpr(dot_operand_layout)
15-
shape = _unwrap_if_constexpr(shape)
16-
17-
op_idx = dot_operand_layout.operand_index
18-
parent = dot_operand_layout.parent
19-
assert isinstance(parent, AMDMFMALayout), "Expected parent to be an instance of AMDMFMALayout"
20-
mdim = parent.instr_shape[0]
21-
tiles_per_warp = parent.tiles_per_warp
22-
warps_per_cta = parent.warps_per_cta
23-
return semantic.builder.get_amd_mfma_scale_layout(op_idx, shape, mdim, tiles_per_warp, warps_per_cta)
24-
25-
2616
@builtin
2717
def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None):
2818
"""
@@ -56,11 +46,11 @@ def mfma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
5646
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
5747
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
5848

59-
return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _get_mfma_scale_layout, _semantic)
49+
return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, get_mfma_scale_layout, _semantic)
6050

6151

62-
@builtin
63-
def get_mfma_scale_layout(dot_operand_layout, shape, _semantic=None):
52+
@constexpr_function
53+
def get_mfma_scale_layout(dot_operand_layout, shape):
6454
""" Get the scale layout for MFMA scaled operands.
6555
6656
Args:
@@ -70,7 +60,13 @@ def get_mfma_scale_layout(dot_operand_layout, shape, _semantic=None):
7060
Return:
7161
layout (DistributedLinearLayout): The scale layout.
7262
"""
73-
return _get_mfma_scale_layout(dot_operand_layout, shape, _semantic)
63+
op_idx = dot_operand_layout.operand_index
64+
parent = dot_operand_layout.parent
65+
assert isinstance(parent, AMDMFMALayout), "Expected parent to be an instance of AMDMFMALayout"
66+
mdim = parent.instr_shape[0]
67+
tiles_per_warp = parent.tiles_per_warp
68+
warps_per_cta = parent.warps_per_cta
69+
return _get_mfma_scale_layout(op_idx, shape, mdim, tiles_per_warp, warps_per_cta)
7470

7571

7672
"""

python/triton/experimental/gluon/language/amd/gfx1250/__init__.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
from ..._core import builtin, _unwrap_if_constexpr
1+
from triton.runtime.jit import constexpr_function
2+
from triton._C.libtriton.gluon_ir import get_amd_wmma_scale_layout as _get_wmma_scale_layout
3+
4+
from ..._core import builtin
25
from .._ops import _wmma, _verify_wmma, _mma_scaled
36
from .._layouts import AMDWMMALayout
47
from ..cdna3 import buffer_load, buffer_store
@@ -8,19 +11,6 @@
811
__all__ = ["async_copy", "tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"]
912

1013

11-
def _get_wmma_scale_layout(dot_operand_layout, shape, semantic):
12-
dot_operand_layout = _unwrap_if_constexpr(dot_operand_layout)
13-
shape = _unwrap_if_constexpr(shape)
14-
15-
op_idx = dot_operand_layout.operand_index
16-
parent = dot_operand_layout.parent
17-
assert isinstance(parent, AMDWMMALayout), "Expected parent to be an instance of AMDMFMALayout"
18-
mdim = parent.instr_shape[0]
19-
tiles_per_warp = parent.tiles_per_warp
20-
warps_per_cta = parent.warps_per_cta
21-
return semantic.builder.get_amd_wmma_scale_layout(op_idx, shape, mdim, tiles_per_warp, warps_per_cta)
22-
23-
2414
@builtin
2515
def wmma(a, b, acc, _semantic=None):
2616
"""
@@ -73,11 +63,11 @@ def wmma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _semantic=None)
7363
assert a_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported lhs_format: {a_format.value}"
7464
assert b_format.value in {"e2m1", "e4m3", "e5m2"}, f"Unsupported rhs_format: {b_format.value}"
7565

76-
return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, _get_wmma_scale_layout, _semantic)
66+
return _mma_scaled(a, a_scale, a_format, b, b_scale, b_format, acc, get_wmma_scale_layout, _semantic)
7767

7868

79-
@builtin
80-
def get_wmma_scale_layout(dot_operand_layout, shape, _semantic=None):
69+
@constexpr_function
70+
def get_wmma_scale_layout(dot_operand_layout, shape):
8171
""" Get the scale layout for WMMA scaled operands.
8272
8373
Args:
@@ -87,4 +77,10 @@ def get_wmma_scale_layout(dot_operand_layout, shape, _semantic=None):
8777
Return:
8878
layout (DistributedLinearLayout): The scale layout.
8979
"""
90-
return _get_wmma_scale_layout(dot_operand_layout, shape, _semantic)
80+
op_idx = dot_operand_layout.operand_index
81+
parent = dot_operand_layout.parent
82+
assert isinstance(parent, AMDWMMALayout), "Expected parent to be an instance of AMDMFMALayout"
83+
mdim = parent.instr_shape[0]
84+
tiles_per_warp = parent.tiles_per_warp
85+
warps_per_cta = parent.warps_per_cta
86+
return _get_wmma_scale_layout(op_idx, shape, mdim, tiles_per_warp, warps_per_cta)

0 commit comments

Comments
 (0)