Skip to content

Commit ff3832d

Browse files
authored
[AMD][Gluon] Update AMDMFMALayout and support mfma (#7820)
This PR mostly is to add a new API for mfma on CDNA4 and CDNA3. Other changes are about the update of AMDMFMALayout.
1 parent 69c74b2 commit ff3832d

File tree

6 files changed

+184
-57
lines changed

6 files changed

+184
-57
lines changed

python/src/gluon_ir.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ py::object layoutToGluon(Attribute layout) {
217217

218218
return layouts.AMDMFMALayout(
219219
amdMfma.getVersion(), instrShape, amdMfma.getIsTransposed(),
220-
toStdVector(amdMfma.getWarpsPerCTA()),
221-
toStdVector(amdMfma.getTilesPerWarp()), layouts.GluonDType(typeName),
220+
toStdVector(amdMfma.getWarpsPerCTA()), layouts.GluonDType(typeName),
221+
toStdVector(amdMfma.getTilesPerWarp()),
222222
toStdVector(ctaLayout.getCTAsPerCGA()),
223223
toStdVector(ctaLayout.getCTASplitNum()),
224224
toStdVector(ctaLayout.getCTAOrder()));
@@ -325,13 +325,12 @@ void init_gluon_ir(py::module &&m) {
325325
})
326326
.def("get_amd_mfma_layout",
327327
[](GluonOpBuilder &self, unsigned version,
328+
std::vector<unsigned> &instrShape, bool transposed,
329+
std::vector<unsigned> &warpsPerCta, mlir::Type elemType,
328330
std::vector<unsigned> &tilesPerWarp,
329-
std::vector<unsigned> &warpsPerCta,
330331
std::vector<unsigned> &ctasPerCga,
331332
std::vector<unsigned> &ctaSplitNum,
332-
std::vector<unsigned> &ctaOrder,
333-
std::vector<unsigned> &instrShape, bool transposed,
334-
mlir::Type elemType) -> Attribute {
333+
std::vector<unsigned> &ctaOrder) -> Attribute {
335334
auto ctx = self.getContext();
336335
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
337336
ctx, ctasPerCga, ctaSplitNum, ctaOrder);

python/test/gluon/test_core.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer, is_hopper
4+
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hip_cdna3, is_hip_cdna4, is_hopper_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
77
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
@@ -143,3 +143,66 @@ def test_warpgroup_mma(ASYNC):
143143
ref = torch.matmul(a, b)
144144

145145
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)
146+
147+
148+
@pytest.mark.parametrize("M, N, K", [(32, 32, 16), (16, 16, 32)])
149+
@pytest.mark.parametrize("in_dtype", ['float16', 'bfloat16'])
150+
@pytest.mark.parametrize("num_warps", [4, 8])
151+
@pytest.mark.parametrize("cdna_version", [3, 4])
152+
def test_amd_mfma(M, N, K, in_dtype, num_warps, cdna_version):
153+
154+
@gluon.jit
155+
def kernel(a_ptr, b_ptr, c_ptr, stride_am, stride_ak, #
156+
stride_bk, stride_bn, #
157+
stride_cm, stride_cn, BLOCK_SIZE_M: ttgl.constexpr, BLOCK_SIZE_N: ttgl.constexpr,
158+
BLOCK_SIZE_K: ttgl.constexpr, blocked: ttgl.constexpr, mfma_layout: ttgl.constexpr):
159+
dot_a_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout, k_width=8)
160+
dot_b_layout: ttgl.constexpr = ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout, k_width=8)
161+
162+
offs_am = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
163+
offs_bn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
164+
165+
offs_ak = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(0, blocked))
166+
offs_bk = ttgl.arange(0, BLOCK_SIZE_K, layout=ttgl.SliceLayout(1, blocked))
167+
offs_a = offs_am[:, None] * stride_am + offs_ak[None, :] * stride_ak
168+
offs_b = offs_bk[:, None] * stride_bk + offs_bn[None, :] * stride_bn
169+
170+
a = ttgl.amd.cdna3.buffer_load(ptr=a_ptr, offsets=offs_a)
171+
b = ttgl.amd.cdna3.buffer_load(ptr=b_ptr, offsets=offs_b)
172+
a1 = ttgl.convert_layout(a, layout=dot_a_layout)
173+
b1 = ttgl.convert_layout(b, layout=dot_b_layout)
174+
acc = ttgl.zeros([BLOCK_SIZE_M, BLOCK_SIZE_N], ttgl.float32, mfma_layout)
175+
c = ttgl.amd.cdna3.mfma(a1, b1, acc)
176+
c = ttgl.convert_layout(c, layout=blocked)
177+
c = c.to(a_ptr.dtype.element_ty)
178+
179+
offs_cm = ttgl.arange(0, BLOCK_SIZE_M, layout=ttgl.SliceLayout(1, blocked))
180+
offs_cn = ttgl.arange(0, BLOCK_SIZE_N, layout=ttgl.SliceLayout(0, blocked))
181+
offs_c = offs_cm[:, None] * stride_cm + offs_cn[None, :] * stride_cn
182+
ttgl.amd.cdna3.buffer_store(stored_value=c, ptr=c_ptr, offsets=offs_c)
183+
184+
if not is_hip_cdna4() and not is_hip_cdna3():
185+
pytest.skip("mfma quires target to be CDNA3 or CDNA4")
186+
187+
if is_hip_cdna3() and cdna_version != 3:
188+
pytest.skip("On CDNA3 target, skip if mfma version is not 3")
189+
190+
if is_hip_cdna4() and cdna_version != 4:
191+
pytest.skip("On CDNA4 target, skip if mfma version is not 4")
192+
193+
elem_type = torch.float16 if in_dtype == 'float16' else torch.bfloat16
194+
a = torch.randn((M, K), device='cuda', dtype=elem_type) - 0.5
195+
b = torch.randn((K, N), device='cuda', dtype=elem_type) - 0.5
196+
c = torch.empty((M, N), device=a.device, dtype=elem_type)
197+
nonkdim: ttgl.constexpr = 32
198+
blocked: ttgl.constexpr = ttgl.BlockedLayout(size_per_thread=[4, 4], threads_per_warp=[4, 16],
199+
warps_per_cta=[num_warps, 1], order=[1, 0])
200+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=cdna_version, instr_shape=[nonkdim, nonkdim],
201+
transposed=True, warps_per_cta=[num_warps, 1])
202+
203+
kernel[1, 1](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_SIZE_M=M,
204+
BLOCK_SIZE_N=N, BLOCK_SIZE_K=K, blocked=blocked, mfma_layout=mfma_layout, num_warps=num_warps)
205+
206+
ref = torch.matmul(a, b)
207+
triton_output = c
208+
torch.testing.assert_close(ref, triton_output)

