Skip to content

Commit 3b8c43b

Browse files
committed
Implemented improvement in linalg debufferize to work through inversesubmap, and a more shophisticated approach taken for tracking current tensors via a tree of regions
1 parent 0edd38e commit 3b8c43b

File tree

4 files changed

+1015
-185
lines changed

4 files changed

+1015
-185
lines changed

include/polygeist/PolygeistOps.td

Lines changed: 96 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,19 +260,111 @@ def TypeAlignOp : Polygeist_Op<"typeAlign", [Pure]> {
260260
let hasCanonicalizer = 1;
261261
}
262262

263+
//Add check for result to be same as original memref/tensor type
264+
def SubmapInverseOp : Polygeist_Op<"submapInverse", [Pure, ViewLikeOpInterface]> {
265+
let summary = "Inverse submap operation for scatter-back semantics";
266+
let description = [{
267+
The `polygeist.submapInverse` operation scatters a modified view back into
268+
the original base tensor/memref, preserving elements not covered by the view.
269+
270+
This is the inverse operation to `polygeist.submap` and is essential for
271+
debufferization of strided memory operations.
272+
273+
Example:
274+
```mlir
275+
// Scatter strided view back into base tensor
276+
%base_updated = polygeist.submapInverse(%base, %modified_view, %stride, %size)
277+
<{map = affine_map<(d0)[s0] -> (d0 * s0)>}>
278+
: (tensor<100xf32>, tensor<50xf32>) -> tensor<100xf32>
279+
280+
// Semantics: base_updated[i*stride] = modified_view[i]
281+
// base_updated[other] = base[other] (preserved)
282+
```
283+
}];
284+
285+
let arguments = (ins
286+
Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the original base">:$base_original,
287+
Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the modified view">:$view_modified,
288+
Variadic<Index>:$indices_and_sizes,
289+
AffineMapAttr:$map
290+
);
291+
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]> : $result);
292+
let hasFolder = 1;
293+
let hasCanonicalizer = 1;
294+
295+
let assemblyFormat = [{
296+
`(` $base_original `,` $view_modified (`,` $indices_and_sizes^)? `)`
297+
attr-dict `:` functional-type(operands, results)
298+
}];
299+
300+
let extraClassDeclaration = [{
301+
::mlir::ValueRange getSymbols() { return getOperands().slice(2, getMap().getNumSymbols()); }
302+
::mlir::ValueRange getSizes() {
303+
auto shapedType = ::llvm::cast<::mlir::ShapedType>(getType());
304+
return getOperands().slice(getMap().getNumSymbols()+2, shapedType.getShape().size());
305+
}
306+
::mlir::Value getViewSource() { return getBaseOriginal(); }
307+
308+
// Type compatibility helpers
309+
bool isMemRefVariant() {
310+
return ::llvm::isa<::mlir::MemRefType>(getBaseOriginal().getType());
311+
}
312+
bool isTensorVariant() {
313+
return ::llvm::isa<::mlir::TensorType>(getBaseOriginal().getType());
314+
}
315+
}];
316+
}
317+
263318
def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> {
264-
let arguments = (ins Arg<AnyMemRef, "the reference to load from">:$memref,
319+
let summary = "Submap operation for strided view extraction";
320+
let description = [{
321+
The `polygeist.submap` operation creates a strided view of a tensor/memref
322+
by applying an affine map to extract elements. This is used to represent
323+
strided access patterns in a composable way.
324+
325+
The operation works in both memref and tensor contexts, enabling
326+
debufferization of strided operations.
327+
328+
Example:
329+
```mlir
330+
// Extract every other element (stride=2)
331+
%view = polygeist.submap(%base, %stride, %size)
332+
<{map = affine_map<(d0)[s0] -> (d0 * s0)>}>
333+
: tensor<100xf32> -> tensor<50xf32>
334+
335+
// Semantics: view[i] = base[i * stride]
336+
```
337+
}];
338+
339+
let arguments = (ins
340+
Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the base to view">:$base,
265341
Variadic<Index>:$indices_and_sizes,
266342
AffineMapAttr:$map
267343
);
268-
let results = (outs AnyMemRef : $result);
344+
let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]> : $result);
269345
let hasFolder = 1;
270346
let hasCanonicalizer = 1;
347+
348+
let assemblyFormat = [{
349+
`(` $base (`,` $indices_and_sizes^)? `)`
350+
attr-dict `:` functional-type(operands, results)
351+
}];
271352

272353
let extraClassDeclaration = [{
273354
::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()); }
274-
::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getType().getShape().size()); }
275-
::mlir::Value getViewSource() { return getMemref(); }
355+
::mlir::ValueRange getSizes() {
356+
auto shapedType = ::llvm::cast<::mlir::ShapedType>(getType());
357+
return getOperands().slice(getMap().getNumSymbols()+1, shapedType.getShape().size());
358+
}
359+
::mlir::Value getViewSource() { return getBase(); }
360+
361+
// Type compatibility helpers
362+
bool isMemRefVariant() {
363+
return ::llvm::isa<::mlir::MemRefType>(getBase().getType());
364+
}
365+
bool isTensorVariant() {
366+
return ::llvm::isa<::mlir::TensorType>(getBase().getType());
367+
}
276368
}];
277369
}
278370

