Skip to content

Commit 31bfb67

Browse files
committed
Revert "Revert "[Triton][Allocation] Enable getScratchValueSize specialization (#5070)""
This reverts commit c17a0fb.
1 parent 29d27d7 commit 31bfb67

File tree

4 files changed

+123
-76
lines changed

4 files changed

+123
-76
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@ namespace mlir {
1818
namespace triton {
1919
class AllocationAnalysis;
2020

21+
/// Callback to allow backends to specify target-specific scratch sizes for
22+
/// some operations.
23+
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
24+
25+
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
26+
2127
// To convert a tensor from one layout to another, we need to allocate a
2228
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
2329
// require multiple iterations, with each iteration involving multiple
@@ -141,7 +147,8 @@ class Allocation {
141147
explicit Allocation(Operation *operation) : operation(operation) {}
142148

143149
/// Runs allocation analysis on the given top-level operation.
144-
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
150+
void run(FuncAllocMapT &funcAllocMap,
151+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
145152

146153
/// Returns the operation this analysis was constructed from.
147154
Operation *getOperation() const { return operation; }
@@ -255,17 +262,18 @@ class ModuleAllocation : public CallGraph<Allocation> {
255262
public:
256263
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
257264

258-
template <typename AllocationAnalysis = triton::AllocationAnalysis>
259-
static ModuleAllocation get(ModuleOp moduleOp) {
260-
ModuleAllocation res(moduleOp);
261-
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
265+
ModuleAllocation(ModuleOp moduleOp,
266+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
267+
triton::defaultAllocationAnalysisScratchSizeFn)
268+
: CallGraph<Allocation>(moduleOp) {
269+
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
262270
// Pre-order edge walk callback
263271
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
264272
// Post-order node walk callback
265273
[&](FunctionOpInterface funcOp) {
266274
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
267275
if (inserted)
268-
iter->second.template run<AllocationAnalysis>(res.funcMap);
276+
iter->second.run(funcMap, scratchSizeGetter);
269277
});
270278
return res;
271279
}

lib/Analysis/Allocation.cpp

Lines changed: 71 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
118118
return scratchConfig;
119119
}
120120

121+
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
122+
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
123+
ReduceOpHelper helper(reduceOp);
124+
return helper.getScratchSizeInBytes();
125+
}
126+
if (auto scanOp = dyn_cast<ScanOp>(op)) {
127+
ScanLoweringHelper helper(scanOp);
128+
return helper.getScratchSizeInBytes();
129+
}
130+
if (auto histogram = dyn_cast<HistogramOp>(op)) {
131+
auto dstTy = histogram.getType();
132+
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
133+
op->getParentOfType<ModuleOp>());
134+
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
135+
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
136+
}
137+
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
138+
auto srcTy = cvtLayout.getSrc().getType();
139+
auto dstTy = cvtLayout.getType();
140+
auto srcEncoding = srcTy.getEncoding();
141+
auto dstEncoding = dstTy.getEncoding();
142+
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
143+
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
144+
// Conversions from/to shared memory do not need scratch memory.
145+
return 0;
146+
}
147+
// ConvertLayoutOp with both input/output non-shared_layout
148+
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
149+
// also possible to realize it with other approaches in restricted
150+
// conditions, such as warp-shuffle
151+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
152+
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
153+
return isa<PointerType>(srcTy.getElementType())
154+
? elems * kPtrBitWidth / 8
155+
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
156+
}
157+
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
158+
auto value = op->getOperand(0);
159+
// only scalar requires scratch memory
160+
// make it explicit for readability
161+
if (dyn_cast<RankedTensorType>(value.getType())) {
162+
return 0;
163+
}
164+
auto smemShape = getRepShapeForAtomic(op->getResult(0));
165+
auto elems = getNumScratchElements(smemShape);
166+
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
167+
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
168+
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
169+
}
170+
if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
171+
constexpr int32_t kTMASize = 128;
172+
return kTMASize;
173+
}
174+
return 0;
175+
}
176+
121177
class AllocationAnalysis {
122178
public:
123179
AllocationAnalysis(Operation *operation,
124180
Allocation::FuncAllocMapT *funcAllocMap,
125-
Allocation *allocation)
181+
Allocation *allocation,
182+
AllocationAnalysisScratchSizeFn scratchSizeGetter)
126183
: operation(operation), funcAllocMap(funcAllocMap),
127-
allocation(allocation) {
184+
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
128185
run();
129186
}
130187