python/test/gluon/test_frontend.py

Lines changed: 78 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1413,31 +1413,27 @@ def test_atomic_cas():
14131413

14141414
@gluon.jit
14151415
def amd_mfma_layout_kernel():
1416-
mfma_layout_fp32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1417-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1418-
ctas_per_cga=[1,
1419-
1], cta_split_num=[1,
1420-
1], cta_order=[1, 0])
1416+
ttgl.full([128, 32], 0, ttgl.float32, layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32],
1417+
transposed=True, warps_per_cta=[4, 1]))
14211418

1422-
mfma_layout_fp64: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True,
1423-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1424-
elem_type=ttgl.float64, ctas_per_cga=[1, 1],
1425-
cta_split_num=[1, 1], cta_order=[1, 0])
1419+
ttgl.full([128, 32], 0, ttgl.float32,
1420+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], tiles_per_warp=[4, 1], transposed=True,
1421+
warps_per_cta=[4, 1]))
14261422

1427-
mfma_layout_int32: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True,
1428-
warps_per_cta=[4, 1], tiles_per_warp=[4, 1],
1429-
elem_type=ttgl.int32, ctas_per_cga=[1, 1],
1430-
cta_split_num=[1, 1], cta_order=[1, 0])
1423+
ttgl.full([128, 32], 0, ttgl.float32,
1424+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True, warps_per_cta=[4, 1],
1425+
ctas_per_cga=[1, 1], tiles_per_warp=[1, 1], cta_split_num=[1, 1],
1426+
cta_order=[1, 0]))
14311427

1432-
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 64], [4, 1], [1, 0])
1428+
ttgl.full([128, 32], 0, ttgl.float64,
1429+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True, warps_per_cta=[4, 1],
1430+
elem_type=ttgl.float64, tiles_per_warp=[1, 1], ctas_per_cga=[1, 1],
1431+
cta_split_num=[1, 1], cta_order=[1, 0]))
14331432

1434-
x_fp32 = ttgl.full([128, 32], 0, ttgl.float32, layout)
1435-
x_fp64 = ttgl.full([128, 32], 0, ttgl.float64, layout)
1436-
x_int32 = ttgl.full([128, 32], 0, ttgl.int32, layout)
1437-
1438-
ttgl.convert_layout(x_fp32, mfma_layout_fp32)
1439-
ttgl.convert_layout(x_fp64, mfma_layout_fp64)
1440-
ttgl.convert_layout(x_int32, mfma_layout_int32)
1433+
ttgl.full([128, 32], 0, ttgl.int32,
1434+
layout=amd_layouts.AMDMFMALayout(version=3, instr_shape=[16, 16], transposed=True, warps_per_cta=[4, 1],
1435+
elem_type=ttgl.int32, tiles_per_warp=[1, 1], ctas_per_cga=[1, 1],
1436+
cta_split_num=[1, 1]))
14411437

