Skip to content

Commit 77a7b4c

Browse files
masahibinarybana
authored andcommitted
[Blackwell] Enable MMA pipelining for scaled dot when TMEM copy is used (triton-lang#5812)
This PR enables MMA pipelining for scaled dot. The main difficulty this PR overcomes is the dependency cycle between TMEM copy rewriting and SWP - currently TMEM copy rewriting relies on SWP to put loading of scales into SMEM, while to apply MMA pipelining during SWP, TMEM copy rewriting needs to have happened beforehand. I propose to break the cycle by having loading of scales go through `local_alloc` and `local_load` in `AccelerateMatmul`. This way, TMEM copy rewriting happens during [the first call to OptimizedDotOperands,](https://github.com/triton-lang/triton/blob/1e0e51c4aeb3e1beea000da5d0e494f8b9ac40dd/third_party/nvidia/backend/compiler.py#L260) before SWP. And the local alloc and load added in `AccelerateMatmul` are eliminated during SWP. It's a bit ad hoc to add local alloc for scales there, since scales do not need to be in SMEM. But other solutions, like decoupling MMA pipelining from SWP, is more difficult. The other changes in this PR are for making SWP recognize loading of scales when there is TMEM copy between scale load and MMA. @ThomasRaoux @pawelszczerbuk @csullivan @mbrookhart @binarybana --------- Co-authored-by: Masahiro Masuda <[email protected]> Co-authored-by: Jason Knight <[email protected]>
1 parent 5d2a1d2 commit 77a7b4c

File tree

8 files changed

+386
-36
lines changed

8 files changed

+386
-36
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 81 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,25 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
184184
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
185185
}
186186

187+
static LocalAllocOp
188+
getSharedMemoryScale(Value arg, mlir::PatternRewriter &rewriter, Location loc) {
189+
OpBuilder::InsertionGuard g(rewriter);
190+
auto argType = cast<RankedTensorType>(arg.getType());
191+
assert(argType.getEncoding() && "unexpected tensor type");
192+
auto newOrder = getOrder(argType.getEncoding());
193+
194+
Attribute SharedMemorySpace =
195+
SharedMemorySpaceAttr::get(argType.getContext());
196+
auto CTALayout = getCTALayout(argType.getEncoding());
197+
// No swizzling for scale for now
198+
auto newLayout = SwizzledSharedEncodingAttr::get(argType.getContext(), 1, 1,
199+
1, newOrder, CTALayout);
200+
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
201+
newLayout, SharedMemorySpace);
202+
rewriter.setInsertionPointAfterValue(arg);
203+
return rewriter.create<LocalAllocOp>(loc, newType, arg);
204+
}
205+
187206
SmallVector<unsigned, 3>
188207
getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
189208
int numWarps, const SmallVector<unsigned, 3> &instrShape) {
@@ -575,6 +594,60 @@ class BlockedToMMAv5 : public mlir::OpRewritePattern<DotOp> {
575594
}
576595
};
577596

