@@ -260,19 +260,111 @@ def TypeAlignOp : Polygeist_Op<"typeAlign", [Pure]> {
260260 let hasCanonicalizer = 1;
261261}
262262
263+ //Add check for result to be same as original memref/tensor type
264+ def SubmapInverseOp : Polygeist_Op<"submapInverse", [Pure, ViewLikeOpInterface]> {
265+ let summary = "Inverse submap operation for scatter-back semantics";
266+ let description = [{
267+ The `polygeist.submapInverse` operation scatters a modified view back into
268+ the original base tensor/memref, preserving elements not covered by the view.
269+
270+ This is the inverse operation to `polygeist.submap` and is essential for
271+ debufferization of strided memory operations.
272+
273+ Example:
274+ ```mlir
275+ // Scatter strided view back into base tensor
276+ %base_updated = polygeist.submapInverse(%base, %modified_view, %stride, %size)
277+ <{map = affine_map<(d0)[s0] -> (d0 * s0)>}>
278+ : (tensor<100xf32>, tensor<50xf32>) -> tensor<100xf32>
279+
280+ // Semantics: base_updated[i*stride] = modified_view[i]
281+ // base_updated[other] = base[other] (preserved)
282+ ```
283+ }];
284+
285+ let arguments = (ins
286+ Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the original base">:$base_original,
287+ Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the modified view">:$view_modified,
288+ Variadic<Index>:$indices_and_sizes,
289+ AffineMapAttr:$map
290+ );
291+ let results = (outs AnyTypeOf<[AnyMemRef, AnyTensor]> : $result);
292+ let hasFolder = 1;
293+ let hasCanonicalizer = 1;
294+
295+ let assemblyFormat = [{
296+ `(` $base_original `,` $view_modified (`,` $indices_and_sizes^)? `)`
297+ attr-dict `:` functional-type(operands, results)
298+ }];
299+
300+ let extraClassDeclaration = [{
301+ ::mlir::ValueRange getSymbols() { return getOperands().slice(2, getMap().getNumSymbols()); }
302+ ::mlir::ValueRange getSizes() {
303+ auto shapedType = ::llvm::cast<::mlir::ShapedType>(getType());
304+ return getOperands().slice(getMap().getNumSymbols()+2, shapedType.getShape().size());
305+ }
306+ ::mlir::Value getViewSource() { return getBaseOriginal(); }
307+
308+ // Type compatibility helpers
309+ bool isMemRefVariant() {
310+ return ::llvm::isa<::mlir::MemRefType>(getBaseOriginal().getType());
311+ }
312+ bool isTensorVariant() {
313+ return ::llvm::isa<::mlir::TensorType>(getBaseOriginal().getType());
314+ }
315+ }];
316+ }
317+
263318def SubmapOp : Polygeist_Op<"submap", [Pure, ViewLikeOpInterface]> {
264- let arguments = (ins Arg<AnyMemRef, "the reference to load from">:$memref,
319+ let summary = "Submap operation for strided view extraction";
320+ let description = [{
321+ The `polygeist.submap` operation creates a strided view of a tensor/memref
322+ by applying an affine map to extract elements. This is used to represent
323+ strided access patterns in a composable way.
324+
325+ The operation works in both memref and tensor contexts, enabling
326+ debufferization of strided operations.
327+
328+ Example:
329+ ```mlir
330+ // Extract every other element (stride=2)
331+ %view = polygeist.submap(%base, %stride, %size)
332+ <{map = affine_map<(d0)[s0] -> (d0 * s0)>}>
333+ : tensor<100xf32> -> tensor<50xf32>
334+
335+ // Semantics: view[i] = base[i * stride]
336+ ```
337+ }];
338+
339+ let arguments = (ins
340+ Arg<AnyTypeOf<[AnyMemRef, AnyTensor]>, "the base to view">:$base,
265341 Variadic<Index>:$indices_and_sizes,
266342 AffineMapAttr:$map
267343 );
268- let results = (outs AnyMemRef : $result);
344+ let results = (outs AnyTypeOf<[ AnyMemRef, AnyTensor]> : $result);
269345 let hasFolder = 1;
270346 let hasCanonicalizer = 1;
347+
348+ let assemblyFormat = [{
349+ `(` $base (`,` $indices_and_sizes^)? `)`
350+ attr-dict `:` functional-type(operands, results)
351+ }];
271352
272353 let extraClassDeclaration = [{
273354 ::mlir::ValueRange getSymbols() { return getOperands().slice(1, getMap().getNumSymbols()); }
274- ::mlir::ValueRange getSizes() { return getOperands().slice(getMap().getNumSymbols()+1, getType().getShape().size()); }
275- ::mlir::Value getViewSource() { return getMemref(); }
355+ ::mlir::ValueRange getSizes() {
356+ auto shapedType = ::llvm::cast<::mlir::ShapedType>(getType());
357+ return getOperands().slice(getMap().getNumSymbols()+1, shapedType.getShape().size());
358+ }
359+ ::mlir::Value getViewSource() { return getBase(); }
360+
361+ // Type compatibility helpers
362+ bool isMemRefVariant() {
363+ return ::llvm::isa<::mlir::MemRefType>(getBase().getType());
364+ }
365+ bool isTensorVariant() {
366+ return ::llvm::isa<::mlir::TensorType>(getBase().getType());
367+ }
276368 }];
277369}
278370
0 commit comments