Skip to content

Commit 5c05106

Browse files
[BE] TCGen5MMAScaledOp accepts scales in shared memory (#6019)
Enable `TCGen5MMAScaledOp` to accept scales in shared memory, if efficient lowering from shmem to tmem is possible. This removes the use of `TMEMCopyOp` from most of the compiler stack, it is left only for the purposes of lowering `TCGen5MMAScaledOp`. It simplifies IR analysis a little, and improves semantics of `TCGen5MMAScaledOp`, as no assumptions about HW pipelining of tmem copy -> mma are needed. --------- Co-authored-by: Thomas Raoux <[email protected]>
1 parent e38a482 commit 5c05106

File tree

12 files changed

+152
-74
lines changed

12 files changed

+152
-74
lines changed

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ def TTNG_TCGen5MMAOp : TTNG_Op<"tc_gen5_mma", [DeclareOpInterfaceMethods<MemoryE
332332

333333
let description = [{
334334
$d += matrix_multiply($a, $b).
335-
If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
335+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
336336
If there is a barrier the result will be safe to read after a barrier wait.
337337
If $two_ctas is set the op will execute a matmul across two contiguous CTAs, it will read the data distributed across the two CTAs.
338338
and syncronize both CTAs if the op is synchronous.
@@ -355,7 +355,7 @@ def TTNG_TCGen5MMAScaledOp : TTNG_Op<"tc_gen5_mma_scaled", [DeclareOpInterfaceMe
355355

356356
let description = [{
357357
$d += matrix_multiply(scale($lhs, $lhs_scale), scale(rlhs, $rhs_scale))
358-
If not barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
358+
If no barrier is given the op is assumed to be synchronous otherwise the op will trigger a commit/arrive on the given barrier.
359359
If there is a barrier the result will be safe to read after a barrier wait.
360360
}];
361361

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -580,8 +580,8 @@ Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
580580
loadConsumer = cvt;
581581
} else {
582582
// Unrecognized pattern, bail out. In practice, this implies that MMA
583-
// pipelining will not apply to the scaled dot op, since tmem_copy would
584-
// not be inserted before the pipeline pass.
583+
// pipelining will not apply to the scaled dot op, since scales will not
584+
// be in passed through SMEM to tc_gen5_mma_scaled.
585585
return scale;
586586
}
587587
}

lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -141,32 +141,53 @@ class FuseTransMMAV3Plus : public OpRewritePattern<LocalAllocOp> {
141141

142142
// Inject TMEM copy instructions into IR to efficiently load blocked scales for
143143
// scaled dot
144-
class InjectTMemCopy
145-
: public OpRewritePattern<triton::nvidia_gpu::TMEMAllocOp> {
144+
class UseShmemForScales
145+
: public OpRewritePattern<triton::nvidia_gpu::TCGen5MMAScaledOp> {
146146
public:
147-
using OpRewritePattern<triton::nvidia_gpu::TMEMAllocOp>::OpRewritePattern;
147+
using OpRewritePattern<
148+
triton::nvidia_gpu::TCGen5MMAScaledOp>::OpRewritePattern;
148149

149-
LogicalResult matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp tmemAlloc,
150+
LogicalResult matchAndRewrite(triton::nvidia_gpu::TCGen5MMAScaledOp mmaOp,
150151
PatternRewriter &rewriter) const override {
151-
auto dstType = tmemAlloc.getResult().getType();
152+
auto aScale = mmaOp.getAScale();
153+
auto bScale = mmaOp.getBScale();
154+
LogicalResult ret = failure();
155+
if (aScale && isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
156+
aScale.getType().getEncoding())) {
157+
if (rewriteOperand(mmaOp.getAScaleMutable(), rewriter).succeeded())
158+
ret = success();
159+
}
160+
if (bScale && isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
161+
bScale.getType().getEncoding())) {
162+
if (rewriteOperand(mmaOp.getBScaleMutable(), rewriter).succeeded())
163+
ret = success();
164+
}
165+
return ret;
166+
}
152167

153-
// Only applies to TMEMAlloc with scales encoding
154-
if (!isa<triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(
155-
dstType.getEncoding())) {
168+
private:
169+
LogicalResult rewriteOperand(OpOperand &opOperand,
170+
PatternRewriter &rewriter) const {
171+
auto src = cast<TypedValue<MemDescType>>(opOperand.get());
172+
auto tmemAlloc = src.getDefiningOp<triton::nvidia_gpu::TMEMAllocOp>();
173+
if (!tmemAlloc) {
156174
return failure();
157175
}
176+
auto dstType = tmemAlloc.getResult().getType();
158177

159178
if (!tmemAlloc.getSrc()) {
160179
return failure();
161180
}
162181

163182
// Look for a sequence
164183
// local_load
165-
// -> reshape(..., (BLOCK_MN / 128, BLOCK_K / scale_vec_size / 4, 32, 4, 4)
184+
// -> reshape(..., (BLOCK_MN / 128, BLOCK_K / scale_vec_size / 4, 32, 4,
185+
// 4)
166186
// -> transpose(..., (0, 3, 2, 1, 4))
167187
// -> reshape(..., (BLOCK_MN, BLOCK_K / scale_vec_size)
168188
// -> tmem_alloc
169-
// and replace it with tmem_alloc -> tmem_copy
189+
// -> tc_gen_mma_scaled
190+
// and replace it with local_alloc -> tc_gen_mma_scaled
170191
auto scale2DShape = dstType.getShape();
171192
auto blockMN = scale2DShape[0];
172193
auto numScales = scale2DShape[1];
@@ -195,24 +216,10 @@ class InjectTMemCopy
195216
if (!localLoad || !isTmemCopyCompatible(localLoad.getSrc().getType())) {
196217
return failure();
197218
}
198-
MemDescType newType = MemDescType::get(
199-
dstType.getShape(), dstType.getElementType(), dstType.getEncoding(),
200-
dstType.getMemorySpace(), /*mutableMemory=*/true);
201-
Value newTmemAlloc = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
202-
tmemAlloc.getLoc(), newType, Value());
203-
204-
// Since tcgen05.cp followed by tcgen05.mma is guaranteed to execute in that
205-
// order, we do not need to wait for the completion of the copy before MMA.
206-
rewriter.create<triton::nvidia_gpu::TMEMCopyOp>(
207-
newTmemAlloc.getLoc(), localLoad.getSrc(), newTmemAlloc,
208-
Value() /* barrier */);
209-
210-
rewriter.replaceOp(tmemAlloc, newTmemAlloc);
211-
219+
opOperand.assign(localLoad.getSrc());
212220
return success();
213221
}
214222

215-
private:
216223
template <typename Op> Op getNextOp(Value op) const {
217224
while (auto cvtOp = op.getDefiningOp<ConvertLayoutOp>()) {
218225
op = cvtOp.getSrc();
@@ -285,7 +292,7 @@ class TritonGPUOptimizeDotOperandsPass
285292
mlir::RewritePatternSet patterns(context);
286293
patterns.add<SwizzleShmemConvert>(context);
287294
patterns.add<FuseTransMMAV3Plus>(context);
288-
patterns.add<InjectTMemCopy>(context);
295+
patterns.add<UseShmemForScales>(context);
289296
ConvertLayoutOp::getCanonicalizationPatterns(patterns, context);
290297
if (failed(applyPatternsGreedily(m, std::move(patterns))))
291298
signalPassFailure();

lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,7 @@ bool loadRequiresAdditionalBuffer(Operation *loadOp) {
560560
ttg::LocalAllocOp alloc =
561561
dyn_cast<ttg::LocalAllocOp>(*loadOp->getUsers().begin());
562562
if (alloc && alloc->hasOneUse()) {
563-
if (isa<ttng::TMEMCopyOp>(*alloc->getUsers().begin())) {
563+
if (isa<ttng::TCGen5MMAScaledOp>(*alloc->getUsers().begin())) {
564564
return true;
565565
}
566566
}

lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -571,18 +571,14 @@ void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
571571
}
572572

573573
bool isSafeToPipeline(ttng::TCGen5MMAScaledOp scaledDot) {
574-
auto getNumUsers = [](Value value) {
575-
return std::distance(value.user_begin(), value.user_end());
576-
};
577-
578574
auto isCopiedByTMEMCopy = [=](Value scale) {
579-
if (getNumUsers(scale) != 2) {
580-
// MMA and TMEM copy must be the only users
575+
if (!scale.hasOneUse()) {
576+
// Should be used only by the scaled dot op
581577
return false;
582578
}
583579

584580
for (auto user : scale.getUsers()) {
585-
if (!isa<ttng::TMEMCopyOp, ttng::TCGen5MMAScaledOp>(user)) {
581+
if (!isa<ttng::TCGen5MMAScaledOp>(user)) {
586582
// If the scale is used by TMEM copy and the only other user is the
587583
// scaled dot op, MMA pipelining is safe to apply.
588584
return false;

lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
2929
return failure();
3030
MLIRContext *ctx = op.getContext();
3131
Location loc = op.getLoc();
32-
Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx);
32+
Attribute sharedMemorySpace = SharedMemorySpaceAttr::get(ctx);
3333
auto barrierCTALayout = CTALayoutAttr::get(
3434
/*context=*/ctx, /*CTAsPerCGA=*/{1},
3535
/*CTASplitNum=*/{1}, /*CTAOrder=*/{0});
@@ -51,6 +51,50 @@ class SyncMMALowering : public OpRewritePattern<TCGen5MMAOpTy> {
5151
}
5252
};
5353

54+
struct TCGen5MMAScaleSharedToTmemConversion
55+
: public OpRewritePattern<TCGen5MMAScaledOp> {
56+
using OpRewritePattern<TCGen5MMAScaledOp>::OpRewritePattern;
57+
58+
bool lowerScaleToTmem(OpOperand &operand, PatternRewriter &rewriter) const {
59+
Location loc = operand.getOwner()->getLoc();
60+
MLIRContext *context = operand.getOwner()->getContext();
61+
Attribute tensorMemorySpace = TensorMemorySpaceAttr::get(context);
62+
auto oldType = cast<MemDescType>(operand.get().getType());
63+
Type elType = oldType.getElementType();
64+
SwizzledSharedEncodingAttr oldEncoding =
65+
cast<SwizzledSharedEncodingAttr>(oldType.getEncoding());
66+
CTALayoutAttr CTALayout = getCTALayout(oldEncoding);
67+
ArrayRef<unsigned> CTASplitNum = CTALayout.getCTASplitNum();
68+
ArrayRef<int64_t> shape = oldType.getAllocShape();
69+
Attribute scaleEncoding = TensorMemoryScalesEncodingAttr::get(
70+
context, CTASplitNum[0], CTASplitNum[1]);
71+
Type scaleAType =
72+
MemDescType::get(shape, elType, scaleEncoding, tensorMemorySpace,
73+
/*mutableMemory=*/true);
74+
auto tmemAlloc = rewriter.create<TMEMAllocOp>(loc, scaleAType, Value());
75+
rewriter.create<TMEMCopyOp>(loc, operand.get(), tmemAlloc,
76+
/*barrier*/ Value());
77+
operand.set(tmemAlloc);
78+
return true;
79+
}
80+
81+
LogicalResult matchAndRewrite(TCGen5MMAScaledOp op,
82+
PatternRewriter &rewriter) const override {
83+
Location loc = op.getLoc();
84+
MLIRContext *context = op->getContext();
85+
auto aScaleType = op.getAScale().getType();
86+
auto bScaleType = op.getBScale().getType();
87+
bool anyChanged = false;
88+
if (isa<SwizzledSharedEncodingAttr>(aScaleType.getEncoding())) {
89+
anyChanged = lowerScaleToTmem(op.getAScaleMutable(), rewriter);
90+
}
91+
if (isa<SwizzledSharedEncodingAttr>(bScaleType.getEncoding())) {
92+
anyChanged = lowerScaleToTmem(op.getBScaleMutable(), rewriter);
93+
}
94+
return LogicalResult::success(anyChanged);
95+
}
96+
};
97+
5498
class TritonNvidiaGPUMMALoweringPass
5599
: public TritonNvidiaGPUMMALoweringPassBase<
56100
TritonNvidiaGPUMMALoweringPass> {
@@ -61,8 +105,8 @@ class TritonNvidiaGPUMMALoweringPass
61105

62106
mlir::RewritePatternSet patterns(context);
63107
patterns
64-
.add<SyncMMALowering<TCGen5MMAOp>, SyncMMALowering<TCGen5MMAScaledOp>>(
65-
context);
108+
.add<SyncMMALowering<TCGen5MMAOp>, SyncMMALowering<TCGen5MMAScaledOp>,
109+
TCGen5MMAScaleSharedToTmemConversion>(context);
66110
if (applyPatternsGreedily(m, std::move(patterns)).failed())
67111
signalPassFailure();
68112
}

python/test/unit/language/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
484484
b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
485485
NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD)
486486
ttgir = out.asm["ttgir"]
487+
ptx = out.asm["ptx"]
487488

488489
def flatten_scale(scale):
489490
num_chunk_m, num_chunk_k, _, _, _ = scale.shape
@@ -508,8 +509,7 @@ def flatten_scale(scale):
508509
if USE_2D_SCALE_LOAD:
509510
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
510511
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
511-
assert "tmem_copy" in ttgir
512-
512+
assert "tcgen05.cp" in ptx
513513
if NUM_STAGES > 1:
514514
if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K:
515515
load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2

test/TritonGPU/dot-operands.mlir

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -95,27 +95,44 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num-
9595

9696
// -----
9797

98+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
9899
#shared1 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
99100
#smem = #ttg.shared_memory
100101
#blocked4 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [2, 2], order = [1, 0]}>
101102
#blocked8 = #ttg.blocked<{sizePerThread = [1, 1, 1, 2, 4], threadsPerWarp = [1, 1, 16, 2, 1], warpsPerCTA = [2, 1, 2, 1, 1], order = [4, 3, 2, 1, 0]}>
102103
#blocked9 = #ttg.blocked<{sizePerThread = [1, 2, 1, 1, 4], threadsPerWarp = [1, 2, 16, 1, 1], warpsPerCTA = [2, 1, 2, 1, 1], order = [4, 1, 2, 3, 0]}>
103104
#blocked10 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 32, 1, 1], warpsPerCTA = [1, 1, 1, 1, 4], order = [4, 3, 2, 1, 0]}>
104105
#blocked11 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
106+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
105107
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
106108
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
107-
// CHECK-LABEL: @inject_tmem_copy
108-
// CHECK: ttng.tmem_alloc {{.*}}, mutable
109-
// CHECK: ttng.tmem_copy
109+
// CHECK-LABEL: @scales_in_shmem
110+
// CHECK: %[[A_LA:.*]] = ttg.local_alloc
111+
// CHECK: %[[B_LA:.*]] = ttg.local_alloc
112+
// CHECK: ttng.tc_gen5_mma_scaled {{.*}}, %[[A_LA]], %[[B_LA]],
110113

111-
tt.func public @inject_tmem_copy(%scale: tensor<2x512x!tt.ptr<i8>, #blocked4> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32}) attributes {noinline = false} {
112-
%75 = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable>
113-
%180 = ttg.local_load %75 : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4>
114-
%183 = tt.reshape %180 : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8>
115-
%184 = tt.trans %183 {order = array<i32: 0, 3, 2, 1, 4>} : tensor<2x1x32x4x4xi8, #blocked8> -> tensor<2x4x32x1x4xi8, #blocked9>
116-
%187 = ttg.convert_layout %184 : tensor<2x4x32x1x4xi8, #blocked9> -> tensor<2x4x32x1x4xi8, #blocked10>
117-
%188 = tt.reshape %187 : tensor<2x4x32x1x4xi8, #blocked10> -> tensor<256x4xi8, #blocked11>
118-
%190 = ttng.tmem_alloc %188 : (tensor<256x4xi8, #blocked11>) -> !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>
114+
tt.func public @scales_in_shmem(
115+
%scale: tensor<2x512x!tt.ptr<i8>, #blocked4> {tt.contiguity = 16 : i32, tt.divisibility = 16 : i32},
116+
%A_sh: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
117+
%B_sh: !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>,
118+
%acc_tm: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>
119+
) attributes {noinline = false} {
120+
%true = arith.constant true
121+
%A_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable>
122+
%B_la = ttg.local_alloc : () -> !ttg.memdesc<2x512xi8, #shared1, #smem, mutable>
123+
%A_ll = ttg.local_load %A_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4>
124+
%B_ll = ttg.local_load %B_la : !ttg.memdesc<2x512xi8, #shared1, #smem, mutable, 3x2x512> -> tensor<2x512xi8, #blocked4>
125+
%A_r = tt.reshape %A_ll : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8>
126+
%B_r = tt.reshape %B_ll : tensor<2x512xi8, #blocked4> -> tensor<2x1x32x4x4xi8, #blocked8>
127+
%A_tr = tt.trans %A_r {order = array<i32: 0, 3, 2, 1, 4>} : tensor<2x1x32x4x4xi8, #blocked8> -> tensor<2x4x32x1x4xi8, #blocked9>
128+
%B_tr = tt.trans %B_r {order = array<i32: 0, 3, 2, 1, 4>} : tensor<2x1x32x4x4xi8, #blocked8> -> tensor<2x4x32x1x4xi8, #blocked9>
129+
%A_cv = ttg.convert_layout %A_tr : tensor<2x4x32x1x4xi8, #blocked9> -> tensor<2x4x32x1x4xi8, #blocked10>
130+
%B_cv = ttg.convert_layout %B_tr : tensor<2x4x32x1x4xi8, #blocked9> -> tensor<2x4x32x1x4xi8, #blocked10>
131+
%A_r2 = tt.reshape %A_cv : tensor<2x4x32x1x4xi8, #blocked10> -> tensor<256x4xi8, #blocked11>
132+
%B_r2 = tt.reshape %B_cv : tensor<2x4x32x1x4xi8, #blocked10> -> tensor<256x4xi8, #blocked11>
133+
%A_tm = ttng.tmem_alloc %A_r2 : (tensor<256x4xi8, #blocked11>) -> !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>
134+
%B_tm = ttng.tmem_alloc %B_r2 : (tensor<256x4xi8, #blocked11>) -> !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>
135+
ttng.tc_gen5_mma_scaled %A_sh, %B_sh, %acc_tm, %A_tm, %B_tm, %true, %true lhs = e5m2 rhs = e5m2 {loop.cluster = 0 : i32, loop.stage = 2 : i32} : (!ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory>, !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, !ttg.memdesc<256x4xi8, #tmem_scales, #ttng.tensor_memory>, i1, i1) -> ()
119136
tt.return
120137
}
121138
}

test/TritonGPU/loop-pipeline-blackwell.mlir

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
384384
%122 = tt.load %arg19 : tensor<1x2x32x4x4x!tt.ptr<i8>, #blocked2>
385385

386386
%137 = ttg.local_alloc %121 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>
387-
%130 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>
388-
ttng.tmem_copy %137, %130, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>) -> ()
389-
390387
%139 = ttg.local_alloc %122 : (tensor<1x2x32x4x4xi8, #blocked2>) -> !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>
391-
%131 = ttng.tmem_alloc : () -> !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>
392-
ttng.tmem_copy %139, %131, : (!ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<128x8xi8, #tmem_scales, #ttng.tensor_memory, mutable>) -> ()
393388

394389
%127 = ttng.tmem_alloc %arg15 : (tensor<128x128xf32, #blocked4>) -> !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>
395390

396-
ttng.tc_gen5_mma_scaled %118, %120, %127, %130, %131, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory, mutable>, !ttg.memdesc<128x8xi8, #ttng.tensor_memory_scales_encoding<>, #ttng.tensor_memory, mutable>, i1, i1) -> ()
391+
ttng.tc_gen5_mma_scaled %118, %120, %127, %137, %139, %true, %true lhs = e5m2 rhs = e5m2 : (!ttg.memdesc<128x256xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<256x128xf8E5M2, #shared, #ttg.shared_memory>, !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, !ttg.memdesc<1x2x32x4x4xi8, #shared1, #smem>, i1, i1) -> ()
397392
%132 = ttng.tmem_load %127 : !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable> -> tensor<128x128xf32, #blocked4>
398393

399394
%133 = tt.addptr %arg16, %incr_A : tensor<128x256x!tt.ptr<f8E5M2>, #blocked>, tensor<128x256xi32, #blocked>

0 commit comments

Comments
 (0)