597+
Value addSmemStageToScaleLoad(Value scale, mlir::PatternRewriter &rewriter) {
598+
/*
599+
Rewrite load(scale) -> local_load(local_alloc(load(scale))).
600+
This function does not add anything to the final IR when num_stages > 1,
601+
but it makes it easy to apply TMEM copy rewriting later.
602+
603+
Since scales are stored in TMEM for MMAv5 scaled dot, loading of scales do
604+
not needs to be put into SMEM. But in practice, the software pipeliner puts
605+
loading of scales into multi-buffered SMEM. At that point, the SMEM
606+
allocation created here is eliminated.
607+
*/
608+
OpBuilder::InsertionGuard g(rewriter);
609+
auto op = scale.getDefiningOp();
610+
Operation *loadConsumer = nullptr;
611+
612+
if (!op)
613+
return scale;
614+
615+
while (!isa<LoadOp>(op)) {
616+
if (auto reshape = dyn_cast<ReshapeOp>(op)) {
617+
op = reshape.getSrc().getDefiningOp();
618+
loadConsumer = reshape;
619+
} else if (auto trans = dyn_cast<TransOp>(op)) {
620+
op = trans.getSrc().getDefiningOp();
621+
loadConsumer = trans;
622+
} else if (auto cvt = dyn_cast<ConvertLayoutOp>(op)) {
623+
op = cvt.getSrc().getDefiningOp();
624+
loadConsumer = cvt;
625+
} else {
626+
// Unrecognized pattern, bail out. In practice, this implies that MMA
627+
// pipelining will not apply to the scaled dot op, since tmem_copy would
628+
// not be inserted before the pipeline pass.
629+
return scale;
630+
}
631+
}
632+
633+
auto scaleAfterLoad = op->getResult(0);
634+
auto scaleSmemAlloc =
635+
getSharedMemoryScale(scaleAfterLoad, rewriter, op->getLoc());
636+
637+
rewriter.setInsertionPointAfterValue(scaleSmemAlloc);
638+
auto localLoad = rewriter.create<LocalLoadOp>(
639+
op->getLoc(), scaleAfterLoad.getType(), scaleSmemAlloc);
640+
641+
rewriter.replaceAllUsesExcept(scaleAfterLoad, localLoad.getResult(),
642+
scaleSmemAlloc);
643+
644+
if (loadConsumer) {
645+
return scale;
646+
} else {
647+
return localLoad;
648+
}
649+
}
650+
578651
class ScaledBlockedToMMAv5
579652
: public mlir::OpRewritePattern<triton::DotScaledOp> {
580653
int computeCapability;
@@ -688,10 +761,14 @@ class ScaledBlockedToMMAv5
688761
oldScaleAType.getShape(), oldScaleAType.getElementType(), scaleALayout);
689762
RankedTensorType newScaleBType = RankedTensorType::get(
690763
oldScaleBType.getShape(), oldScaleBType.getElementType(), scaleBLayout);
691-
Value newScaleA = rewriter.create<ConvertLayoutOp>(loc, newScaleAType,
692-
dotOp.getLhsScale());
693-
Value newScaleB = rewriter.create<ConvertLayoutOp>(loc, newScaleBType,
694-
dotOp.getRhsScale());
764+
765+
auto lhsScale = addSmemStageToScaleLoad(dotOp.getLhsScale(), rewriter);
766+
auto rhsScale = addSmemStageToScaleLoad(dotOp.getRhsScale(), rewriter);
767+
768+
Value newScaleA =
769+
rewriter.create<ConvertLayoutOp>(loc, newScaleAType, lhsScale);
770+
Value newScaleB =
771+
rewriter.create<ConvertLayoutOp>(loc, newScaleBType, rhsScale);
695772
Value scaleA = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(
696773
loc, scaleAType, newScaleA);
697774
Value scaleB = rewriter.create<triton::nvidia_gpu::TMEMAllocOp>(

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
181181
dfs(defOp, finalUser, distance);
182182
}
183183
}
184+
if (auto tmemAlloc = dyn_cast<nvidia_gpu::TMEMAllocOp>(op)) {
185+
if (!tmemAlloc.getSrc()) {
186+
for (auto user : tmemAlloc.getResult().getUsers()) {
187+
if (auto tmemCopy = dyn_cast<nvidia_gpu::TMEMCopyOp>(user)) {
188+
dfs(tmemCopy.getSrc().getDefiningOp(), finalUser, distance);
189+
break;
190+
}
191+
}
192+
}
193+
}
184194
};
185195

186196
bool seenDot = false;

lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,15 +177,18 @@ static int createAsyncCopy(scf::ForOp forOp, tt::LoadOp loadOp, Value alloc,
177177
Operation *wait = builder.createWithStage<ttg::AsyncWaitOp>(
178178
loc, stageForFirstUse, clusterForFirstUse, commit->getResult(0), 0);
179179

180-
auto loadIsMMAv3Shared = loadToInfo[loadOp].isMMAv3Shared;
181-
182180
// Extract part.
183181
SmallVector<Value> loadOffsets(allocTy.getRank(), zero);
184182
loadOffsets[0] = extractIdx;
185183
auto viewLoad = builder.createWithStage<ttg::MemDescSubviewOp>(
186184
loc, stageForFirstUse, clusterForFirstUse, subviewTy, alloc, loadOffsets);
187-
if (loadIsMMAv3Shared) {
188-
auto alloc = cast<ttg::LocalAllocOp>((*loadOp->getUsers().begin()));
185+
186+
if (loadToInfo[loadOp].isMMAv3Shared || loadToInfo[loadOp].isMMAv5Scale) {
187+
auto user = *loadOp->getUsers().begin();
188+
assert(isa<triton::gpu::LocalAllocOp>(user) &&
189+
"Loading of MMAv3 operands and MMAv5 scale is expected to be "
190+
"consumed by LocalAlloc.");
191+
auto alloc = cast<ttg::LocalAllocOp>(user);
189192
tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult());
190193
alloc.erase();
191194
} else {
@@ -455,6 +458,12 @@ getTransitiveUserInBlock(Operation *baseOp, scf::ForOp &forOp) {
455458
for (Operation *user : op->getUsers())
456459
if (user->getBlock() == op->getBlock())
457460
dfs(user, baseOp, anyOp);
461+
if (auto tmemCopy = dyn_cast<triton::nvidia_gpu::TMEMCopyOp>(op)) {
462+
auto tmemAlloc =
463+
tmemCopy.getDst()
464+
.getDefiningOp<triton::nvidia_gpu::TMEMAllocOp>();
465+
dfs(tmemAlloc, baseOp, anyOp);
466+
}
458467
};
459468
// We are matching the behavior before refactoring:
460469
// For loops without num_stage attributes, we check for dot users.

lib/Dialect/TritonGPU/Transforms/Pipeliner/TC05MMAPipeline.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -593,6 +593,31 @@ void createBarrierAndWaitOps(IRRewriter &builder, scf::ForOp forOp,
593593
annotateWithPipelineStage(builder, info.phase.getDefiningOp(), 0);
594594
}
595595

596+
bool isSafeToPipeline(ttng::TCGen5MMAScaledOp scaledDot) {
597+
auto getNumUsers = [](Value value) {
598+
return std::distance(value.user_begin(), value.user_end());
599+
};
600+
601+
auto isCopiedByTMEMCopy = [=](Value scale) {
602+
if (getNumUsers(scale) != 2) {
603+
// MMA and TMEM copy must be the only users
604+
return false;
605+
}
606+
607+
for (auto user : scale.getUsers()) {
608+
if (!isa<ttng::TMEMCopyOp, ttng::TCGen5MMAScaledOp>(user)) {
609+
// If the scale is used by TMEM copy and the only other user is the
610+
// scaled dot op, MMA pipelining is safe to apply.
611+
return false;
612+
}
613+
}
614+
return true;
615+
};
616+
617+
return isCopiedByTMEMCopy(scaledDot.getAScale()) &&
618+
isCopiedByTMEMCopy(scaledDot.getBScale());
619+
}
620+
596621
// Find MMAs eligible for pipelining and lower them by:
597622
// 1. Hoisting the accumulator allocation outside of the loop.
598623
// 2. Creating a barrier alloc and lowering the MMA to MMA + wait barrier.
@@ -603,9 +628,17 @@ FailureOr<scf::ForOp> preProcessLoopForTC05MMAPipelining(scf::ForOp forOp,
603628
SmallVector<Operation *> mmaOps;
604629
forOp.walk([&](Operation *op) {
605630
// Skip MMA nested in another forOp
606-
if (isa<ttng::TCGen5MMAOp>(op) &&
607-
op->getParentOfType<scf::ForOp>() == forOp) {
608-
mmaOps.push_back(op);
631+
if (op->getParentOfType<scf::ForOp>() == forOp) {
632+
if (isa<ttng::TCGen5MMAOp>(op)) {
633+
mmaOps.push_back(op);
634+
} else if (auto scaledDot = dyn_cast<ttng::TCGen5MMAScaledOp>(op)) {
635+
if (isSafeToPipeline(scaledDot)) {
636+
mmaOps.push_back(op);
637+
} else {
638+
op->emitWarning("Skipping pipelining of an MMAv5 scaled op because "
639+
"TMEM copy is not used.");
640+
}
641+
}
609642
}
610643
});
611644

python/test/unit/language/test_matmul.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -352,12 +352,9 @@ def test_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, device):
352352
rtol = 0.0001
353353
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
354354

355-
if NUM_STAGES > 1:
356-
# TODO: Remove this check once MMA pipelining is working for these cases
357-
if M >= BLOCK_M and N >= BLOCK_N and K >= BLOCK_K:
358-
# Verify that MMA pipelining has been applied
359-
# FIXME: Scaled dot pipelining is DISABLED
360-
assert "ttng.wait_barrier" not in out.asm["ttgir"]
355+
# Pipelining of dot_scaled requires tmem_copy to be used, which in turn
356+
# requires the scales to be in the blocked layout in global memory.
357+
assert "ttng.wait_barrier" not in out.asm["ttgir"]
361358

362359

363360
def _knob_promote_lhs_to_tmem(monkeypatch):
@@ -437,13 +434,21 @@ def block_scale_mxfp_matmul( #
437434
tl.store(output_ptrs, accumulator, mask=c_mask)
438435

439436

437+
def _knob_disable_ptxas_opt(monkeypatch):
438+
monkeypatch.setenv("DISABLE_PTXAS_OPT", "1")
439+
440+
440441
@pytest.mark.parametrize("M, N, K", [(1024, 512, 512), (998, 111, 512), (63, 128, 512)])
441442
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(128, 128, 128), (256, 128, 128), (128, 256, 128),
442443
(128, 128, 256), (128, 256, 256)])
443444
@pytest.mark.parametrize("NUM_STAGES", [1, 2, 4])
444445
@pytest.mark.parametrize("USE_2D_SCALE_LOAD", [False, True])
445446
@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 10, reason="Requires compute capability >= 10")
446-
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device):
447+
def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_2D_SCALE_LOAD, device, monkeypatch):
448+
if NUM_STAGES == 1 and USE_2D_SCALE_LOAD:
449+
# Disabling ptxas optimization as a temporary workaround, otherwise the test does not pass
450+
_knob_disable_ptxas_opt(monkeypatch)
451+
447452
if BLOCK_N == 256 and BLOCK_K == 256:
448453
NUM_STAGES = min(NUM_STAGES, 2)
449454
elif BLOCK_K == 256:
@@ -467,6 +472,7 @@ def test_blocked_scale_mxfp(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_STAGES, USE_
467472
a_scale.stride(2), a_scale.stride(3), a.stride(0), a.stride(1), b.stride(0),
468473
b.stride(1), output.stride(0), output.stride(1), BLOCK_M, BLOCK_N, BLOCK_K,
469474
NUM_STAGES=NUM_STAGES, USE_2D_SCALE_LOAD=USE_2D_SCALE_LOAD)
475+
ttgir = out.asm["ttgir"]
470476

