Skip to content

Commit e21efcb

Browse files
[GLUON] Adding Hopper WGMMA support (#7300)
Support for async wgmma coming in a separate PR
1 parent 8791ac1 commit e21efcb

File tree

8 files changed

+199
-14
lines changed

8 files changed

+199
-14
lines changed

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,33 @@ LogicalResult WarpGroupDotOp::inferReturnTypes(
5959
}
6060

6161
LogicalResult WarpGroupDotOp::verify() {
62-
auto nvmmaEnc =
63-
dyn_cast<NvidiaMmaEncodingAttr>(getD().getType().getEncoding());
62+
auto resTy = getD().getType();
63+
auto nvmmaEnc = dyn_cast<NvidiaMmaEncodingAttr>(resTy.getEncoding());
6464
if (!nvmmaEnc || !nvmmaEnc.isHopper())
6565
return emitOpError("WGMMA result layout must be Hopper NVMMA");
66+
auto numWarps = gpu::lookupNumWarps(getOperation());
67+
if (numWarps % 4)
68+
return emitOpError("WGMMA requires num_warps to be divisible by 4");
69+
auto retShapePerCTA = getShapePerCTA(resTy);
70+
int rank = retShapePerCTA.size();
71+
if (rank != 2)
72+
return emitOpError("WGMMA result shape must be 2D");
73+
if (retShapePerCTA[0] % 64 != 0)
74+
return emitOpError("WGMMA result M dimension must be divisible by 64");
75+
if (retShapePerCTA[1] % 8 != 0)
76+
return emitOpError("WGMMA result N dimension must be divisible by 8");
77+
auto aElemTy = getA().getType().getElementType();
78+
if (!(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
79+
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
80+
aElemTy.isF32()))
81+
return emitOpError("WGMMA result element type must be F16, BF16, F32, "
82+
"F8E5M2, F8E4M3FN, or integer type");
83+
if (getMaxNumImpreciseAcc() < 32 &&
84+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
85+
resTy.getElementType().isF32()) {
86+
return emitOpError("Cannot use F32 as the accumulator element type when "
87+
"the max_num_imprecise_acc is less than 32");
88+
}
6689
return success();
6790
}
6891

lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,10 @@ SmallVector<int64_t> getTMABlockShape(ArrayRef<int64_t> shapePerCTA,
127127
// Last dim must equal the swizzle byte size
128128
if (swizzleBytes != 0) {
129129
auto contigDimSize = (8 * swizzleBytes) / elementBitWidth;
130-
assert(blockShape[contigDim] >= contigDimSize);
130+
if (blockShape[contigDim] < contigDimSize) {
131+
llvm::reportFatalUsageError("Block shape is too small for the swizzle "
132+
"byte size in NVMMA Shared Layout.");
133+
}
131134
blockShape[contigDim] = contigDimSize;
132135
}
133136
if (fp4Padded && packedSize) {

python/src/gluon_ir.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ struct GluonLayouts {
8686
py::handle BlockedLayout;
8787
py::handle SliceLayout;
8888
py::handle DistributedLinearLayout;
89+
py::handle NVMMADistributedLayout;
8990
py::handle NVMMASharedLayout;
9091
py::handle SwizzledSharedLayout;
9192

@@ -96,6 +97,8 @@ struct GluonLayouts {
9697
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
9798
DistributedLinearLayout =
9899
py::object(layouts.attr("DistributedLinearLayout")).release();
100+
NVMMADistributedLayout =
101+
py::object(layouts.attr("NVMMADistributedLayout")).release();
99102
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
100103
SwizzledSharedLayout =
101104
py::object(layouts.attr("SwizzledSharedLayout")).release();
@@ -131,6 +134,14 @@ py::object layoutToGluon(Attribute layout) {
131134
ll.getBases().lookup(kReg), ll.getBases().lookup(kLane),
132135
ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock),
133136
toStdVector(ArrayRef(llvm::to_vector(ll.getOutDimSizes()))));
137+
} else if (auto mma = dyn_cast<ttg::NvidiaMmaEncodingAttr>(layout)) {
138+
auto ctaLayout = mma.getCTALayout();
139+
return layouts.NVMMADistributedLayout(
140+
std::vector<unsigned>{mma.getVersionMajor(), mma.getVersionMinor()},
141+
toStdVector(mma.getWarpsPerCTA()),
142+
toStdVector(ctaLayout.getCTAsPerCGA()),
143+
toStdVector(ctaLayout.getCTASplitNum()),
144+
toStdVector(ctaLayout.getCTAOrder()), toStdVector(mma.getInstrShape()));
134145
} else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135146
auto ctaLayout = nvmma.getCTALayout();
136147
return layouts.NVMMASharedLayout(
@@ -224,6 +235,20 @@ void init_gluon_ir(py::module &&m) {
224235
/*requiresSurjective=*/true);
225236
return ttg::LinearEncodingAttr::get(ctx, ll);
226237
})
238+
.def("get_mma_layout",
239+
[](GluonOpBuilder &self, std::vector<unsigned> &version,
240+
std::vector<unsigned> &warpsPerCta,
241+
std::vector<unsigned> &ctasPerCga,
242+
std::vector<unsigned> &ctaSplitNum,
243+
std::vector<unsigned> &ctaOrder,
244+
std::vector<unsigned> &instrShape) -> Attribute {
245+
auto ctx = self.getContext();
246+
auto ctaLayout = self.getChecked<ttg::CTALayoutAttr>(
247+
ctx, ctasPerCga, ctaSplitNum, ctaOrder);
248+
return self.getChecked<ttg::NvidiaMmaEncodingAttr>(
249+
ctx, version[0], version[1], warpsPerCta, ctaLayout,
250+
instrShape);
251+
})
227252
.def("get_nvmma_shared_layout",
228253
[](GluonOpBuilder &self, unsigned swizzleByteWidth,
229254
unsigned elementBitwidth, bool transposed, bool fp4Padded,
@@ -359,6 +384,14 @@ void init_gluon_ir(py::module &&m) {
359384
auto op = self.create<triton::SplitOp>(TypeRange{resTy, resTy}, a);
360385
return py::make_tuple(op->getResult(0), op->getResult(1));
361386
})
387+
.def("create_warpgroup_mma",
388+
[](GluonOpBuilder &self, Value a, Value b, Value acc, Value useAcc,
389+
triton::InputPrecision precision = triton::InputPrecision::IEEE,
390+
int maxNumImpreciseAcc = 0, bool isAsync = false) -> Value {
391+
return self.create<ttng::WarpGroupDotOp>(
392+
a, b, acc, useAcc, precision, maxNumImpreciseAcc, isAsync);
393+
})
394+
362395
.def("create_tmem_alloc",
363396
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
364397
return self.create<ttng::TMEMAllocOp>(resultTy, value);

python/test/gluon/test_core.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch
22
import pytest
33

4-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer
4+
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_hopper_or_newer, is_hopper
55
from triton.experimental import gluon
66
from triton.experimental.gluon import language as ttgl
77
from triton.experimental.gluon.language.nvidia.ampere import async_copy, mbarrier
88
from triton.experimental.gluon.language.nvidia.hopper import tma
9+
from triton.experimental.gluon.language.nvidia import hopper
910

1011

1112
@gluon.jit
@@ -96,3 +97,44 @@ def test_async_copy_mbarrier():
9697
async_copy_mbarrier_kernel[(1, )](out, inp, inp.shape[0], XBLOCK=32, YBLOCK=32)
9798
torch.testing.assert_close(out[:20], inp)
9899
torch.testing.assert_close(out[20:], torch.zeros((12, 32), **tensor_opts))
100+
101+
102+
@gluon.jit
103+
def warpgroup_mma_kernel(a, b, out, M: ttgl.constexpr, N: ttgl.constexpr, K: ttgl.constexpr):
104+
block_layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0])
105+
mma_layout: ttgl.constexpr = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1],
106+
instr_shape=[16, 32, 16])
107+
nvmma_layout: ttgl.constexpr = ttgl.NVMMASharedLayout(swizzle_byte_width=64, element_bitwidth=16, rank=2)
108+
109+
a_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, block_layout))[:, None]
110+
a_offs_n = ttgl.arange(0, K, layout=ttgl.SliceLayout(0, block_layout))[None, :]
111+
b_offs_m = ttgl.arange(0, K, layout=ttgl.SliceLayout(1, block_layout))[:, None]
112+
b_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, block_layout))[None, :]
113+
114+
out_offs_m = ttgl.arange(0, M, layout=ttgl.SliceLayout(1, mma_layout))[:, None]
115+
out_offs_n = ttgl.arange(0, N, layout=ttgl.SliceLayout(0, mma_layout))[None, :]
116+
117+
acc = ttgl.zeros([M, N], dtype=a.dtype.element_ty, layout=mma_layout)
118+
A = ttgl.load(a + a_offs_m * K + a_offs_n)
119+
B = ttgl.load(b + b_offs_m * N + b_offs_n)
120+
121+
a_shmem = ttgl.allocate_shared_memory(ttgl.float16, [M, K], nvmma_layout, A)
122+
b_shmem = ttgl.allocate_shared_memory(ttgl.float16, [K, N], nvmma_layout, B)
123+
124+
acc = hopper.warpgroup_mma(a_shmem, b_shmem, acc)
125+
126+
ttgl.store(out + out_offs_m * N + out_offs_n, acc)
127+
128+
129+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper")
130+
def test_warpgroup_mma():
131+
torch.manual_seed(0)
132+
M, N, K = 64, 32, 32
133+
a = torch.randn((M, K), device="cuda", dtype=torch.float16)
134+
b = torch.randn((K, N), device="cuda", dtype=torch.float16)
135+
out = torch.zeros((M, N), device="cuda", dtype=torch.float16)
136+
warpgroup_mma_kernel[(1, )](a, b, out, M, N, K)
137+
138+
ref = torch.matmul(a, b)
139+
140+
torch.testing.assert_close(out, ref, atol=1e-3, rtol=1e-1)

