@@ -118,13 +118,70 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
118118 return scratchConfig;
119119}
120120
121+ unsigned defaultAllocationAnalysisScratchSizeFn (Operation *op) {
122+ if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
123+ ReduceOpHelper helper (reduceOp);
124+ return helper.getScratchSizeInBytes ();
125+ }
126+ if (auto scanOp = dyn_cast<ScanOp>(op)) {
127+ ScanLoweringHelper helper (scanOp);
128+ return helper.getScratchSizeInBytes ();
129+ }
130+ if (auto histogram = dyn_cast<HistogramOp>(op)) {
131+ auto dstTy = histogram.getType ();
132+ int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
133+ op->getParentOfType <ModuleOp>());
134+ return std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
135+ std::max<int >(8 , dstTy.getElementTypeBitWidth ()) / 8 ;
136+ }
137+ if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
138+ auto srcTy = cvtLayout.getSrc ().getType ();
139+ auto dstTy = cvtLayout.getType ();
140+ auto srcEncoding = srcTy.getEncoding ();
141+ auto dstEncoding = dstTy.getEncoding ();
142+ if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
143+ mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
144+ // Conversions from/to shared memory do not need scratch memory.
145+ return 0 ;
146+ }
147+ // ConvertLayoutOp with both input/output non-shared_layout
148+ // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
149+ // also possible to realize it with other approaches in restricted
150+ // conditions, such as warp-shuffle
151+ auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
152+ auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
153+ return isa<PointerType>(srcTy.getElementType ())
154+ ? elems * kPtrBitWidth / 8
155+ : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
156+ }
157+ if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
158+ auto value = op->getOperand (0 );
159+ // only scalar requires scratch memory
160+ // make it explicit for readability
161+ if (dyn_cast<RankedTensorType>(value.getType ())) {
162+ return 0 ;
163+ }
164+ auto smemShape = getRepShapeForAtomic (op->getResult (0 ));
165+ auto elems = getNumScratchElements (smemShape);
166+ auto elemTy = cast<PointerType>(value.getType ()).getPointeeType ();
167+ assert (!isa<PointerType>(elemTy) && " unexpected pointer type" );
168+ return elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
169+ }
170+ if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
171+ constexpr int32_t kTMASize = 128 ;
172+ return kTMASize ;
173+ }
174+ return 0 ;
175+ }
176+
121177class AllocationAnalysis {
122178public:
123179 AllocationAnalysis (Operation *operation,
124180 Allocation::FuncAllocMapT *funcAllocMap,
125- Allocation *allocation)
181+ Allocation *allocation,
182+ AllocationAnalysisScratchSizeFn scratchSizeGetter)
126183 : operation(operation), funcAllocMap(funcAllocMap),
127- allocation (allocation) {
184+ allocation (allocation), scratchSizeGetter(scratchSizeGetter) {
128185 run ();
129186 }
130187
@@ -177,77 +234,19 @@ class AllocationAnalysis {
177234
178235 // / Initializes temporary shared memory for a given operation.
179236 void getScratchValueSize (Operation *op) {
180- const size_t scratchAlignment = 128 ;
181- if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
182- ReduceOpHelper helper (reduceOp);
183- unsigned bytes = helper.getScratchSizeInBytes ();
184- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
185- scratchAlignment);
186- } else if (auto scanOp = dyn_cast<ScanOp>(op)) {
187- ScanLoweringHelper helper (scanOp);
188- unsigned bytes = helper.getScratchSizeInBytes ();
189- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
190- scratchAlignment);
191- } else if (auto histogram = dyn_cast<HistogramOp>(op)) {
192- auto dstTy = histogram.getType ();
193- int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
194- op->getParentOfType <ModuleOp>());
195- auto bytes = std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
196- std::max<int >(8 , dstTy.getElementTypeBitWidth ()) / 8 ;
197- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
198- scratchAlignment);
199- } else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
200- auto srcTy = cvtLayout.getSrc ().getType ();
201- auto dstTy = cvtLayout.getType ();
202- auto srcEncoding = srcTy.getEncoding ();
203- auto dstEncoding = dstTy.getEncoding ();
204- if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
205- mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
206- // Conversions from/to shared memory do not need scratch memory.
207- return ;
208- }
209- // ConvertLayoutOp with both input/output non-shared_layout
210- // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
211- // also possible to realize it with other approaches in restricted
212- // conditions, such as warp-shuffle
213- auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
214- auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
215- auto bytes =
216- isa<PointerType>(srcTy.getElementType ())
217- ? elems * kPtrBitWidth / 8
218- : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
219- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
220- scratchAlignment);
221- } else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
222- auto value = op->getOperand (0 );
223- // only scalar requires scratch memory
224- // make it explicit for readability
225- if (dyn_cast<RankedTensorType>(value.getType ())) {
226- // nothing to do
227- } else {
228- auto smemShape = getRepShapeForAtomic (op->getResult (0 ));
229- auto elems = getNumScratchElements (smemShape);
230- auto elemTy = cast<PointerType>(value.getType ()).getPointeeType ();
231- assert (!isa<PointerType>(elemTy) && " unexpected pointer type" );
232- auto bytes =
233- elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
234- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
235- scratchAlignment);
236- }
237- } else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
237+ constexpr size_t scratchAlignment = 128 ;
238+ if (auto callOp = dyn_cast<CallOpInterface>(op)) {
238239 auto callable = callOp.resolveCallable ();
239240 auto funcOp = dyn_cast<FunctionOpInterface>(callable);
240241 auto *funcAlloc = &(*funcAllocMap)[funcOp];
241242 auto bytes = funcAlloc->getSharedMemorySize ();
242243 maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
243244 scratchAlignment);
244- } else if (auto createTensormap =
245- dyn_cast<ExperimentalTensormapCreateOp>(op)) {
246- constexpr int32_t kTMASize = 128 ;
247- constexpr int32_t kTMAAlign = 128 ;
248- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize ,
249- kTMAAlign );
245+ return ;
250246 }
247+ unsigned bytes = scratchSizeGetter (op);
248+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
249+ scratchAlignment);
251250 }
252251
253252 void getValueAlias (Value value, SharedMemoryAliasAnalysis &analysis) {
@@ -547,13 +546,16 @@ class AllocationAnalysis {
547546 Allocation::FuncAllocMapT *funcAllocMap;
548547 Allocation *allocation;
549548 BufferRangeMapT bufferRange;
549+ AllocationAnalysisScratchSizeFn scratchSizeGetter;
550550};
551551
552552} // namespace triton
553553
554- template <>
555- void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
556- triton::AllocationAnalysis (getOperation (), &funcAllocMap, this );
554+ void Allocation::run (
555+ FuncAllocMapT &funcAllocMap,
556+ triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
557+ triton::AllocationAnalysis (getOperation (), &funcAllocMap, this ,
558+ scratchSizeGetter);
557559}
558560
559561std::map<Operation *, SmallVector<Allocation::BufferId>>
0 commit comments