@@ -19,6 +19,7 @@ include "imex/Dialect/XeTile/IR/XeTileTypes.td"
19
19
include "imex/Dialect/XeTile/IR/XeTileAttrs.td"
20
20
21
21
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
22
+ include "mlir/Interfaces/ViewLikeInterface.td"
22
23
23
24
// Base class for dialect operations. This operation inherits from the base
24
25
// `Op` class in OpBase.td, and provides:
@@ -28,7 +29,8 @@ include "mlir/Dialect/Vector/IR/VectorAttributes.td"
28
29
class XeTile_Op<string mnemonic, list<Trait> traits = []> :
29
30
Op<XeTile_Dialect, mnemonic, traits>;
30
31
31
- def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]> {
32
+ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments,
33
+ ViewLikeOpInterface, OffsetSizeAndStrideOpInterface]> {
32
34
let summary = "Describes an XeTile with reference to a base memref";
33
35
let description = [{
34
36
The "init_tile" operation is used to describe a 2D region (i.e. tile) in gloabl memory.
@@ -103,15 +105,26 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
103
105
104
106
}];
105
107
106
- let arguments = (ins XeTile_BaseAddrType:$source,
107
- Variadic<Index>:$offsets,
108
- DenseI64ArrayAttr:$static_offsets,
109
- Variadic<Index>:$dynamic_shape,
110
- Variadic<Index>:$dynamic_strides
108
+ let arguments = (ins XeTile_BaseAddrType: $source,
109
+ Variadic<Index>: $offsets,
110
+ Variadic<Index>: $sizes,
111
+ Variadic<Index>: $strides,
112
+ DenseI64ArrayAttr: $const_offsets,
113
+ OptionalAttr<DenseI64ArrayAttr>: $const_sizes,
114
+ OptionalAttr<DenseI64ArrayAttr>: $const_strides
111
115
);
112
116
113
117
let results = (outs XeTile: $tile);
114
118
119
+
120
+ let assemblyFormat = [{
121
+ $source ``
122
+ custom<DynamicIndexList>($offsets, $const_offsets)
123
+ (`,` custom<DynamicIndexList>($sizes, $const_sizes)^
124
+ `,` custom<DynamicIndexList>($strides, $const_strides))?
125
+ attr-dict `:` type($source) `->` qualified(type($tile))
126
+ }];
127
+
115
128
let builders = [
116
129
// creating init_tile op with static memref
117
130
OpBuilder<(ins "xetile::TileType":$resultType,
@@ -121,12 +134,10 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
121
134
OpBuilder<(ins "xetile::TileType":$resultType,
122
135
"mlir::Value":$source,
123
136
"llvm::ArrayRef<mlir::OpFoldResult>":$offsets,
124
- "llvm::ArrayRef<mlir::Value >":$dynamic_shape ,
125
- "llvm::ArrayRef<mlir::Value >":$dynamic_strides )>
137
+ "llvm::ArrayRef<mlir::OpFoldResult >":$sizes ,
138
+ "llvm::ArrayRef<mlir::OpFoldResult >":$strides )>
126
139
];
127
140
128
- let hasCustomAssemblyFormat = true;
129
-
130
141
let extraClassDeclaration = [{
131
142
/// get source type, could be a memref or an integer
132
143
mlir::Type getSourceType() {return getSource().getType();}
@@ -163,16 +174,6 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
163
174
return getType().getShape();
164
175
}
165
176
166
- /// check if the offsets are static
167
- bool hasStaticOffsets() {
168
- return !mlir::ShapedType::isDynamicShape(getStaticOffsets());
169
- }
170
-
171
- /// check if a given dim in static offsets has a static value
172
- bool hasStaticOffsetAtDim(int dim) {
173
- return !mlir::ShapedType::isDynamic(getStaticOffsets()[dim]);
174
- }
175
-
176
177
/// check if the source memref has static shape info
177
178
/// this method will fail if the source is not a memref
178
179
bool sourceMemRefHasStaticShape() {
@@ -187,14 +188,48 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
187
188
return mlir::cast<mlir::MemRefType>(getSourceType()).getShape();
188
189
}
189
190
190
- /// check if dynamic shape arguments are present
191
- bool hasDynamicShape() {
192
- return getDynamicShape().size();
191
+ /// check if dynamic size arguments are present
192
+ bool hasSizeArgs() {
193
+ auto sizes = getConstSizes().value_or(llvm::ArrayRef<int64_t>({}));
194
+ return sizes.size();
193
195
}
194
196
195
197
/// check if dynamic stride arguments are present
196
- bool hasDynamicStrides() {
197
- return getDynamicStrides().size();
198
+ bool hasStrideArgs() {
199
+ auto strides = getConstStrides().value_or(llvm::ArrayRef<int64_t>({}));
200
+ return strides.size();
201
+ }
202
+
203
+ /// Get static offsets.
204
+ llvm::ArrayRef<int64_t> getStaticOffsets() {
205
+ return getConstOffsets();
206
+ }
207
+
208
+ /// Get the static sizes.
209
+ llvm::ArrayRef<int64_t> getStaticSizes() {
210
+ if (getConstSizes().has_value())
211
+ return getConstSizes().value();
212
+ // At this point, the source must be a memref with static shape.
213
+ assert(sourceMemRefHasStaticShape() && "The source memref does not have static shape.");
214
+ return getSourceMemrefStaticShape();
215
+ }
216
+
217
+ /// Get the static strides.
218
+ llvm::ArrayRef<int64_t> getStaticStrides() {
219
+ if (getConstStrides().has_value())
220
+ return getConstStrides().value();
221
+ // At this point, the source must be a memref with static shape.
222
+ assert(sourceMemRefHasStaticShape() &&
223
+ "The source memref does not have static shape.");
224
+ llvm::SmallVector<int64_t> strides;
225
+ int64_t offset;
226
+ auto memrefType = mlir::dyn_cast<mlir::MemRefType>(getSourceType());
227
+ assert(mlir::succeeded(
228
+ mlir::getStridesAndOffset(memrefType, strides, offset)) &&
229
+ "Failed to get strides and offset. Invalid source memref.");
230
+ // Reuse the op storage.
231
+ setConstStrides(strides);
232
+ return getConstStrides().value();
198
233
}
199
234
200
235
mlir::Attribute getSourceMemorySpace() {
@@ -212,34 +247,23 @@ def XeTile_InitTileOp : XeTile_Op<"init_tile", [Pure, AttrSizedOperandSegments]>
212
247
return 0;
213
248
}
214
249
215
- /// Returns the offsets info to the source. It consolidates
216
- /// information from both dynamic_offsets and static_offsets
217
- /// parameters. static_offsets parameter always has the expected
218
- /// ranks with some dim could have mlir::ShapeType::kDynamic value
219
- /// indicating the corresponding value should be from dynamic_offsets.
220
- // llvm::SmallVector<mlir::OpFoldResult> getOffsets() {
221
- // llvm::SmallVector<mlir::OpFoldResult> offsets;
222
- // auto dynamicOffsets = getOffsets(); // from offsets variable
223
- // auto staticOffsets = getStaticOffsets(); // from static_offsets attribute
224
-
225
- // // in case static_offsets is missing
226
- // if (staticOffsets.size() == 0) {
227
- // offsets.assign(dynamicOffsets.begin(), dynamicOffsets.end());
228
- // return offsets;
229
- // }
230
-
231
- // for (size_t i = 0, j = 0; i < staticOffsets.size(); i++) {
232
- // if (mlir::ShapedType::isDynamic(staticOffsets[i])) {
233
- // assert(j < dynamicOffsets.size());
234
- // offsets.push_back(dynamicOffsets[j++]);
235
- // } else {
236
- // auto ty = mlir::IndexType::get(getContext());
237
- // auto attr = mlir::IntegerAttr::get(ty, staticOffsets[i]);
238
- // offsets.push_back(attr);
239
- // }
240
- // }
241
- // return offsets;
242
- // }
250
+ /// Return the expected rank of each of the`static_offsets`,
251
+ /// `static_shape` and `static_strides` attributes.
252
+ std::array<unsigned, 3> getArrayAttrMaxRanks() {
253
+ unsigned rank;
254
+ if (auto ty = llvm::dyn_cast<mlir::MemRefType>(getSourceType())) {
255
+ rank = ty.getRank();
256
+ } else {
257
+ rank = (unsigned)getMixedOffsets().size();
258
+ }
259
+ return {rank, rank, rank};
260
+ }
261
+
262
+ /// Return the number of leading operands before the `offsets`,
263
+ /// `shape` and `strides` operands.
264
+ static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
265
+
266
+ mlir::Value getViewSource() { return getSource(); }
243
267
244
268
}];
245
269
0 commit comments