Skip to content

Commit 4fb6cc2

Browse files
committed
[AMD] Added local_alloc refinement
1 parent 2abf233 commit 4fb6cc2

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: triton-opt %s -split-input-file -triton-amdgpu-refine-ops='arch=gfx942' -canonicalize | FileCheck %s
2+
3+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}>
4+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}>
5+
#smem = #ttg.shared_memory
6+
7+
8+
// CHECK-LABEL: @local_alloc_refinement
9+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
10+
tt.func public @local_alloc_refinement(%arg0: tensor<64x16xf16, #blocked>) attributes {noinline = false} {
11+
12+
// CHECK: [[OFFSET_12:%.*]] = arith.constant 12 : i32
13+
// CHECK: [[OFFSET_8:%.*]] = arith.constant 8 : i32
14+
// CHECK: [[OFFSET_4:%.*]] = arith.constant 4 : i32
15+
// CHECK: [[OFFSET_0:%.*]] = arith.constant 0 : i32
16+
// CHECK: [[ALLOC:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
17+
// CHECK: [[SUBVIEW_0:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_0]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
18+
// CHECK: [[SLICE_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
19+
// CHECK: ttg.local_store [[SLICE_0]], [[SUBVIEW_0]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
20+
// CHECK: [[SUBVIEW_1:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_4]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
21+
// CHECK: [[SLICE_1:%.*]] = amdgpu.extract_slice %arg0 [0, 4] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
22+
// CHECK: ttg.local_store [[SLICE_1]], [[SUBVIEW_1]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
23+
// CHECK: [[SUBVIEW_2:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_8]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
24+
// CHECK: [[SLICE_2:%.*]] = amdgpu.extract_slice %arg0 [0, 8] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
25+
// CHECK: ttg.local_store [[SLICE_2]], [[SUBVIEW_2]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
26+
// CHECK: [[SUBVIEW_3:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_12]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
27+
// CHECK: [[SLICE_3:%.*]] = amdgpu.extract_slice %arg0 [0, 12] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked>
28+
// CHECK: ttg.local_store [[SLICE_3]], [[SUBVIEW_3]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16>
29+
// CHECK: amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
30+
// CHECK: ttg.local_dealloc [[ALLOC]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>
31+
%0 = ttg.local_alloc %arg0 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem>
32+
amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant<refine_ops>}
33+
tt.return
34+
}
35+
}

third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,92 @@ struct LocalStoreOpPattern
692692
}
693693
};
694694

695+
struct LocalAllocOpPattern
696+
: public RefineRewritePattern<triton::gpu::LocalAllocOp> {
697+
LocalAllocOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
698+
: RefineRewritePattern(context, benefit) {}
699+
700+
// Refines non-mutable memory `LocalAllocOp` ops. The non-mutable variant
701+
// is used as a not-pipelined version of the op. To be able to refine the op,
702+
// we replace the non-mutable variant with the mutable one that requires
703+
// `LocalDeallocOp` after the last user of the result of `LocalAllocOp`.
704+
// The `LocalStoreOp` is used to move data from registers to the LDS.
705+
// The refinement of the resulting `LocalStoreOp` is left to the dedicated
706+
// rewrite pattern.
707+
LogicalResult apply(triton::gpu::LocalAllocOp op,
708+
PatternRewriter &rewriter) const override {
709+
auto ctx = op->getContext();
710+
auto loc = op.getLoc();
711+
auto alignment = op.getAlignment();
712+
713+
if (op->getNumOperands() == 0)
714+
return failure();
715+
716+
auto allocType = cast<triton::gpu::MemDescType>(op.getResult().getType());
717+
auto origShape = allocType.getShape();
718+
SmallVector<int64_t> newShape(origShape);
719+
SmallVector<int64_t> newAllocShape(allocType.getAllocShape());
720+
721+
if (newShape.size() == 2) {
722+
newShape.insert(newShape.begin(), 1);
723+
}
724+
assert(newShape.size() == 3);
725+
726+
if (newAllocShape.size() == 2) {
727+
newAllocShape.insert(newAllocShape.begin(), 1);
728+
}
729+
assert(newAllocShape.size() == 3);
730+
731+
auto newAllocType = triton::gpu::MemDescType::get(
732+
ctx, newShape, allocType.getElementType(), allocType.getEncoding(),
733+
allocType.getMemorySpace(),
734+
/*mutableMemory=*/true, newAllocShape);
735+
736+
rewriter.setInsertionPointAfter(op);
737+
auto newAlloc =
738+
rewriter.create<triton::gpu::LocalAllocOp>(loc, newAllocType);
739+
newAlloc->setAttrs(op->getAttrs());
740+
741+
auto newSubviewType = triton::gpu::MemDescType::get(
742+
ctx, origShape, allocType.getElementType(), allocType.getEncoding(),
743+
allocType.getMemorySpace(),
744+
/*mutableMemory=*/true, newAllocShape);
745+
746+
auto offset = createOffset({}, {0, 0, 0}, rewriter, loc);
747+
auto newSubview = rewriter.create<ttg::MemDescSubviewOp>(
748+
loc, newSubviewType, newAlloc, offset);
749+
rewriter.create<ttg::LocalStoreOp>(loc, op.getOperand(0), newSubview);
750+
751+
mlir::Operation *lastUser = nullptr;
752+
for (auto *user : op.getResult().getUsers()) {
753+
if (!lastUser || user->isBeforeInBlock(lastUser) == false) {
754+
lastUser = user;
755+
}
756+
}
757+
758+
Operation &lastOpInBlock = op->getBlock()->back();
759+
const bool noUsers = lastUser == nullptr;
760+
const bool isLastInstr = noUsers
761+
? false
762+
: mlir::OperationEquivalence::isEquivalentTo(
763+
&lastOpInBlock, lastUser,
764+
mlir::OperationEquivalence::Flags::None);
765+
;
766+
if (noUsers || isLastInstr) {
767+
rewriter.setInsertionPoint(&lastOpInBlock);
768+
} else {
769+
rewriter.setInsertionPointAfter(lastUser);
770+
}
771+
772+
rewriter.create<triton::gpu::LocalDeallocOp>(loc, newAlloc.getResult());
773+
774+
op.replaceAllUsesWith(newSubview.getResult());
775+
rewriter.eraseOp(op);
776+
777+
return success();
778+
}
779+
};
780+
695781
struct ReduceOpPattern : public RefineRewritePattern<triton::ReduceOp> {
696782
ReduceOpPattern(MLIRContext *context, PatternBenefit benefit = 1)
697783
: RefineRewritePattern(context, benefit) {}
@@ -1134,6 +1220,10 @@ struct TritonAMDGPURefineOps
11341220
return signalPassFailure();
11351221
}
11361222

1223+
RewritePatternSet primaryPatterns(context);
1224+
primaryPatterns.add<LocalAllocOpPattern>(context, /*benefit=*/1);
1225+
walkAndApplyPatterns(func, std::move(primaryPatterns));
1226+
11371227
RewritePatternSet patterns(context);
11381228
patterns.add<LocalLoadOpPattern>(context, /*benefit=*/1);
11391229
patterns.add<DotOpPattern>(context, /*benefit=*/1);

0 commit comments

Comments
 (0)