Skip to content

Commit 8e87ed6

Browse files
authored
[AMD][GLUON] Expose AMDWMMALayout (triton-lang#8090)
This PR exposes AMD WMMA layout for RDNA-series GPU. Updates test_lowering.py to include AMD WMMA layout and also changes it for more common layout cases. Will open a separate PR for WMMA op.
1 parent fca399f commit 8e87ed6

File tree

5 files changed

+224
-57
lines changed

5 files changed

+224
-57
lines changed

python/src/gluon_ir.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ struct GluonLayouts {
9797
py::handle NVMMASharedLayout;
9898
py::handle SwizzledSharedLayout;
9999
py::handle AMDMFMALayout;
100+
py::handle AMDWMMALayout;
100101
py::handle PaddedSharedLayout;
101102
py::handle GluonDType;
102103

@@ -117,6 +118,7 @@ struct GluonLayouts {
117118
SwizzledSharedLayout =
118119
py::object(layouts.attr("SwizzledSharedLayout")).release();
119120
AMDMFMALayout = py::object(amdLayouts.attr("AMDMFMALayout")).release();
121+
AMDWMMALayout = py::object(amdLayouts.attr("AMDWMMALayout")).release();
120122
PaddedSharedLayout =
121123
py::object(layouts.attr("PaddedSharedLayout")).release();
122124

@@ -226,6 +228,14 @@ py::object layoutToGluon(Attribute layout) {
226228
toStdVector(ctaLayout.getCTAsPerCGA()),
227229
toStdVector(ctaLayout.getCTASplitNum()),
228230
toStdVector(ctaLayout.getCTAOrder()));
231+
} else if (auto amdWmma = dyn_cast<ttg::AMDWmmaEncodingAttr>(layout)) {
232+
auto ctaLayout = amdWmma.getCTALayout();
233+
return layouts.AMDWMMALayout(amdWmma.getVersion(),
234+
amdWmma.getIsTransposed(),
235+
toStdVector(amdWmma.getWarpsPerCTA()),
236+
toStdVector(ctaLayout.getCTAsPerCGA()),
237+
toStdVector(ctaLayout.getCTASplitNum()),
238+
toStdVector(ctaLayout.getCTAOrder()));
229239
} else if (auto paddedShared =
230240
dyn_cast<ttg::PaddedSharedEncodingAttr>(layout)) {
231241
auto *ctx = paddedShared.getContext();
@@ -357,6 +367,18 @@ void init_gluon_ir(py::module &&m) {
357367
ctx, version, warpsPerCta, tilesPerWarp, instrShape[0],
358368
instrShape[1], transposed, ctaLayout, elemType);
359369
})
370+
.def("get_amd_wmma_layout",
371+
[](GluonOpBuilder &self, unsigned version, bool transposed,
372+
std::vector<unsigned> &warpsPerCta,
373+
std::vector<unsigned> &ctasPerCga,
374+
std::vector<unsigned> &ctaSplitNum,
375+
std::vector<unsigned> &ctaOrder) -> Attribute {
376+
auto ctx = self.getContext();
377+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
378+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
379+
return ttg::AMDWmmaEncodingAttr::get(ctx, version, transposed,
380+
warpsPerCta, ctaLayout);
381+
})
360382
.def("get_padded_shared_layout",
361383
[](GluonOpBuilder &self, std::vector<unsigned> &intervals,
362384
std::vector<unsigned> &paddings,

python/test/gluon/test_frontend.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@
2626
BLACKWELL_TARGET = GPUTarget("cuda", 100, 32)
2727
HOPPER_TARGET = GPUTarget("cuda", 90, 32)
2828
AMPERE_TARGET = GPUTarget("cuda", 80, 32)
29-
HIP_TARGET = GPUTarget("hip", "gfx1200", 32)
29+
HIP_TARGET_RDNA4 = GPUTarget("hip", "gfx1200", 32)
3030
HIP_TARGET_CDNA3 = GPUTarget("hip", "gfx942", 64)
3131
HIP_TARGET_CDNA4 = GPUTarget("hip", "gfx950", 64)
3232

33-
ALL_TARGETS = [AMPERE_TARGET, HOPPER_TARGET, BLACKWELL_TARGET, HIP_TARGET]
33+
ALL_TARGETS = [AMPERE_TARGET, HOPPER_TARGET, BLACKWELL_TARGET, HIP_TARGET_RDNA4]
3434

3535

3636
def anonymize_ir(ir):
@@ -1705,6 +1705,79 @@ def test_infer_layout_for_amd_mfma(target):
17051705
""")
17061706

17071707

1708+
@gluon.jit
1709+
def amd_wmma_layout_kernel():
1710+
ttgl.full([64, 64], 0, ttgl.float16, layout=amd_layouts.AMDWMMALayout(version=2, transposed=True,
1711+
warps_per_cta=[1, 4]))
1712+
ttgl.full([64, 64], 0, ttgl.float16, layout=amd_layouts.AMDWMMALayout(version=2, transposed=True,
1713+
warps_per_cta=[2, 2]))
1714+
ttgl.full([64, 64], 0, ttgl.float16, layout=amd_layouts.AMDWMMALayout(version=2, transposed=False,
1715+
warps_per_cta=[1, 4]))
1716+
ttgl.full([64, 64], 0, ttgl.float16, layout=amd_layouts.AMDWMMALayout(version=2, transposed=False,
1717+
warps_per_cta=[2, 2]))
1718+
1719+
1720+
@pytest.mark.parametrize("target", [HIP_TARGET_RDNA4])
1721+
def test_amd_wmma_layout(target):
1722+
module = run_parser(amd_wmma_layout_kernel, target=target)
1723+
expecttest.assert_expected_inline(
1724+
anonymize_ir(module.str_nodebug()), """\
1725+
#mma = #ttg.amd_wmma<{version = 2, isTranspose = true, warpsPerCTA = [1, 4]}>
1726+
#mma1 = #ttg.amd_wmma<{version = 2, isTranspose = true, warpsPerCTA = [2, 2]}>
1727+
#mma2 = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [1, 4]}>
1728+
#mma3 = #ttg.amd_wmma<{version = 2, isTranspose = false, warpsPerCTA = [2, 2]}>
1729+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
1730+
tt.func public @amd_wmma_layout_kernel() attributes {noinline = false} {
1731+
%cst = arith.constant 0.000000e+00 : f16
1732+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #mma>
1733+
%cst_1 = arith.constant 0.000000e+00 : f16
1734+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #mma1>
1735+
%cst_3 = arith.constant 0.000000e+00 : f16
1736+
%cst_4 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #mma2>
1737+
%cst_5 = arith.constant 0.000000e+00 : f16
1738+
%cst_6 = arith.constant dense<0.000000e+00> : tensor<64x64xf16, #mma3>
1739+
tt.return
1740+
}
1741+
}
1742+
""")
1743+
1744+
1745+
@gluon.jit
1746+
def infer_layout_for_amd_wmma_kernel():
1747+
layout: ttgl.constexpr = amd_layouts.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[4, 1])
1748+
a = ttgl.full([128, 32], 1, ttgl.float16, layout)
1749+
b = ttgl.reduce(a, 1, add_int)
1750+
ttgl.static_assert(b.type.layout == ttgl.SliceLayout(1, layout))
1751+
1752+
1753+
@pytest.mark.parametrize("target", [HIP_TARGET_RDNA4])
1754+
def test_infer_layout_for_amd_wmma(target):
1755+
module = run_parser(infer_layout_for_amd_wmma_kernel, target=target)
1756+
expecttest.assert_expected_inline(
1757+
anonymize_ir(module.str_nodebug()), """\
1758+
#mma = #ttg.amd_wmma<{version = 2, isTranspose = true, warpsPerCTA = [4, 1]}>
1759+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
1760+
tt.func public @infer_layout_for_amd_wmma_kernel() attributes {noinline = false} {
1761+
%cst = arith.constant 1.000000e+00 : f16
1762+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<128x32xf16, #mma>
1763+
%0 = "tt.reduce"(%cst_0) <{axis = 1 : i32}> ({
1764+
^bb0(%arg0: f16, %arg1: f16):
1765+
%1 = tt.call @test_frontend.add_int__fp16_fp16__(%arg0, %arg1) : (f16, f16) -> f16
1766+
tt.reduce.return %1 : f16
1767+
}) : (tensor<128x32xf16, #mma>) -> tensor<128xf16, #ttg.slice<{dim = 1, parent = #mma}>>
1768+
tt.return
1769+
}
1770+
tt.func private @test_frontend.add_int__fp16_fp16__(%arg0: f16, %arg1: f16) -> f16 attributes {noinline = false} {
1771+
%0 = arith.addf %arg0, %arg1 : f16
1772+
tt.return %0 : f16
1773+
^bb1: // no predecessors
1774+
%1 = ub.poison : f16
1775+
tt.return %1 : f16
1776+
}
1777+
}
1778+
""")
1779+
1780+
17081781
@gluon.jit
17091782
def amd_async_wait():
17101783
cdna4_async_copy.async_wait(0)

python/test/gluon/test_lowerings.py

Lines changed: 47 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,16 @@ def _reduce_layouts():
119119
cta_order=[0, 1], instr_shape=[16, 8]),
120120
ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1],
121121
cta_order=[1, 0], instr_shape=[16, 16, 16]),
122-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
123-
transposed=False),
124-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
125-
transposed=False),
126-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
122+
ttgl.amd.AMDMFMALayout(version=1, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
127123
transposed=True),
128124
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
129125
transposed=True),
130-
# TODO: AMDWMMA layouts
126+
ttgl.amd.AMDMFMALayout(version=3, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
127+
transposed=True),
128+
ttgl.amd.AMDMFMALayout(version=4, warps_per_cta=[1, 4], tiles_per_warp=[1, 1], instr_shape=[32, 32],
129+
transposed=True),
130+
ttgl.amd.AMDWMMALayout(version=1, transposed=True, warps_per_cta=[1, 4]),
131+
ttgl.amd.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[1, 4]),
131132
ttgl.DotOperandLayout(
132133
parent=ttgl.NVMMADistributedLayout(version=[2, 0], warps_per_cta=[2, 4], ctas_per_cga=[1, 1],
133134
cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]),
@@ -515,72 +516,68 @@ def kernel(x_ptr, y_ptr, M: ttgl.constexpr, N: ttgl.constexpr, src_layout: ttgl.
515516
ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1],
516517
cta_order=[0, 1], instr_shape=[16, 64, 16]),
517518
],
518-
# AMD MFMA layouts
519+
# AMD MFMA v1 layouts
519520
[
520-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
521-
transposed=False),
522-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
523-
transposed=False),
521+
ttgl.amd.AMDMFMALayout(version=1, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
522+
transposed=True),
523+
ttgl.amd.AMDMFMALayout(version=1, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
524+
transposed=True),
524525
],
525526
[
526-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
527-
transposed=False),
528-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
529-
transposed=False),
527+
ttgl.amd.AMDMFMALayout(version=1, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
528+
transposed=True),
529+
ttgl.amd.AMDMFMALayout(version=1, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
530+
transposed=True),
530531
],
532+
# AMD MFMA v2 layouts
531533
[
532534
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
533-
transposed=False),
534-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
535535
transposed=True),
536-
],
537-
[
538536
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
539-
transposed=False),
540-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
541537
transposed=True),
542538
],
543539
[
544540
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
545-
transposed=False),
541+
transposed=True),
546542
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
547-
transposed=False),
543+
transposed=True),
548544
],
545+
# AMD MFMA v3 layouts
549546
[
550-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
551-
transposed=False),
552-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
553-
transposed=False),
547+
ttgl.amd.AMDMFMALayout(version=3, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
548+
transposed=True),
549+
ttgl.amd.AMDMFMALayout(version=3, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
550+
transposed=True),
554551
],
555552
[
556-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
557-
transposed=False),
558-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
553+
ttgl.amd.AMDMFMALayout(version=3, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
554+
transposed=True),
555+
ttgl.amd.AMDMFMALayout(version=3, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
559556
transposed=True),
560557
],
558+
# AMD MFMA v4 layouts
561559
[
562-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
563-
transposed=False),
564-
ttgl.amd.AMDMFMALayout(version=2, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
560+
ttgl.amd.AMDMFMALayout(version=4, warps_per_cta=[2, 2], tiles_per_warp=[1, 1], instr_shape=[32, 32],
565561
transposed=True),
562+
ttgl.amd.AMDMFMALayout(version=4, warps_per_cta=[4, 1], tiles_per_warp=[1, 1], instr_shape=[32, 32],
563+
transposed=True),
564+
],
565+
[
566+
ttgl.amd.AMDMFMALayout(version=4, warps_per_cta=[4, 4], tiles_per_warp=[1, 1], instr_shape=[16, 16],
567+
transposed=True),
568+
ttgl.amd.AMDMFMALayout(version=4, warps_per_cta=[16, 1], tiles_per_warp=[1, 1], instr_shape=[16, 16],
569+
transposed=True),
570+
],
571+
# AMD WMMA v1 layouts
572+
[
573+
ttgl.amd.AMDWMMALayout(version=1, transposed=True, warps_per_cta=[4, 4]),
574+
ttgl.amd.AMDWMMALayout(version=1, transposed=True, warps_per_cta=[16, 1]),
575+
],
576+
# AMD WMMA v2 layouts
577+
[
578+
ttgl.amd.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[4, 4]),
579+
ttgl.amd.AMDWMMALayout(version=2, transposed=True, warps_per_cta=[16, 1]),
566580
],
567-
# TODO: AMD WMMA layouts
568-
#[
569-
# WmmaLayout(1, [4, 4]),
570-
# WmmaLayout(1, [16, 1]),
571-
#],
572-
#[
573-
# WmmaLayout(1, [16, 1]),
574-
# WmmaLayout(1, [4, 4]),
575-
#],
576-
#[
577-
# WmmaLayout(2, [4, 4]),
578-
# WmmaLayout(2, [16, 1]),
579-
#],
580-
#[
581-
# WmmaLayout(2, [16, 1]),
582-
# WmmaLayout(2, [4, 4]),
583-
#],
584581
]
585582

586583

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ._layouts import AMDMFMALayout
1+
from ._layouts import AMDMFMALayout, AMDWMMALayout
22
from . import cdna3, cdna4
33

4-
__all__ = ["AMDMFMALayout", "cdna3", "cdna4"]
4+
__all__ = ["AMDMFMALayout", "AMDWMMALayout", "cdna3", "cdna4"]

python/triton/experimental/gluon/language/amd/_layouts.py

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
__all__ = [
1111
"AMDMFMALayout",
12+
"AMDWMMALayout",
1213
]
1314

1415

@@ -18,15 +19,22 @@ class AMDMFMALayout(DistributedLayout):
1819
Represents a layout for AMD MFMA (matrix core) operations.
1920
2021
Args:
21-
version (int): Major and minor identifier for the MFMA instruction.
22-
instr_shape: (M, N) dimension for the instrinsic shape.
23-
transposed (bool): indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
22+
version (int): Indicates the GPU architecture.
23+
instr_shape: (M, N) Dimension for the instrinsic shape.
24+
transposed (bool): Indicates the result tensor is transposed so that each thread holds consecutive elements in the same row instead of column, which is good for chained dot and global write.
2425
warps_per_cta (List[int]): Number of warps per CTA.
2526
elem_type Optional(ttgl.dtype): Supported types are int32, fp32 and fp64. Default is fp32.
2627
tiles_per_warp Optional(List[int]): Number of tiles per WARP. For mfma layout, if missing, use the default where we have unit tile size on all dimensions.
2728
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
2829
cta_split_num (Optional[List[int]]): Split factors for CTAs.
2930
cta_order (Optional[List[int]]): CTA ordering.
31+
32+
Current supported versions:
33+
34+
- 1: gfx908
35+
- 2: gfx90a
36+
- 3: gfx942
37+
- 4: gfx950
3038
"""
3139
version: int
3240
instr_shape: List[int]
@@ -94,3 +102,70 @@ def __hash__(self):
94102
tuple(self.cta_split_num) if self.cta_split_num else None,
95103
tuple(self.cta_order) if self.cta_order else None,
96104
))
105+
106+
107+
@dataclass(frozen=True)
108+
class AMDWMMALayout(DistributedLayout):
109+
"""
110+
Represents a layout for AMD WMMA (matrix core) operations.
111+
112+
Args:
113+
version (int): Indicates the GPU architecture.
114+
transposed (bool): Indicates the result tensor is transposed.
115+
warps_per_cta (List[int]): Number of warps per CTA.
116+
ctas_per_cga (Optional[List[int]]): CTAs per CGA grouping.
117+
cta_split_num (Optional[List[int]]): Split factors for CTAs.
118+
cta_order (Optional[List[int]]): CTA ordering.
119+
120+
Current supported versions:
121+
122+
- 1: RDNA3; e.g., gfx1100, gfx1101
123+
- 2: RDNA4; e.g., gfx1200, gfx1201
124+
"""
125+
version: int
126+
transposed: bool
127+
warps_per_cta: List[int]
128+
ctas_per_cga: Optional[List[int]] = None
129+
cta_split_num: Optional[List[int]] = None
130+
cta_order: Optional[List[int]] = None
131+
132+
def __post_init__(self):
133+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
134+
super().__setattr__("transposed", _unwrap_if_constexpr(self.transposed))
135+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
136+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
137+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
138+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
139+
self.verify()
140+
141+
def _to_ir(self, builder):
142+
return builder.get_amd_wmma_layout(self.version, self.transposed, self.warps_per_cta, self.ctas_per_cga,
143+
self.cta_split_num, self.cta_order)
144+
145+
def mangle(self) -> str:
146+
147+
def stringify(x):
148+
if x is None:
149+
return ""
150+
return "_".join(map(str, x))
151+
152+
return f"WMMA_{self.version}_{self.transposed}_{stringify(self.warps_per_cta)}_{stringify(self.ctas_per_cga)}_{stringify(self.cta_split_num)}_{stringify(self.cta_order)}_WMMA"
153+
154+
def verify(self):
155+
assert self.version >= 1 and self.version <= 2, "version must be in the [1, 2] range"
156+
157+
rank = len(self.warps_per_cta)
158+
_realize_cta_layout(self, rank)
159+
assert len(self.ctas_per_cga) == rank
160+
assert len(self.cta_split_num) == rank
161+
assert len(self.cta_order) == rank
162+
163+
def __hash__(self):
164+
return hash((
165+
self.version,
166+
self.transposed,
167+
tuple(self.warps_per_cta),
168+
tuple(self.ctas_per_cga) if self.ctas_per_cga else None,
169+
tuple(self.cta_split_num) if self.cta_split_num else None,
170+
tuple(self.cta_order) if self.cta_order else None,
171+
))

0 commit comments

Comments
 (0)