lib/polygeist/Ops.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5769,7 +5769,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern<polygeist::SubmapOp> {
57695769
/// %x = ... : memref<4x5xf32>
57705770
// %y = memref.cast %x : memref<4x5xf32> -> memref<?x?xf32>
57715771
//
5772-
auto source_memref = op.getMemref();
5772+
auto source_memref = op.getBase();
57735773
bool isIdentity = op.getMap().isIdentity();
57745774
bool isInputSameDim = llvm::all_of(
57755775
llvm::zip_equal(op.getSizes(),
@@ -5785,7 +5785,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern<polygeist::SubmapOp> {
57855785
});
57865786
if (isIdentity && isInputSameDim) {
57875787
rewriter.replaceOpWithNewOp<memref::CastOp>(op, op.getType(),
5788-
op.getMemref());
5788+
op.getBase());
57895789
return success();
57905790
}
57915791
if (auto sapOp = source_memref.getDefiningOp<polygeist::SubmapOp>()) {
@@ -5797,7 +5797,7 @@ struct SubMapOpCanonicalize : public OpRewritePattern<polygeist::SubmapOp> {
57975797
operands.append(op.getSymbols().begin(), op.getSymbols().end());
57985798
operands.append(op.getSizes().begin(), op.getSizes().end());
57995799
rewriter.replaceOpWithNewOp<polygeist::SubmapOp>(
5800-
op, op.getType(), sapOp.getMemref(), operands, new_map);
5800+
op, op.getType(), sapOp.getBase(), operands, new_map);
58015801
return success();
58025802
}
58035803
return failure();
@@ -5990,7 +5990,7 @@ static bool canConvertSubmapToSubView(polygeist::SubmapOp submapOp) {
59905990
auto map = submapOp.getMap();
59915991
auto sizes = submapOp.getSizes();
59925992
auto symbols = submapOp.getSymbols();
5993-
auto source_memref = submapOp.getMemref();
5993+
auto source_memref = submapOp.getBase();
59945994

59955995
// 0. Only convert if map has symbols
59965996
if (submapOp.getMap().getNumSymbols() == 0) {
@@ -6111,7 +6111,7 @@ struct SubmapToSubviewOp : public OpRewritePattern<polygeist::SubmapOp> {
61116111
for (Value size : conversionInfo.sizes) {
61126112
sizeValues.push_back(size);
61136113
}
6114-
rewriter.replaceOpWithNewOp<memref::SubViewOp>(submapOp, submapOp.getType(), submapOp.getMemref(), offsetValues, sizeValues, strideValues);
6114+
rewriter.replaceOpWithNewOp<memref::SubViewOp>(submapOp, submapOp.getBase(), offsetValues, sizeValues, strideValues);
61156115
return success();
61166116
}
61176117
};
@@ -6535,7 +6535,7 @@ class LoadSubMap final : public OpRewritePattern<affine::AffineLoadOp> {
65356535

65366536
auto submap_map = subMapOp.getMap();
65376537
auto submap_operands = subMapOp.getSymbols();
6538-
auto source_memref = subMapOp.getMemref();
6538+
auto source_memref = subMapOp.getBase();
65396539

65406540
auto load_map = op.getAffineMap();
65416541
auto load_operands = op.getMapOperands();
@@ -6567,7 +6567,7 @@ class StoreSubMap final : public OpRewritePattern<affine::AffineStoreOp> {
65676567

65686568
auto submap_map = subMapOp.getMap();
65696569
auto submap_operands = subMapOp.getSymbols();
6570-
auto source_memref = subMapOp.getMemref();
6570+
auto source_memref = subMapOp.getBase();
65716571

65726572
auto load_map = op.getAffineMap();
65736573
auto load_operands = op.getMapOperands();
@@ -6705,3 +6705,20 @@ void polygeist::SubmapOp::getCanonicalizationPatterns(
67056705
// results.insert<LoadSubMap, StoreSubMap, DimSubMap>(context);
67066706
}
67076707

6708+
//===----------------------------------------------------------------------===//
6709+
// SubmapInverseOp
6710+
//===----------------------------------------------------------------------===//
6711+
6712+
OpFoldResult mlir::polygeist::SubmapInverseOp::fold(
6713+
mlir::polygeist::SubmapInverseOp::FoldAdaptor adaptor) {
6714+
// TODO: Add folding logic for SubmapInverseOp
6715+
// For now, just return nullptr (no folding)
6716+
return nullptr;
6717+
}
6718+
6719+
void polygeist::SubmapInverseOp::getCanonicalizationPatterns(
6720+
RewritePatternSet &results, MLIRContext *context) {
6721+
// TODO: Add canonicalization patterns for SubmapInverseOp
6722+
// For now, leave empty
6723+
}
6724+

0 commit comments

Comments
 (0)