@@ -177,77 +234,19 @@ class AllocationAnalysis {
177234

178235
/// Initializes temporary shared memory for a given operation.
179236
void getScratchValueSize(Operation *op) {
180-
const size_t scratchAlignment = 128;
181-
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
182-
ReduceOpHelper helper(reduceOp);
183-
unsigned bytes = helper.getScratchSizeInBytes();
184-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
185-
scratchAlignment);
186-
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
187-
ScanLoweringHelper helper(scanOp);
188-
unsigned bytes = helper.getScratchSizeInBytes();
189-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
190-
scratchAlignment);
191-
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
192-
auto dstTy = histogram.getType();
193-
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
194-
op->getParentOfType<ModuleOp>());
195-
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
196-
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
197-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
198-
scratchAlignment);
199-
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
200-
auto srcTy = cvtLayout.getSrc().getType();
201-
auto dstTy = cvtLayout.getType();
202-
auto srcEncoding = srcTy.getEncoding();
203-
auto dstEncoding = dstTy.getEncoding();
204-
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
205-
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
206-
// Conversions from/to shared memory do not need scratch memory.
207-
return;
208-
}
209-
// ConvertLayoutOp with both input/output non-shared_layout
210-
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
211-
// also possible to realize it with other approaches in restricted
212-
// conditions, such as warp-shuffle
213-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
214-
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
215-
auto bytes =
216-
isa<PointerType>(srcTy.getElementType())
217-
? elems * kPtrBitWidth / 8
218-
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
219-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
220-
scratchAlignment);
221-
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
222-
auto value = op->getOperand(0);
223-
// only scalar requires scratch memory
224-
// make it explicit for readability
225-
if (dyn_cast<RankedTensorType>(value.getType())) {
226-
// nothing to do
227-
} else {
228-
auto smemShape = getRepShapeForAtomic(op->getResult(0));
229-
auto elems = getNumScratchElements(smemShape);
230-
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
231-
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
232-
auto bytes =
233-
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
234-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
235-
scratchAlignment);
236-
}
237-
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
237+
constexpr size_t scratchAlignment = 128;
238+
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
238239
auto callable = callOp.resolveCallable();
239240
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
240241
auto *funcAlloc = &(*funcAllocMap)[funcOp];
241242
auto bytes = funcAlloc->getSharedMemorySize();
242243
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
243244
scratchAlignment);
244-
} else if (auto createTensormap =
245-
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
246-
constexpr int32_t kTMASize = 128;
247-
constexpr int32_t kTMAAlign = 128;
248-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
249-
kTMAAlign);
245+
return;
250246
}
247+
unsigned bytes = scratchSizeGetter(op);
248+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
249+
scratchAlignment);
251250
}
252251

253252
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
@@ -547,13 +546,16 @@ class AllocationAnalysis {
547546
Allocation::FuncAllocMapT *funcAllocMap;
548547
Allocation *allocation;
549548
BufferRangeMapT bufferRange;
549+
AllocationAnalysisScratchSizeFn scratchSizeGetter;
550550
};
551551

552552
} // namespace triton
553553

554-
template <>
555-
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
556-
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
554+
void Allocation::run(
555+
FuncAllocMapT &funcAllocMap,
556+
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
557+
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
558+
scratchSizeGetter);
557559
}
558560

559561
std::map<Operation *, SmallVector<Allocation::BufferId>>

test/Analysis/test-allocation.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
2+
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128
3+
4+
// Check there are no lines with a size different to 128 and we have at least a line with size 128.
5+
6+
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
7+
// CHECK-128: scratch offset = {{.*}}, size = 128
8+
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
29

310
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
411
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>

test/lib/Analysis/TestAllocation.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,21 +5,42 @@ using namespace mlir;
55

66
namespace {
77

8+
unsigned getScratchSize128(Operation *) { return 128; }
9+
10+
enum class GetScratchSizeFunction {
11+
None,
12+
ValidConstant,
13+
};
14+
815
struct TestAllocationPass
916
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {
1017

1118
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
1219

20+
TestAllocationPass() = default;
21+
TestAllocationPass(const TestAllocationPass &other)
22+
: PassWrapper<TestAllocationPass, OperationPass<ModuleOp>>(other) {}
23+
1324
StringRef getArgument() const final { return "test-print-allocation"; }
1425
StringRef getDescription() const final {
1526
return "print the result of the allocation pass";
1627
}
1728

29+
ModuleAllocation getModuleAllocation() {
30+
switch (getScratchSizeFunction) {
31+
case GetScratchSizeFunction::None:
32+
return {getOperation()};
33+
case GetScratchSizeFunction::ValidConstant:
34+
return {getOperation(), getScratchSize128};
35+
}
36+
llvm_unreachable("Unhandled case");
37+
}
38+
1839
void runOnOperation() override {
1940
auto &os = llvm::errs();
2041
ModuleOp moduleOp = getOperation();
2142
// Convert to std::string can remove quotes from opName
22-
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
43+
ModuleAllocation moduleAllocation = getModuleAllocation();
2344
moduleOp.walk([&](triton::FuncOp funcOp) {
2445
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
2546
os << opName << "\n";
@@ -48,6 +69,15 @@ struct TestAllocationPass
4869
os << "size = " << allocation->getSharedMemorySize() << "\n";
4970
});
5071
}
72+
73+
Option<GetScratchSizeFunction> getScratchSizeFunction{
74+
*this, "get-scratch-size-function",
75+
llvm::cl::desc("Custom scratch size function to use"),
76+
llvm::cl::init(GetScratchSizeFunction::None),
77+
llvm::cl::values(
78+
clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"),
79+
clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant",
80+
"ValidConstant"))};
5181
};
5282

5383
} // namespace

0 commit comments

Comments
 (0)