@@ -144,133 +144,6 @@ class StructuredToMemrefPass
144144 ttx::TritonTilingExtDialect, memref::MemRefDialect>();
145145 }
146146
147- LogicalResult convertArgsToMemrefType () {
148- auto moduleOp = getOperation ();
149-
150- RewritePatternSet patterns (&getContext ());
151- ConversionTarget target (getContext ());
152- TritonFunctionSignatureConverter typeConverter;
153-
154- // Update function signature to use memrefs
155- target.addDynamicallyLegalOp <func::FuncOp>([&](func::FuncOp op) {
156- return typeConverter.isSignatureLegal (op.getFunctionType ());
157- });
158-
159- populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
160- patterns, typeConverter);
161-
162- return applyPartialConversion (moduleOp, target, std::move (patterns));
163- }
164-
165- // We leverage the 1->N conversion infrastructure to convert tt.addptr for
166- // scalar to memref.reinterpret_cast.
167- //
168- // A tt.addptr has the following form:
169- //
170- // %new_ptr = tt.addptr %ptr %offset
171- //
172- // where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.
173- //
174- // With this form, there can be a chain of tt.addptr where we keep adding
175- // offsets to an existing pointer:
176- //
177- // %ptr_1 = tt.addptr %arg0 %offset
178- // %ptr_2 = tt.addptr %ptr_1 %offset
179- // %ptr_3 = tt.addptr %ptr_2 %offset
180- //
181- // Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that
182- // the pointers can be used by affine.load and affine.store (lowered from
183- // tt.load and tt.store).
184- //
185- // A memref.reinterpret_cast op also takes an offset and returns a memref in a
186- // similar fashion to tt.addptr:
187- //
188- // %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes:
189- // [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset:
190- // ?>>
191- //
192- // However, since the semantic of memref.reinterpret_cast is different,
193- // the following lowering would be incorrect for the sequence of tt.addptr
194- // above:
195- //
196- // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
197- // %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset]
198- // %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset]
199- //
200- // The above sequence is equivalent to:
201- //
202- // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
203- // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset]
204- // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset]
205- //
206- // In other word, memref.reinterpret_cast ignores the current offset of the
207- // input buffer.
208- //
209- // Therefore, we have to manually track the offset for each addptr by lowering
210- // to the following form:
211- //
212- // %offset_1 = arith.addi %cst_0 %offset
213- // %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1]
214- //
215- // %offset_2 = arith.addi %offset_1 %offset
216- // %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2]
217- //
218- // %offset_3 = arith.addi %offset_2 %offset
219- // %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3]
220- //
221- // Each tt.addptr is lowered to a pair of arith.addi that accumulates the
222- // current offset before using that offset to the reinterpret_cast.
223- LogicalResult convertAddPtrToReinterpretCast () {
224- auto moduleOp = getOperation ();
225-
226- RewritePatternSet patterns (&getContext ());
227-
228- auto context = &getContext ();
229- OneToNTypeConverter converter;
230- converter.addConversion ([](Type type) { return type; });
231-
232- // We are doing a 1->2 type conversion here, where a triton pointer type
233- // maps to a pair of {memref, index} type for the the buffer and offset.
234- converter.addConversion (
235- [context](triton::PointerType ptrType, SmallVectorImpl<Type> &types)
236- -> std::optional<LogicalResult> {
237- types = SmallVector<Type>{getMemrefTypeForScalarPtr (ptrType, context),
238- IndexType::get (context)};
239- return success ();
240- });
241-
242- // Hooks to compute the correct materialization, "argument" and "source"
243- // materialization are used when we need to convert a pair of {memref,
244- // index} type back to the original triton pointer type.
245- // These are used when there are ops that still need to use the original
246- // pointer type. For instance, we convert the result of tt.addptr from
247- // tt.ptr type to a pair of {memref, index}, but the original ptr result is
248- // still being used by another tt.load or tt.store.
249- converter.addArgumentMaterialization (buildCastOp);
250- converter.addSourceMaterialization (buildCastOp);
251-
252- // Compute the target materialization, given a value with the pointer type,
253- // convert that value to a pair of {memref, index} type.
254- converter.addTargetMaterialization (buildCastAndOffsetOps);
255-
256- patterns.add <ScalarAddptrConverter>(converter, context);
257-
258- scf::populateSCFStructuralOneToNTypeConversions (converter, patterns);
259-
260- if (failed (applyPartialOneToNConversion (getOperation (), converter,
261- std::move (patterns)))) {
262- return failure ();
263- }
264-
265- PassManager pm (&getContext (), moduleOp.getOperationName ());
266- pm.addPass (createCanonicalizerPass ());
267- if (failed (runPipeline (pm, getOperation ()))) {
268- return failure ();
269- }
270-
271- return success ();
272- }
273-
274147 void runOnOperation () override {
275148 auto moduleOp = getOperation ();
276149
0 commit comments