python/test/gluon/test_frontend.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from triton.experimental import gluon
99
from triton.experimental.gluon import language as ttgl
1010
from triton.experimental.gluon.language.nvidia import blackwell
11+
from triton.experimental.gluon.language.nvidia import hopper
1112
from triton.experimental.gluon.language.nvidia.blackwell import mbarrier, tma, TensorMemoryLayout, async_copy
1213
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
1314
from triton._filecheck import filecheck_test, run_parser
1415
import triton.language as tl
15-
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_blackwell, is_hopper_or_newer
16+
from triton._internal_testing import is_cuda, is_ampere_or_newer, is_blackwell, is_hopper, is_hopper_or_newer
1617
from triton.compiler.errors import CompilationError, CompileTimeAssertionFailure
1718

1819
TARGET_PAT = re.compile('ttg.target = "[^"]*"')
@@ -446,6 +447,41 @@ def test_tcgen05_mma(fresh_knobs):
446447
""")
447448

448449

450+
@gluon.jit
451+
def warpgroup_mma_kernel(nvmma_layout: ttgl.constexpr, acc_layout: ttgl.constexpr):
452+
a = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
453+
b = ttgl.allocate_shared_memory(ttgl.float16, [128, 128], nvmma_layout)
454+
acc = ttgl.full([128, 128], 0, dtype=ttgl.float16, layout=acc_layout)
455+
hopper.warpgroup_mma(a, b, acc)
456+
457+
458+
@pytest.mark.skipif(not is_hopper(), reason="Requires Hopper WGMMA")
459+
def test_warpgroup_mma(fresh_knobs):
460+
knobs.compilation.disable_line_info = True
461+
462+
nvmma_layout = ttgl.NVMMASharedLayout(swizzle_byte_width=128, element_bitwidth=16, rank=2)
463+
mma_layout = ttgl.NVMMADistributedLayout(version=[3, 0], warps_per_cta=[4, 1], instr_shape=[16, 32, 16])
464+
h = warpgroup_mma_kernel.warmup(nvmma_layout, mma_layout, grid=(1, ))
465+
expecttest.assert_expected_inline(
466+
anonymize_ir(h.asm["source"]), """\
467+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 32, 16]}>
468+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
469+
#smem = #ttg.shared_memory
470+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
471+
tt.func public @warpgroup_mma_kernel() attributes {noinline = false} {
472+
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
473+
%1 = ttg.local_alloc : () -> !ttg.memdesc<128x128xf16, #shared, #smem, mutable> loc(#loc)
474+
%cst = arith.constant 0.000000e+00 : f16 loc(#loc)
475+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<128x128xf16, #mma> loc(#loc)
476+
%true = arith.constant true loc(#loc)
477+
%2 = ttng.warp_group_dot %0, %1, %cst_0, %true {inputPrecision = 0 : i32} : !ttg.memdesc<128x128xf16, #shared, #smem, mutable> * !ttg.memdesc<128x128xf16, #shared, #smem, mutable> -> tensor<128x128xf16, #mma> loc(#loc)
478+
tt.return loc(#loc)
479+
} loc(#loc)
480+
} loc(#loc)
481+
#loc = loc(unknown)
482+
""")
483+
484+
449485
@gluon.jit
450486
def async_tma_kernel(input_desc, XBLOCK: ttgl.constexpr):
451487
smem = ttgl.allocate_shared_memory(ttgl.float16, [XBLOCK, XBLOCK], input_desc.layout)

python/triton/experimental/gluon/language/_layouts.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"BlockedLayout",
77
"SliceLayout",
88
"DistributedLinearLayout",
9+
"NVMMADistributedLayout",
910
"NVMMASharedLayout",
1011
"SwizzledSharedLayout",
1112
]
@@ -133,6 +134,37 @@ def mangle(self):
133134
return f"DLL{self.reg_bases}_{self.lane_bases}_{self.warp_bases}_{self.block_bases}_{self.shape}DLL"
134135

135136

137+
@dataclass(frozen=True)
138+
class NVMMADistributedLayout(DistributedLayout):
139+
version: List[int]
140+
warps_per_cta: List[int]
141+
instr_shape: List[int]
142+
ctas_per_cga: Optional[List[int]] = None
143+
cta_split_num: Optional[List[int]] = None
144+
cta_order: Optional[List[int]] = None
145+
146+
def __post_init__(self):
147+
super().__setattr__("version", _unwrap_if_constexpr(self.version))
148+
super().__setattr__("warps_per_cta", _unwrap_if_constexpr(self.warps_per_cta))
149+
super().__setattr__("instr_shape", _unwrap_if_constexpr(self.instr_shape))
150+
super().__setattr__("ctas_per_cga", _unwrap_if_constexpr(self.ctas_per_cga))
151+
super().__setattr__("cta_split_num", _unwrap_if_constexpr(self.cta_split_num))
152+
super().__setattr__("cta_order", _unwrap_if_constexpr(self.cta_order))
153+
154+
rank = 2
155+
_realize_cta_layout(self, rank)
156+
assert len(self.ctas_per_cga) == rank
157+
assert len(self.cta_split_num) == rank
158+
assert len(self.cta_order) == rank
159+
160+
def _to_ir(self, builder):
161+
return builder.get_mma_layout(self.version, self.warps_per_cta, self.ctas_per_cga, self.cta_split_num,
162+
self.cta_order, self.instr_shape)
163+
164+
def mangle(self) -> str:
165+
return f"MMA_{self.version}_{self.warps_per_cta}_{self.instr_shape}_{self.ctas_per_cga}_{self.cta_split_num}_{self.cta_order}_MMA"
166+
167+
136168
class SharedLayout:
137169
pass
138170

python/triton/experimental/gluon/language/nvidia/hopper/__init__.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,26 @@
22
from . import mbarrier, tma
33
from ... import _core
44

5-
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma"]
5+
__all__ = ["async_copy", "fence_async_shared", "mbarrier", "tma", "warpgroup_mma"]
66

77

88
@_core.builtin
99
def fence_async_shared(cluster=False, _semantic=None):
1010
cluster = _core._unwrap_if_constexpr(cluster)
1111
_semantic.builder.create_fence_async_shared(cluster)
12+
13+
14+
@_core.builtin
15+
def warpgroup_mma(a, b, acc, *, use_acc=True, precision=None, max_num_imprecise_acc=0, is_async=False, _semantic=None):
16+
use_acc = _semantic.to_tensor(use_acc)
17+
18+
if precision is None:
19+
precision = _semantic.builder.options.default_dot_input_precision
20+
21+
precision = _semantic._str_to_dot_input_precision(precision)
22+
max_num_imprecise_acc = _core._unwrap_if_constexpr(max_num_imprecise_acc)
23+
is_async = _core._unwrap_if_constexpr(is_async)
24+
25+
handle = _semantic.builder.create_warpgroup_mma(a.handle, b.handle, acc.handle, use_acc.handle, precision,
26+
max_num_imprecise_acc, is_async)
27+
return _core.tensor(handle, acc.type)

test/TritonGPU/loop-pipeline-hopper.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -559,13 +559,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
559559
#smem = #ttg.shared_memory
560560
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
561561
// CHECK-LABEL: _kernel_matmul_dependency
562-
tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked>, %arg1: !tt.ptr<f8E4M3FNUZ> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} {
562+
tt.func public @_kernel_matmul_dependency(%arg0: tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked>, %arg1: !tt.ptr<f8E4M3FN> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: i32, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) attributes {noinline = false} {
563563
%cst = arith.constant dense<0> : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>
564564
%cst_0 = arith.constant 1.000000e+00 : f32
565565
%c8_i32 = arith.constant 8 : i32
566566
%cst_1 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma>
567567
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
568-
%1 = tt.splat %arg1 : !tt.ptr<f8E4M3FNUZ> -> tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>
568+
%1 = tt.splat %arg1 : !tt.ptr<f8E4M3FN> -> tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>
569569
%2:4 = scf.for %arg6 = %c8_i32 to %arg3 step %c8_i32 iter_args(%arg7 = %c8_i32, %arg8 = %c8_i32, %arg9 = %cst_1, %arg10 = %arg5) -> (i32, i32, tensor<128x128xf32, #mma>, tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked1}>>) : i32 {
570570
%3 = arith.addi %arg7, %c8_i32 : i32
571571
%4 = arith.cmpi eq, %3, %c8_i32 : i32
@@ -586,12 +586,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ
586586
%9 = arith.addi %8, %0 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>>
587587
%10 = tt.expand_dims %9 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1>
588588
%11 = tt.broadcast %10 : tensor<128x1xi32, #blocked1> -> tensor<128x128xi32, #blocked1>
589-
%12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>, tensor<128x128xi32, #blocked1>
590-
%13 = tt.load %arg0 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked>
591-
%14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FNUZ, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem>
592-
%15 = tt.load %12 : tensor<128x128x!tt.ptr<f8E4M3FNUZ>, #blocked1>
593-
%16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FNUZ, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem>
594-
%17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FNUZ, #shared, #smem> * !ttg.memdesc<128x128xf8E4M3FNUZ, #shared1, #smem> -> tensor<128x128xf32, #mma>
589+
%12 = tt.addptr %1, %11 : tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>, tensor<128x128xi32, #blocked1>
590+
%13 = tt.load %arg0 : tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked>
591+
%14 = ttg.local_alloc %13 : (tensor<128x128xf8E4M3FN, #blocked>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem>
592+
%15 = tt.load %12 : tensor<128x128x!tt.ptr<f8E4M3FN>, #blocked1>
593+
%16 = ttg.local_alloc %15 : (tensor<128x128xf8E4M3FN, #blocked1>) -> !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem>
594+
%17 = ttng.warp_group_dot %14, %16, %arg9 {inputPrecision = 0 : i32, maxNumImpreciseAcc = 1073741824 : i32} : !ttg.memdesc<128x128xf8E4M3FN, #shared, #smem> * !ttg.memdesc<128x128xf8E4M3FN, #shared1, #smem> -> tensor<128x128xf32, #mma>
595595
%18 = tt.splat %7 : f32 -> tensor<128x128xf32, #mma>
596596
%19 = arith.mulf %17, %18 : tensor<128x128xf32, #mma>
597597
%20 = scf.if %6 -> (tensor<128x128xf32, #mma>) {

0 commit comments

Comments
 (0)