14421438

14431439
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
@@ -1446,21 +1442,22 @@ def test_amd_mfma_layout(target):
14461442
module = run_parser(amd_mfma_layout_kernel, target=target)
14471443
expecttest.assert_expected_inline(
14481444
anonymize_ir(module.str_nodebug()), """\
1449-
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}>
1450-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1451-
#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1452-
#mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
1445+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
1446+
#mma1 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true}>
1447+
#mma2 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = f64}>
1448+
#mma3 = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true, elementType = i32}>
14531449
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14541450
tt.func public @amd_mfma_layout_kernel() attributes {noinline = false} {
14551451
%cst = arith.constant 0.000000e+00 : f32
1456-
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #blocked>
1457-
%cst_1 = arith.constant 0.000000e+00 : f64
1458-
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #blocked>
1452+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1453+
%cst_1 = arith.constant 0.000000e+00 : f32
1454+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma1>
1455+
%cst_3 = arith.constant 0.000000e+00 : f32
1456+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128x32xf32, #mma>
1457+
%cst_5 = arith.constant 0.000000e+00 : f64
1458+
%cst_6 = arith.constant dense<0.000000e+00> : tensor<128x32xf64, #mma2>
14591459
%c0_i32 = arith.constant 0 : i32
1460-
%cst_3 = arith.constant dense<0> : tensor<128x32xi32, #blocked>
1461-
%0 = ttg.convert_layout %cst_0 : tensor<128x32xf32, #blocked> -> tensor<128x32xf32, #mma>
1462-
%1 = ttg.convert_layout %cst_2 : tensor<128x32xf64, #blocked> -> tensor<128x32xf64, #mma1>
1463-
%2 = ttg.convert_layout %cst_3 : tensor<128x32xi32, #blocked> -> tensor<128x32xi32, #mma2>
1460+
%cst_7 = arith.constant dense<0> : tensor<128x32xi32, #mma3>
14641461
tt.return
14651462
}
14661463
}
@@ -1475,8 +1472,8 @@ def add_int(a, b):
14751472
@gluon.jit
14761473
def infer_layout_for_amd_mfma_kernel():
14771474
layout: ttgl.constexpr = amd_layouts.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1478-
elem_type=ttgl.int32, warps_per_cta=[4,
1479-
1], tiles_per_warp=[4, 1],
1475+
warps_per_cta=[4,
1476+
1], elem_type=ttgl.int32, tiles_per_warp=[1, 1],
14801477
ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0])
14811478
a = ttgl.full([128, 32], 1, ttgl.int32, layout)
14821479
b = ttgl.reduce(a, 1, add_int)
@@ -1489,7 +1486,7 @@ def test_infer_layout_for_amd_mfma(target):
14891486

14901487
expecttest.assert_expected_inline(
14911488
anonymize_ir(module.str_nodebug()), """\
1492-
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], tilesPerWarp = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
1489+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true, elementType = i32}>
14931490
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
14941491
tt.func public @infer_layout_for_amd_mfma_kernel() attributes {noinline = false} {
14951492
%c1_i32 = arith.constant 1 : i32
@@ -1719,3 +1716,49 @@ def test_buffer_load_store_with_broadcast(target):
17191716
}
17201717
}
17211718
""")
1719+
1720+
1721+
@pytest.mark.parametrize("target", [HIP_TARGET_CDNA3, HIP_TARGET_CDNA4])
1722+
def test_amd_mfma(target):
1723+
1724+
@gluon.jit
1725+
def kernel():
1726+
mfma_layout: ttgl.constexpr = ttgl.amd.AMDMFMALayout(version=3, instr_shape=[32, 32], transposed=True,
1727+
warps_per_cta=[4, 1])
1728+
1729+
a = ttgl.full([64, 32], 1.0, ttgl.float32, layout=ttgl.DotOperandLayout(operand_index=0, parent=mfma_layout,
1730+
k_width=8))
1731+
b = ttgl.full([32, 64], 2.0, ttgl.float32, layout=ttgl.DotOperandLayout(operand_index=1, parent=mfma_layout,
1732+
k_width=8))
1733+
1734+
acc = ttgl.zeros([64, 64], ttgl.float32, mfma_layout)
1735+
acc = ttgl.amd.cdna3.mfma(a, b, acc)
1736+
ttgl.static_assert(isinstance(acc, ttgl.tensor))
1737+
ttgl.static_assert(acc.type.layout == mfma_layout)
1738+
1739+
module = run_parser(kernel, target=target)
1740+
1741+
expecttest.assert_expected_inline(
1742+
anonymize_ir(module.str_nodebug()), """\
1743+
#mma = #ttg.amd_mfma<{version = 3, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}>
1744+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 64 : i32} {
1745+
tt.func public @kernel() attributes {noinline = false} {
1746+
%cst = arith.constant 1.000000e+00 : f32
1747+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>>
1748+
%cst_1 = arith.constant 2.000000e+00 : f32
1749+
%cst_2 = arith.constant dense<2.000000e+00> : tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>>
1750+
%0 = tt.call @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() : () -> tensor<64x64xf32, #mma>
1751+
%cst_3 = arith.constant 0.000000e+00 : f32
1752+
%1 = tt.dot %cst_0, %cst_2, %0 : tensor<64x32xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> * tensor<32x64xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>> -> tensor<64x64xf32, #mma>
1753+
tt.return
1754+
}
1755+
tt.func private @"triton.experimental.gluon.language._standard.zeros____(0, 0)cconstexpr_64__(0, 1)cconstexpr_64__(1,)cconstexpr_fp32__(2,)cconstexpr_AMDMFMALayout(version=3, instr_shape=(32 ,32), transposed=True, warps_per_cta=(4 ,1), elem_type=triton_d_language_d_float32, tiles_per_warp=_1, 1_, ctas_per_cga=_1, 1_, cta_split_num=_1, 1_, cta_order=_1, 0_)_"() -> tensor<64x64xf32, #mma> attributes {noinline = false} {
1756+
%cst = arith.constant 0.000000e+00 : f32
1757+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
1758+
tt.return %cst_0 : tensor<64x64xf32, #mma>
1759+
^bb1: // no predecessors
1760+
%0 = ub.poison : tensor<64x64xf32, #mma>
1761+
tt.return %0 : tensor<64x64xf32, #mma>
1762+
}
1763+
}
1764+
""")

python/triton/experimental/gluon/language/amd/_layouts.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4-
from typing import List
4+
from typing import List, Optional
55
from triton.language.core import _unwrap_if_constexpr
66

77
from triton.experimental.gluon.language._layouts import _realize_cta_layout, DistributedLayout
@@ -20,10 +20,10 @@ class AMDMFMALayout(DistributedLayout):
2020
Args:
2121
version (int): Major and minor identifier for the MFMA instruction.
2222
instr_shape: (M, N) dimension for the instrinsic shape.
23-
transposed: indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
23+
transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
2424
warps_per_cta (List[int]): Number of warps per CTA.
25-
tiles_per_warp: (List[int]): Number of tiles per WARP.
26-
elem_type: fp32 or fp64
25+
elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
26+
tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
2727
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
2828
cta_split_num (Optional[List[int]]): Split factors for CTAs.
2929
cta_order (Optional[List[int]]): CTA ordering.
@@ -32,11 +32,11 @@ class AMDMFMALayout(DistributedLayout):
3232
instr_shape: List[int]
3333
transposed: bool
3434
warps_per_cta: List[int]
35-
tiles_per_warp: List[int]
3635
elem_type: ttgl.dtype = ttgl.float32
37-
ctas_per_cga: List[int] | None = None
38-
cta_split_num: List[int] | None = None
39-
cta_order: List[int] | None = None
36+
tiles_per_warp: Optional[List[int]] = None
37+
ctas_per_cga: Optional[List[int]] = None
38+
cta_split_num: Optional[List[int]] = None
39+
cta_order: Optional[List[int]] = None
4040

4141
def __post_init__(self):
4242
super().__setattr__("version", _unwrap_if_constexpr(self.version))
@@ -49,12 +49,15 @@ def __post_init__(self):
4949
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
5050
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
5151

52+
if self.tiles_per_warp is None:
53+
object.__setattr__(self, "tiles_per_warp", [1] * len(self.warps_per_cta))
54+
5255
self.verify()
5356

5457
def _to_ir(self, builder):
5558
type = self.elem_type.to_ir(builder)
56-
return builder.get_amd_mfma_layout(self.version, self.tiles_per_warp, self.warps_per_cta, self.ctas_per_cga,
57-
self.cta_split_num, self.cta_order, self.instr_shape, self.transposed, type)
59+
return builder.get_amd_mfma_layout(self.version, self.instr_shape, self.transposed, self.warps_per_cta, type,
60+
self.tiles_per_warp, self.ctas_per_cga, self.cta_split_num, self.cta_order)
5861

5962
def mangle(self) -> str:
6063

@@ -73,7 +76,7 @@ def verify(self):
7376
assert self.elem_type.is_fp32() or self.elem_type.is_fp64() \
7477
or self.elem_type.is_int32() , "element type must be float32, float64, or int32"
7578

76-
rank = len(self.cta_order)
79+
rank = len(self.warps_per_cta)
7780
_realize_cta_layout(self, rank)
7881
assert len(self.ctas_per_cga) == rank
7982
assert len(self.cta_split_num) == rank

0 commit comments

Comments
 (0)