@@ -118,70 +118,13 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
118118 return scratchConfig;
119119}
120120
121- unsigned defaultAllocationAnalysisScratchSizeFn (Operation *op) {
122- if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
123- ReduceOpHelper helper (reduceOp);
124- return helper.getScratchSizeInBytes ();
125- }
126- if (auto scanOp = dyn_cast<ScanOp>(op)) {
127- ScanLoweringHelper helper (scanOp);
128- return helper.getScratchSizeInBytes ();
129- }
130- if (auto histogram = dyn_cast<HistogramOp>(op)) {
131- auto dstTy = histogram.getType ();
132- int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
133- op->getParentOfType <ModuleOp>());
134- return std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
135- std::max<int >(8 , dstTy.getElementTypeBitWidth ()) / 8 ;
136- }
137- if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
138- auto srcTy = cvtLayout.getSrc ().getType ();
139- auto dstTy = cvtLayout.getType ();
140- auto srcEncoding = srcTy.getEncoding ();
141- auto dstEncoding = dstTy.getEncoding ();
142- if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
143- mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
144- // Conversions from/to shared memory do not need scratch memory.
145- return 0 ;
146- }
147- // ConvertLayoutOp with both input/output non-shared_layout
148- // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
149- // also possible to realize it with other approaches in restricted
150- // conditions, such as warp-shuffle
151- auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
152- auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
153- return isa<PointerType>(srcTy.getElementType ())
154- ? elems * kPtrBitWidth / 8
155- : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
156- }
157- if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
158- auto value = op->getOperand (0 );
159- // only scalar requires scratch memory
160- // make it explicit for readability
161- if (dyn_cast<RankedTensorType>(value.getType ())) {
162- return 0 ;
163- }
164- auto smemShape = getRepShapeForAtomic (op->getResult (0 ));
165- auto elems = getNumScratchElements (smemShape);
166- auto elemTy = cast<PointerType>(value.getType ()).getPointeeType ();
167- assert (!isa<PointerType>(elemTy) && " unexpected pointer type" );
168- return elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
169- }
170- if (auto createTensormap = dyn_cast<ExperimentalTensormapCreateOp>(op)) {
171- constexpr int32_t kTMASize = 128 ;
172- return kTMASize ;
173- }
174- return 0 ;
175- }
176-
177121class AllocationAnalysis {
178122public:
179123 AllocationAnalysis (Operation *operation,
180124 Allocation::FuncAllocMapT *funcAllocMap,
181- Allocation *allocation,
182- AllocationAnalysisScratchSizeFn scratchSizeGetter)
125+ Allocation *allocation)
183126 : operation(operation), funcAllocMap(funcAllocMap),
184- allocation (allocation), scratchSizeGetter(scratchSizeGetter) {
127+ allocation (allocation) {
185128 run ();
186129 }
187130
@@ -234,19 +177,77 @@ class AllocationAnalysis {
234177
235178 // / Initializes temporary shared memory for a given operation.
236179 void getScratchValueSize (Operation *op) {
237- constexpr size_t scratchAlignment = 128 ;
238- if (auto callOp = dyn_cast<CallOpInterface>(op)) {
180+ const size_t scratchAlignment = 128 ;
181+ if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
182+ ReduceOpHelper helper (reduceOp);
183+ unsigned bytes = helper.getScratchSizeInBytes ();
184+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
185+ scratchAlignment);
186+ } else if (auto scanOp = dyn_cast<ScanOp>(op)) {
187+ ScanLoweringHelper helper (scanOp);
188+ unsigned bytes = helper.getScratchSizeInBytes ();
189+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
190+ scratchAlignment);
191+ } else if (auto histogram = dyn_cast<HistogramOp>(op)) {
192+ auto dstTy = histogram.getType ();
193+ int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
194+ op->getParentOfType <ModuleOp>());
195+ auto bytes = std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
196+ std::max<int >(8 , dstTy.getElementTypeBitWidth ()) / 8 ;
197+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
198+ scratchAlignment);
199+ } else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
200+ auto srcTy = cvtLayout.getSrc ().getType ();
201+ auto dstTy = cvtLayout.getType ();
202+ auto srcEncoding = srcTy.getEncoding ();
203+ auto dstEncoding = dstTy.getEncoding ();
204+ if (mlir::isa<gpu::SharedEncodingAttr>(srcEncoding) ||
205+ mlir::isa<gpu::SharedEncodingAttr>(dstEncoding)) {
206+ // Conversions from/to shared memory do not need scratch memory.
207+ return ;
208+ }
209+ // ConvertLayoutOp with both input/output non-shared_layout
210+ // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
211+ // also possible to realize it with other approaches in restricted
212+ // conditions, such as warp-shuffle
213+ auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
214+ auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
215+ auto bytes =
216+ isa<PointerType>(srcTy.getElementType ())
217+ ? elems * kPtrBitWidth / 8
218+ : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
219+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
220+ scratchAlignment);
221+ } else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
222+ auto value = op->getOperand (0 );
223+ // only scalar requires scratch memory
224+ // make it explicit for readability
225+ if (dyn_cast<RankedTensorType>(value.getType ())) {
226+ // nothing to do
227+ } else {
228+ auto smemShape = getRepShapeForAtomic (op->getResult (0 ));
229+ auto elems = getNumScratchElements (smemShape);
230+ auto elemTy = cast<PointerType>(value.getType ()).getPointeeType ();
231+ assert (!isa<PointerType>(elemTy) && " unexpected pointer type" );
232+ auto bytes =
233+ elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
234+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
235+ scratchAlignment);
236+ }
237+ } else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
239238 auto callable = callOp.resolveCallable ();
240239 auto funcOp = dyn_cast<FunctionOpInterface>(callable);
241240 auto *funcAlloc = &(*funcAllocMap)[funcOp];
242241 auto bytes = funcAlloc->getSharedMemorySize ();
243242 maybeAddScratchBuffer<BufferT::BufferKind::Virtual>(op, bytes,
244243 scratchAlignment);
245- return ;
244+ } else if (auto createTensormap =
245+ dyn_cast<ExperimentalTensormapCreateOp>(op)) {
246+ constexpr int32_t kTMASize = 128 ;
247+ constexpr int32_t kTMAAlign = 128 ;
248+ maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, kTMASize ,
249+ kTMAAlign );
246250 }
247- unsigned bytes = scratchSizeGetter (op);
248- maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
249- scratchAlignment);
250251 }
251252
252253 void getValueAlias (Value value, SharedMemoryAliasAnalysis &analysis) {
@@ -546,16 +547,13 @@ class AllocationAnalysis {
546547 Allocation::FuncAllocMapT *funcAllocMap;
547548 Allocation *allocation;
548549 BufferRangeMapT bufferRange;
549- AllocationAnalysisScratchSizeFn scratchSizeGetter;
550550};
551551
552552} // namespace triton
553553
554- void Allocation::run (
555- FuncAllocMapT &funcAllocMap,
556- triton::AllocationAnalysisScratchSizeFn scratchSizeGetter) {
557- triton::AllocationAnalysis (getOperation (), &funcAllocMap, this ,
558- scratchSizeGetter);
554+ template <>
555+ void Allocation::run<triton::AllocationAnalysis>(FuncAllocMapT &funcAllocMap) {
556+ triton::AllocationAnalysis (getOperation (), &funcAllocMap, this );
559557}
560558
561559std::map<Operation *, SmallVector<Allocation::BufferId>>
0 commit comments