Skip to content

Commit a2643ea

Browse files
authored
Enable getMixedOffsets/Sizes/Strides interface for XeTile InitTileOp. (#829)
1 parent 408bcb1 commit a2643ea

File tree

5 files changed

+157
-299
lines changed

5 files changed

+157
-299
lines changed

include/imex/Dialect/XeTile/IR/XeTileOps.td

Lines changed: 77 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ include "imex/Dialect/XeTile/IR/XeTileTypes.td"
1919
include "imex/Dialect/XeTile/IR/XeTileAttrs.td"
2020

2121
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22+
include "mlir/Interfaces/ViewLikeInterface.td"
2223

2324
// Base class for dialect operations. This operation inherits from the base
2425
// `Op` class in OpBase.td, and provides:
@@ -28,7 +29,8 @@ include "mlir/Dialect/Vector/IR/VectorAttributes.td"
2829
class XeTile_Op<string mnemonic, list<Trait> traits = []> :
2930
Op<XeTile_Dialect, mnemonic, traits>;
3031

31-
def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> {
32+
def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
33+
ViewLikeOpInterface, OffsetSizeAndStrideOpInterface]> {
3234
let summary = "Describes an XeTile with reference to a base memref";
3335
let description = [{
3436
The "init_tile" operation is used to describe a 2D region (i.e. tile) in gloabl memory.
@@ -103,15 +105,26 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
103105

104106
}];
105107

106-
let arguments = (ins XeTile_BaseAddrType:$source,
107-
Variadic<Index>:$offsets,
108-
DenseI64ArrayAttr:$static_offsets,
109-
Variadic<Index>:$dynamic_shape,
110-
Variadic<Index>:$dynamic_strides
108+
let arguments = (ins XeTile_BaseAddrType: $source,
109+
Variadic<Index>: $offsets,
110+
Variadic<Index>: $sizes,
111+
Variadic<Index>: $strides,
112+
DenseI64ArrayAttr: $const_offsets,
113+
OptionalAttr<DenseI64ArrayAttr>: $const_sizes,
114+
OptionalAttr<DenseI64ArrayAttr>: $const_strides
111115
);
112116

113117
let results = (outs XeTile: $tile);
114118

119+
120+
let assemblyFormat = [{
121+
$source ``
122+
custom<DynamicIndexList>($offsets, $const_offsets)
123+
(`,` custom<DynamicIndexList>($sizes, $const_sizes)^
124+
`,` custom<DynamicIndexList>($strides, $const_strides))?
125+
attr-dict `:` type($source) `->` qualified(type($tile))
126+
}];
127+
115128
let builders = [
116129
// creating init_tile op with static memref
117130
OpBuilder<(ins "xetile::TileType":$resultType,
@@ -121,12 +134,10 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
121134
OpBuilder<(ins "xetile::TileType":$resultType,
122135
"mlir::Value":$source,
123136
"llvm::ArrayRef<mlir::OpFoldResult>":$offsets,
124-
"llvm::ArrayRef<mlir::Value>":$dynamic_shape,
125-
"llvm::ArrayRef<mlir::Value>":$dynamic_strides)>
137+
"llvm::ArrayRef<mlir::OpFoldResult>":$sizes,
138+
"llvm::ArrayRef<mlir::OpFoldResult>":$strides)>
126139
];
127140

128-
let hasCustomAssemblyFormat = true;
129-
130141
let extraClassDeclaration = [{
131142
/// get source type, could be a memref or an integer
132143
mlir::Type getSourceType() {return getSource().getType();}
@@ -163,16 +174,6 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
163174
return getType().getShape();
164175
}
165176

166-
/// check if the offsets are static
167-
bool hasStaticOffsets() {
168-
return !mlir::ShapedType::isDynamicShape(getStaticOffsets());
169-
}
170-
171-
/// check if a given dim in static offsets has a static value
172-
bool hasStaticOffsetAtDim(int dim) {
173-
return !mlir::ShapedType::isDynamic(getStaticOffsets()[dim]);
174-
}
175-
176177
/// check if the source memref has static shape info
177178
/// this method will fail if the source is not a memref
178179
bool sourceMemRefHasStaticShape() {
@@ -187,14 +188,48 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
187188
return mlir::cast<mlir::MemRefType>(getSourceType()).getShape();
188189
}
189190

190-
/// check if dynamic shape arguments are present
191-
bool hasDynamicShape() {
192-
return getDynamicShape().size();
191+
/// check if dynamic size arguments are present
192+
bool hasSizeArgs() {
193+
auto sizes = getConstSizes().value_or(llvm::ArrayRef<int64_t>({}));
194+
return sizes.size();
193195
}
194196

195197
/// check if dynamic stride arguments are present
196-
bool hasDynamicStrides() {
197-
return getDynamicStrides().size();
198+
bool hasStrideArgs() {
199+
auto strides = getConstStrides().value_or(llvm::ArrayRef<int64_t>({}));
200+
return strides.size();
201+
}
202+
203+
/// Get static offsets.
204+
llvm::ArrayRef<int64_t> getStaticOffsets() {
205+
return getConstOffsets();
206+
}
207+
208+
/// Get the static sizes.
209+
llvm::ArrayRef<int64_t> getStaticSizes() {
210+
if (getConstSizes().has_value())
211+
return getConstSizes().value();
212+
// At this point, the source must be a memref with static shape.
213+
assert(sourceMemRefHasStaticShape() && "The source memref does not have static shape.");
214+
return getSourceMemrefStaticShape();
215+
}
216+
217+
/// Get the static strides.
218+
llvm::ArrayRef<int64_t> getStaticStrides() {
219+
if (getConstStrides().has_value())
220+
return getConstStrides().value();
221+
// At this point, the source must be a memref with static shape.
222+
assert(sourceMemRefHasStaticShape() &&
223+
"The source memref does not have static shape.");
224+
llvm::SmallVector<int64_t> strides;
225+
int64_t offset;
226+
auto memrefType = mlir::dyn_cast<mlir::MemRefType>(getSourceType());
227+
assert(mlir::succeeded(
228+
mlir::getStridesAndOffset(memrefType, strides, offset)) &&
229+
"Failed to get strides and offset. Invalid source memref.");
230+
// Reuse the op storage.
231+
setConstStrides(strides);
232+
return getConstStrides().value();
198233
}
199234

200235
mlir::Attribute getSourceMemorySpace() {
@@ -212,34 +247,23 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
212247
return 0;
213248
}
214249

215-
/// Returns the offsets info to the source. It consolidates
216-
/// information from both dynamic_offsets and static_offsets
217-
/// parameters. static_offsets parameter always has the expected
218-
/// ranks with some dim could have mlir::ShapeType::kDynamic value
219-
/// indicating the corresponding value should be from dynamic_offsets.
220-
// llvm::SmallVector<mlir::OpFoldResult> getOffsets() {
221-
// llvm::SmallVector<mlir::OpFoldResult> offsets;
222-
// auto dynamicOffsets = getOffsets(); // from offsets variable
223-
// auto staticOffsets = getStaticOffsets(); // from static_offsets attribute
224-
225-
// // in case static_offsets is missing
226-
// if (staticOffsets.size() == 0) {
227-
// offsets.assign(dynamicOffsets.begin(), dynamicOffsets.end());
228-
// return offsets;
229-
// }
230-
231-
// for (size_t i = 0, j = 0; i < staticOffsets.size(); i++) {
232-
// if (mlir::ShapedType::isDynamic(staticOffsets[i])) {
233-
// assert(j < dynamicOffsets.size());
234-
// offsets.push_back(dynamicOffsets[j++]);
235-
// } else {
236-
// auto ty = mlir::IndexType::get(getContext());
237-
// auto attr = mlir::IntegerAttr::get(ty, staticOffsets[i]);
238-
// offsets.push_back(attr);
239-
// }
240-
// }
241-
// return offsets;
242-
// }
250+
/// Return the expected rank of each of the`static_offsets`,
251+
/// `static_shape` and `static_strides` attributes.
252+
std::array<unsigned, 3> getArrayAttrMaxRanks() {
253+
unsigned rank;
254+
if (auto ty = llvm::dyn_cast<mlir::MemRefType>(getSourceType())) {
255+
rank = ty.getRank();
256+
} else {
257+
rank = (unsigned)getMixedOffsets().size();
258+
}
259+
return {rank, rank, rank};
260+
}
261+
262+
/// Return the number of leading operands before the `offsets`,
263+
/// `shape` and `strides` operands.
264+
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
265+
266+
mlir::Value getViewSource() { return getSource(); }
243267

244268
}];
245269

0 commit comments

Comments
 (0)