Skip to content

Commit c17a0fb

Browse files
Revert "[Triton][Allocation] Enable getScratchValueSize specialization (#5070)"
This reverts commit 32b0fce.
1 parent cf5ea93 commit c17a0fb

File tree

4 files changed

+76
-123
lines changed

4 files changed

+76
-123
lines changed

include/triton/Analysis/Allocation.h

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,6 @@ namespace mlir {
1818
namespace triton {
1919
class AllocationAnalysis;
2020

21-
/// Callback to allow backends to specify target-specific scratch sizes for
22-
/// some operations.
23-
using AllocationAnalysisScratchSizeFn = std::function<unsigned(Operation *)>;
24-
25-
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op);
26-
2721
// To convert a tensor from one layout to another, we need to allocate a
2822
// temporary buffer (i.e., scratch buffer) in shared memory. The conversion may
2923
// require multiple iterations, with each iteration involving multiple
@@ -147,8 +141,7 @@ class Allocation {
147141
explicit Allocation(Operation *operation) : operation(operation) {}
148142

149143
/// Runs allocation analysis on the given top-level operation.
150-
void run(FuncAllocMapT &funcAllocMap,
151-
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter);
144+
template <typename AllocationAnalysis> void run(FuncAllocMapT &funcAllocMap);
152145

153146
/// Returns the operation this analysis was constructed from.
154147
Operation *getOperation() const { return operation; }
@@ -262,18 +255,17 @@ class ModuleAllocation : public CallGraph<Allocation> {
262255
public:
263256
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
264257

265-
ModuleAllocation(ModuleOp moduleOp,
266-
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter =
267-
triton::defaultAllocationAnalysisScratchSizeFn)
268-
: CallGraph<Allocation>(moduleOp) {
269-
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
258+
template <typename AllocationAnalysis = triton::AllocationAnalysis>
259+
static ModuleAllocation get(ModuleOp moduleOp) {
260+
ModuleAllocation res(moduleOp);
261+
res.walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
270262
// Pre-order edge walk callback
271263
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
272264
// Post-order node walk callback
273265
[&](FunctionOpInterface funcOp) {
274266
auto [iter, inserted] = res.funcMap.try_emplace(funcOp, funcOp);
275267
if (inserted)
276-
iter->second.run(funcMap, scratchSizeGetter);
268+
iter->second.template run<AllocationAnalysis>(res.funcMap);
277269
});
278270
return res;
279271
}

lib/Analysis/Allocation.cpp

Lines changed: 69 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -118,70 +118,13 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
118118
return scratchConfig;
119119
}
120120

121-
unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
122-
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
123-
ReduceOpHelper helper(reduceOp);
124-
return helper.getScratchSizeInBytes();
125-
}
126-
if (auto scanOp = dyn_cast<ScanOp>(op)) {
127-
ScanLoweringHelper helper(scanOp);
128-
return helper.getScratchSizeInBytes();
129-
}
130-
if (auto histogram = dyn_cast<HistogramOp>(op)) {
131-
auto dstTy = histogram.getType();
132-
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
133-
op->getParentOfType<ModuleOp>());
134-
return std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
135-
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
136-
}
137-
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
138-
auto srcTy = cvtLayout.getSrc().getType();
139-
auto dstTy = cvtLayout.getType();
140-
auto srcEncoding = srcTy.getEncoding();
141-
auto dstEncoding = dstTy.getEncoding();
142-
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
143-
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
144-
// Conversions from/to shared memory do not need scratch memory.
145-
return 0;
146-
}
147-
// ConvertLayoutOp with both input/output non-shared_layout
148-
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
149-
// also possible to realize it with other approaches in restricted
150-
// conditions, such as warp-shuffle
151-
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
152-
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
153-
return isa<PointerType>(srcTy.getElementType())
154-
? elems * kPtrBitWidth / 8
155-
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
156-
}
157-
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
158-
auto value = op->getOperand(0);
159-
// only scalar requires scratch memory
160-
// make it explicit for readability
161-
if (dyn_cast<RankedTensorType>(value.getType())) {
162-
return 0;
163-
}
164-
auto smemShape = getRepShapeForAtomic(op->getResult(0));
165-
auto elems = getNumScratchElements(smemShape);
166-
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
167-
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
168-
return elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
169-
}
170-
if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
171-
constexpr int32_t kTMASize = 128;
172-
return kTMASize;
173-
}
174-
return 0;
175-
}
176-
177121
class AllocationAnalysis {
178122
public:
179123
AllocationAnalysis(Operation *operation,
180124
Allocation::FuncAllocMapT *funcAllocMap,
181-
Allocation *allocation,
182-
AllocationAnalysisScratchSizeFn scratchSizeGetter)
125+
Allocation *allocation)
183126
: operation(operation), funcAllocMap(funcAllocMap),
184-
allocation(allocation), scratchSizeGetter(scratchSizeGetter) {
127+
allocation(allocation) {
185128
run();
186129
}
187130

@@ -234,19 +177,77 @@ class AllocationAnalysis {
234177

235178
/// Initializes temporary shared memory for a given operation.
236179
void getScratchValueSize(Operation *op) {
237-
constexpr size_t scratchAlignment = 128;
238-
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
180+
const size_t scratchAlignment = 128;
181+
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
182+
ReduceOpHelper helper(reduceOp);
183+
unsigned bytes = helper.getScratchSizeInBytes();
184+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
185+
scratchAlignment);
186+
} else if (auto scanOp = dyn_cast<ScanOp>(op)) {
187+
ScanLoweringHelper helper(scanOp);
188+
unsigned bytes = helper.getScratchSizeInBytes();
189+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
190+
scratchAlignment);
191+
} else if (auto histogram = dyn_cast<HistogramOp>(op)) {
192+
auto dstTy = histogram.getType();
193+
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp(
194+
op->getParentOfType<ModuleOp>());
195+
auto bytes = std::max<int>(dstTy.getNumElements(), threadsPerWarp) *
196+
std::max<int>(8, dstTy.getElementTypeBitWidth()) / 8;
197+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
198+
scratchAlignment);
199+
} else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
200+
auto srcTy = cvtLayout.getSrc().getType();
201+
auto dstTy = cvtLayout.getType();
202+
auto srcEncoding = srcTy.getEncoding();
203+
auto dstEncoding = dstTy.getEncoding();
204+
if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
205+
mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
206+
// Conversions from/to shared memory do not need scratch memory.
207+
return;
208+
}
209+
// ConvertLayoutOp with both input/output non-shared_layout
210+
// TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
211+
// also possible to realize it with other approaches in restricted
212+
// conditions, such as warp-shuffle
213+
auto scratchConfig = getScratchConfigForCvt(srcTy, dstTy);
214+
auto elems = getNumScratchElements(scratchConfig.paddedRepShape);
215+
auto bytes =
216+
isa<PointerType>(srcTy.getElementType())
217+
? elems * kPtrBitWidth / 8
218+
: elems * std::max<int>(8, srcTy.getElementTypeBitWidth()) / 8;
219+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
220+
scratchAlignment);
221+
} else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
222+
auto value = op->getOperand(0);
223+
// only scalar requires scratch memory
224+
// make it explicit for readability
225+
if (dyn_cast<RankedTensorType>(value.getType())) {
226+
// nothing to do
227+
} else {
228+
auto smemShape = getRepShapeForAtomic(op->getResult(0));
229+
auto elems = getNumScratchElements(smemShape);
230+
auto elemTy = cast<PointerType>(value.getType()).getPointeeType();
231+
assert(!isa<PointerType>(elemTy) && "unexpected pointer type");
232+
auto bytes =
233+
elems * std::max<int>(8, elemTy.getIntOrFloatBitWidth()) / 8;
234+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
235+
scratchAlignment);
236+
}
237+
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
239238
auto callable = callOp.resolveCallable();
240239
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
241240
auto *funcAlloc = &(*funcAllocMap)[funcOp];
242241
auto bytes = funcAlloc->getSharedMemorySize();
243242
maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
244243
scratchAlignment);
245-
return;
244+
} else if (auto createTensormap =
245+
dyn_cast<ExperimentalTensormapCreateOp>(op)) {
246+
constexpr int32_t kTMASize = 128;
247+
constexpr int32_t kTMAAlign = 128;
248+
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize,
249+
kTMAAlign);
246250
}
247-
unsigned bytes = scratchSizeGetter(op);
248-
maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
249-
scratchAlignment);
250251
}
251252