471477
def flatten_scale(scale):
472478
num_chunk_m, num_chunk_k, _, _, _ = scale.shape
@@ -488,30 +494,27 @@ def flatten_scale(scale):
488494
rtol = 0.0001
489495
torch.testing.assert_close(ref_out, output, atol=atol, rtol=rtol)
490496

491-
if NUM_STAGES > 1:
492-
ttgir = out.asm["ttgir"]
497+
if USE_2D_SCALE_LOAD:
498+
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
499+
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
500+
assert "tmem_copy" in ttgir
493501

502+
if NUM_STAGES > 1:
494503
if BLOCK_M == BLOCK_K and BLOCK_N == BLOCK_K:
495504
load_pipelined = ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") == 2
496505
else:
497506
load_pipelined = (ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_M}x{BLOCK_K}") and
498507
ttgir.count(f"ttg.local_alloc : () -> !ttg.memdesc<{NUM_STAGES}x{BLOCK_K}x{BLOCK_N}"))
499508

500-
if load_pipelined:
501-
# If load is pipelined, MMA pipelining should also kick in
502-
# FIXME: Scaled dot pipelining is DISABLED
503-
assert "ttng.wait_barrier" not in ttgir
504-
else:
509+
if load_pipelined and USE_2D_SCALE_LOAD:
510+
# If load is pipelined and tmem_copy is used, MMA pipelining should also kick in
511+
assert "ttng.wait_barrier" in ttgir
512+
elif not load_pipelined:
505513
# The behavior of load pipelining seems to depend on the size of input tensors.
506514
# In this test, it fails to pipeline the RHS tensor when N is not a multiple of 128. Pipelining of the LHS tensor
507515
# does not seem to be affected by the value of M, though.
508516
print(f"SWP failed for M = {M}, N = {N}")
509517

510-
if USE_2D_SCALE_LOAD:
511-
# Due to an issue in the coalescing pass, tmem_copy can not be generated for the 5D load.
512-
# The issue is fixed using the patch from https://github.com/triton-lang/triton/pull/4914
513-
assert "tmem_copy" in ttgir
514-
515518

