@@ -232,6 +232,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
232
232
return static_cast<unsigned>(MemorySpace::Global);
233
233
}
234
234
235
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
236
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
237
+ }
238
+
239
+ ArrayRef<int64_t> getDataShape() {
240
+ return getTensorDescShape();
241
+ }
242
+
235
243
}];
236
244
}
237
245
@@ -262,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
262
270
xegpu::TensorDescType getTensorDescType() {
263
271
return getTensorDesc().getType();
264
272
}
273
+
274
+ SmallVector<OpFoldResult> getMixedOffsets() {
275
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
276
+ auto dynamics = getOffsets();
277
+ if (statics.size() == 0 && dynamics.size() == 0)
278
+ return {};
279
+ return getMixedValues(statics, dynamics, getContext());
280
+ }
281
+
282
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
283
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
284
+ }
285
+
286
+ ArrayRef<int64_t> getDataShape() {
287
+ return getTensorDescType().getShape();
288
+ }
289
+
265
290
}];
266
291
267
292
let assemblyFormat = [{
@@ -343,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
343
368
xegpu::TensorDescType getTensorDescType() {
344
369
return getTensorDesc().getType();
345
370
}
371
+
372
+ SmallVector<OpFoldResult> getMixedOffsets() {
373
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
374
+ auto dynamics = getOffsets();
375
+ if (statics.size() == 0 && dynamics.size() == 0)
376
+ return {};
377
+ return getMixedValues(statics, dynamics, getContext());
378
+ }
379
+
380
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
381
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
382
+ }
383
+
384
+ ArrayRef<int64_t> getDataShape() {
385
+ return getTensorDescType().getShape();
386
+ }
387
+
388
+
346
389
}];
347
390
348
391
let assemblyFormat = [{
@@ -417,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
417
460
xegpu::TensorDescType getTensorDescType() {
418
461
return getTensorDesc().getType();
419
462
}
463
+
464
+ SmallVector<OpFoldResult> getMixedOffsets() {
465
+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
466
+ auto dynamics = getOffsets();
467
+ if (statics.size() == 0 && dynamics.size() == 0)
468
+ return {};
469
+ return getMixedValues(statics, dynamics, getContext());
470
+ }
471
+
472
+ xegpu::DistributeLayoutAttr getLayoutAttr() {
473
+ return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
474
+ }
475
+
476
+ ArrayRef<int64_t> getDataShape() {
477
+ return getTensorDescType().getShape();
478
+ }
479
+
420
480
}];
421
481
422
482
let assemblyFormat = [{
@@ -640,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
640
700
xegpu::TensorDescType getTensorDescType() {
641
701
return dyn_cast<xegpu::TensorDescType>(getSourceType());
642
702
}
703
+
643
704
}];
644
705
645
706
let assemblyFormat = [{
@@ -1150,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1150
1211
let arguments = (ins XeGPU_MemDesc:$mem_desc,
1151
1212
Variadic<Index>: $offsets,
1152
1213
DenseI64ArrayAttr: $const_offsets,
1153
- OptionalAttr<LayoutTrait >:$layout
1214
+ OptionalAttr<DistributeLayoutAttr >:$layout
1154
1215
);
1155
1216
let results = (outs XeGPU_ValueType:$res);
1156
1217
let assemblyFormat = [{
@@ -1175,12 +1236,16 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
1175
1236
1176
1237
let builders = [
1177
1238
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
1178
- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait ": $layout)>,
1239
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr ": $layout)>,
1179
1240
];
1180
1241
let extraClassDeclaration = [{
1181
1242
SmallVector<OpFoldResult> getMixedOffsets() {
1182
1243
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1183
1244
}
1245
+
1246
+ ArrayRef<int64_t> getDataShape() {
1247
+ return getRes().getType().getShape();
1248
+ }
1184
1249
}];
1185
1250
1186
1251
let hasVerifier = 1;
@@ -1194,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1194
1259
XeGPU_MemDesc:$mem_desc,
1195
1260
Variadic<Index>: $offsets,
1196
1261
DenseI64ArrayAttr: $const_offsets,
1197
- OptionalAttr<LayoutTrait >:$layout
1262
+ OptionalAttr<DistributeLayoutAttr >:$layout
1198
1263
);
1199
1264
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
1200
1265
prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,12 +1278,17 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
1213
1278
}];
1214
1279
let builders = [
1215
1280
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
1216
- "llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait ": $layout)>,
1281
+ "llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr ": $layout)>,
1217
1282
];
1218
1283
let extraClassDeclaration = [{
1219
1284
SmallVector<OpFoldResult> getMixedOffsets() {
1220
1285
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
1221
1286
}
1287
+
1288
+ ArrayRef<int64_t> getDataShape() {
1289
+ return getData().getType().getShape();
1290
+ }
1291
+
1222
1292
}];
1223
1293
1224
1294
let hasVerifier = 1;
0 commit comments