252253
void getValueAlias(Value value, SharedMemoryAliasAnalysis &analysis) {
@@ -546,16 +547,13 @@ class AllocationAnalysis {
546547
Allocation::FuncAllocMapT *funcAllocMap;
547548
Allocation *allocation;
548549
BufferRangeMapT bufferRange;
549-
AllocationAnalysisScratchSizeFn scratchSizeGetter;
550550
};
551551

552552
} // namespace triton
553553

554-
void Allocation::run(
555-
FuncAllocMapT &funcAllocMap,
556-
triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
557-
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this,
558-
scratchSizeGetter);
554+
template <>
555+
void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
556+
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
559557
}
560558

561559
std::map<Operation *, SmallVector<Allocation::BufferId>>

test/Analysis/test-allocation.mlir

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,4 @@
11
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation 2>&1 | FileCheck %s
2-
// RUN: triton-opt %s -split-input-file --mlir-disable-threading -test-print-allocation="get-scratch-size-function=ValidConstant" 2>&1 | FileCheck %s --check-prefix=CHECK-128
3-
4-
// Check there are no lines with a size different to 128 and we have at least a line with size 128.
5-
6-
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
7-
// CHECK-128: scratch offset = {{.*}}, size = 128
8-
// CHECK-128-NOT: scratch offset = {{.*}}, size = {{^(128)}}
92