516519
@triton.jit
517520
def lhs_in_tmem_kernel( #

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,12 +302,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
302302
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
303303
// CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x64xi8, #{{.*}}>) -> !ttg.memdesc<128x64xi8, #{{.*}}, #smem
304304
// CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<64x128xi8, #{{.*}}>) -> !ttg.memdesc<64x128xi8, #{{.*}}, #smem
305+
// CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem>
306+
// CHECK: ttg.local_load %[[SCALEA_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}>
307+
// CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #{{.*}}, #smem>
308+
// CHECK: ttg.local_load %[[SCALEB_LOCAL]] : !ttg.memdesc<128x2xi8, #{{.*}}, #smem> -> tensor<128x2xi8, #{{.*}}>
305309
// CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable>
306310
// CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory>
307311
// CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x2xi8, #{{.*}}>) -> !ttg.memdesc<128x2xi8, #[[$TMEM1]], #ttng.tensor_memory>
308312
// CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3
309-
tt.func public @mmav5_block_scaled(%a: tensor<128x64xi8, #blocked2>, %scale_a: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xi8, #blocked>, %scale_b: tensor<128x2xi8, #blocked1>) -> tensor<128x128xf32, #blocked> {
313+
tt.func public @mmav5_block_scaled(%a: tensor<128x64xi8, #blocked2>, %scale_a_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>, %b: tensor<64x128xi8, #blocked>, %scale_b_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>) -> tensor<128x128xf32, #blocked> {
310314
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
315+
%scale_a = tt.load %scale_a_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>
316+
%scale_b = tt.load %scale_b_ptr: tensor<128x2x!tt.ptr<i8>, #blocked1>
311317
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x64xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xi8, #blocked>, tensor<128x2xi8, #blocked1> -> tensor<128x128xf32, #blocked>
312318
tt.return %d : tensor<128x128xf32, #blocked>
313319
}
@@ -389,3 +395,39 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
389395
tt.return %d : tensor<128x128xf32, #blocked>
390396
}
391397
}
398+
399+
// -----
400+
401+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
402+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>
403+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
404+
#blocked3 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 1, 8, 4, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 3, 2, 1, 0]}>
405+
#blocked4 = #ttg.blocked<{sizePerThread = [1, 1, 1, 1, 4], threadsPerWarp = [1, 4, 8, 1, 1], warpsPerCTA = [1, 1, 4, 1, 1], order = [4, 1, 2, 3, 0]}>
406+
#linear = #ttg.linear<{register = [[0, 1], [0, 2]], lane = [[32, 0], [64, 0], [1, 0], [2, 0], [4, 0]], warp = [[8, 0], [16, 0]], block = []}>
407+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
408+
// CHECK-DAG: #[[$TMEM:.+]] = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
409+
// CHECK-DAG: #[[$TMEM1:.+]] = #ttng.tensor_memory_scales_encoding
410+
// CHECK-LABEL: mmav5_block_scaled_5d_scale
411+
// CHECK-DAG: %[[TRUE:.+]] = arith.constant true
412+
// CHECK-DAG: %[[B:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem
413+
// CHECK-DAG: %[[A:.+]] = ttg.local_alloc %{{.*}} : (tensor<128x128xi8, #{{.*}}>) -> !ttg.memdesc<128x128xi8, #{{.*}}, #smem
414+
// CHECK-DAG: %[[SCALEA_LOCAL:.+]] = ttg.local_alloc
415+
// CHECK: ttg.local_load %[[SCALEA_LOCAL]]
416+
// CHECK-DAG: %[[SCALEB_LOCAL:.+]] = ttg.local_alloc
417+
// CHECK: ttg.local_load %[[SCALEB_LOCAL]]
418+
// CHECK-DAG: %[[ACC:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x128xf32, #{{.*}}>) -> !ttg.memdesc<128x128xf32, #{{.*}}, #ttng.tensor_memory, mutable>
419+
// CHECK: %[[SCALEA:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory>
420+
// CHECK: %[[SCALEB:.+]] = ttng.tmem_alloc %{{.*}} : (tensor<128x4xi8, #{{.*}}>) -> !ttg.memdesc<128x4xi8, #[[$TMEM1]], #ttng.tensor_memory>
421+
// CHECK: ttng.tc_gen5_mma_scaled %[[A]], %[[B]], %[[ACC]], %[[SCALEA]], %[[SCALEB]], %[[TRUE]], %[[TRUE]] lhs = e4m3 rhs = e4m3
422+
tt.func public @mmav5_block_scaled_5d_scale(%a: tensor<128x128xi8, #blocked2>, %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>, %b: tensor<128x128xi8, #blocked>, %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>) -> tensor<128x128xf32, #blocked> {
423+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
424+
%scale_a_5d = tt.load %scale_a_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>
425+
%scale_a_trans = tt.trans %scale_a_5d {order = array<i32: 0, 3, 2, 1, 4>} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4>
426+
%scale_a = tt.reshape %scale_a_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear>
427+
%scale_b_5d = tt.load %scale_b_ptr: tensor<1x1x32x4x4x!tt.ptr<i8>, #blocked3>
428+
%scale_b_trans = tt.trans %scale_b_5d {order = array<i32: 0, 3, 2, 1, 4>} : tensor<1x1x32x4x4xi8, #blocked3> -> tensor<1x4x32x1x4xi8, #blocked4>
429+
%scale_b = tt.reshape %scale_b_trans : tensor<1x4x32x1x4xi8, #blocked4> -> tensor<128x4xi8, #linear>
430+
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x128xi8, #blocked2>, tensor<128x4xi8, #linear> * tensor<128x128xi8, #blocked>, tensor<128x4xi8, #linear> -> tensor<128x128xf32, #blocked>
431+
tt.return %d : tensor<128x128xf32, #blocked>
432+
}
433+
}

0 commit comments

Comments
 (0)