Skip to content

Commit 0ea853b

Browse files
committed
Update
1 parent 4969757 commit 0ea853b

File tree

2 files changed

+0
-134
lines changed

2 files changed

+0
-134
lines changed

lib/Conversion/StructuredToMemref/StructuredToMemref.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ static memref::SubViewOp getSubview(int rank, ArrayRef<OpFoldResult> dims,
5959
offsets, dims, strides);
6060
}
6161

62-
static Value getPtr(Value v) {
63-
while (auto op = v.getDefiningOp()) {
64-
v = op->getOperand(0);
65-
}
66-
return v;
67-
}
68-
6962
namespace {
7063

7164
struct MakeTensorPtrConverter

lib/Conversion/StructuredToMemref/StructuredToMemrefPass.cpp

Lines changed: 0 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -144,133 +144,6 @@ class StructuredToMemrefPass
144144
ttx::TritonTilingExtDialect, memref::MemRefDialect>();
145145
}
146146

147-
LogicalResult convertArgsToMemrefType() {
148-
auto moduleOp = getOperation();
149-
150-
RewritePatternSet patterns(&getContext());
151-
ConversionTarget target(getContext());
152-
TritonFunctionSignatureConverter typeConverter;
153-
154-
// Update function signature to use memrefs
155-
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
156-
return typeConverter.isSignatureLegal(op.getFunctionType());
157-
});
158-
159-
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
160-
patterns, typeConverter);
161-
162-
return applyPartialConversion(moduleOp, target, std::move(patterns));
163-
}
164-
165-
// We leverage the 1->N conversion infrastructure to convert tt.addptr for
166-
// scalar to memref.reinterpret_cast.
167-
//
168-
// A tt.addptr has the following form:
169-
//
170-
// %new_ptr = tt.addptr %ptr %offset
171-
//
172-
// where %new_ptr and %ptr have tt.ptr type, and %offset is of index type.
173-
//
174-
// With this form, there can be a chain of tt.addptr where we keep adding
175-
// offsets to an existing pointer:
176-
//
177-
// %ptr_1 = tt.addptr %arg0 %offset
178-
// %ptr_2 = tt.addptr %ptr_1 %offset
179-
// %ptr_3 = tt.addptr %ptr_2 %offset
180-
//
181-
// Now, we want to lower each tt.addptr to a memref.reinterpret_cast so that
182-
// the pointers can be used by affine.load and affine.store (lowered from
183-
// tt.load and tt.store).
184-
//
185-
// A memref.reinterpret_cast op also takes an offset and returns a memref in a
186-
// similar fashion to tt.addptr:
187-
//
188-
// %reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes:
189-
// [1], strides: [1] : memref<*xf32> to memref<1xf32, strided<[1], offset:
190-
// ?>>
191-
//
192-
// However, since the semantic of memref.reinterpret_cast is different,
193-
// the following lowering would be incorrect for the sequence of tt.addptr
194-
// above:
195-
//
196-
// %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
197-
// %cast_2 = memref.reinterpret_cast %cast_1 to offset [%offset]
198-
// %cast_3 = memref.reinterpret_cast %cast_2 to offset [%offset]
199-
//
200-
// The above sequence is equivalent to:
201-
//
202-
// %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset]
203-
// %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset]
204-
// %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset]
205-
//
206-
// In other word, memref.reinterpret_cast ignores the current offset of the
207-
// input buffer.
208-
//
209-
// Therefore, we have to manually track the offset for each addptr by lowering
210-
// to the following form:
211-
//
212-
// %offset_1 = arith.addi %cst_0 %offset
213-
// %cast_1 = memref.reinterpret_cast %arg0 to offset [%offset_1]
214-
//
215-
// %offset_2 = arith.addi %offset_1 %offset
216-
// %cast_2 = memref.reinterpret_cast %arg0 to offset [%offset_2]
217-
//
218-
// %offset_3 = arith.addi %offset_2 %offset
219-
// %cast_3 = memref.reinterpret_cast %arg0 to offset [%offset_3]
220-
//
221-
// Each tt.addptr is lowered to a pair of arith.addi that accumulates the
222-
// current offset before using that offset to the reinterpret_cast.
223-
LogicalResult convertAddPtrToReinterpretCast() {
224-
auto moduleOp = getOperation();
225-
226-
RewritePatternSet patterns(&getContext());
227-
228-
auto context = &getContext();
229-
OneToNTypeConverter converter;
230-
converter.addConversion([](Type type) { return type; });
231-
232-
// We are doing a 1->2 type conversion here, where a triton pointer type
233-
// maps to a pair of {memref, index} type for the the buffer and offset.
234-
converter.addConversion(
235-
[context](triton::PointerType ptrType, SmallVectorImpl<Type> &types)
236-
-> std::optional<LogicalResult> {
237-
types = SmallVector<Type>{getMemrefTypeForScalarPtr(ptrType, context),
238-
IndexType::get(context)};
239-
return success();
240-
});
241-
242-
// Hooks to compute the correct materialization, "argument" and "source"
243-
// materialization are used when we need to convert a pair of {memref,
244-
// index} type back to the original triton pointer type.
245-
// These are used when there are ops that still need to use the original
246-
// pointer type. For instance, we convert the result of tt.addptr from
247-
// tt.ptr type to a pair of {memref, index}, but the original ptr result is
248-
// still being used by another tt.load or tt.store.
249-
converter.addArgumentMaterialization(buildCastOp);
250-
converter.addSourceMaterialization(buildCastOp);
251-
252-
// Compute the target materialization, given a value with the pointer type,
253-
// convert that value to a pair of {memref, index} type.
254-
converter.addTargetMaterialization(buildCastAndOffsetOps);
255-
256-
patterns.add<ScalarAddptrConverter>(converter, context);
257-
258-
scf::populateSCFStructuralOneToNTypeConversions(converter, patterns);
259-
260-
if (failed(applyPartialOneToNConversion(getOperation(), converter,
261-
std::move(patterns)))) {
262-
return failure();
263-
}
264-
265-
PassManager pm(&getContext(), moduleOp.getOperationName());
266-
pm.addPass(createCanonicalizerPass());
267-
if (failed(runPipeline(pm, getOperation()))) {
268-
return failure();
269-
}
270-
271-
return success();
272-
}
273-
274147
void runOnOperation() override {
275148
auto moduleOp = getOperation();
276149

0 commit comments

Comments
 (0)