103
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
114
#sliceAd0 = #triton_gpu.slice<{dim = 0, parent = #AL}>

test/lib/Analysis/TestAllocation.cpp

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,21 @@ using namespace mlir;
55

66
namespace {
77

8-
unsigned getScratchSize128(Operation *) { return 128; }
9-
10-
enum class GetScratchSizeFunction {
11-
None,
12-
ValidConstant,
13-
};
14-
158
struct TestAllocationPass
169
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {
1710

1811
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
1912

20-
TestAllocationPass() = default;
21-
TestAllocationPass(const TestAllocationPass &other)
22-
: PassWrapper<TestAllocationPass, OperationPass<ModuleOp>>(other) {}
23-
2413
StringRef getArgument() const final { return "test-print-allocation"; }
2514
StringRef getDescription() const final {
2615
return "print the result of the allocation pass";
2716
}
2817

29-
ModuleAllocation getModuleAllocation() {
30-
switch (getScratchSizeFunction) {
31-
case GetScratchSizeFunction::None:
32-
return {getOperation()};
33-
case GetScratchSizeFunction::ValidConstant:
34-
return {getOperation(), getScratchSize128};
35-
}
36-
llvm_unreachable("Unhandled case");
37-
}
38-
3918
void runOnOperation() override {
4019
auto &os = llvm::errs();
4120
ModuleOp moduleOp = getOperation();
4221
// Convert to std::string can remove quotes from opName
43-
ModuleAllocation moduleAllocation = getModuleAllocation();
22+
ModuleAllocation moduleAllocation = ModuleAllocation::get(moduleOp);
4423
moduleOp.walk([&](triton::FuncOp funcOp) {
4524
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
4625
os << opName << "\n";
@@ -69,15 +48,6 @@ struct TestAllocationPass
6948
os << "size = " << allocation->getSharedMemorySize() << "\n";
7049
});
7150
}
72-
73-
Option<GetScratchSizeFunction> getScratchSizeFunction{
74-
*this, "get-scratch-size-function",
75-
llvm::cl::desc("Custom scratch size function to use"),
76-
llvm::cl::init(GetScratchSizeFunction::None),
77-
llvm::cl::values(
78-
clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"),
79-
clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant",
80-
"ValidConstant"))};
8151
};
8252

8353
} // namespace

0 commit comments

Comments
 (0)