Skip to content

Commit 93acad2

Browse files
committed
cleanup
1 parent 9af1f7f commit 93acad2

File tree

5 files changed

+249
-343
lines changed

5 files changed

+249
-343
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
236236
return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
237237
}
238238

239+
ArrayRef<int64_t> getDistributeShape() {
240+
return getTensorDescShape();
241+
}
242+
239243
}];
240244
}
241245

@@ -266,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
266270
xegpu::TensorDescType getTensorDescType() {
267271
return getTensorDesc().getType();
268272
}
273+
274+
SmallVector<OpFoldResult> getMixedOffsets() {
275+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
276+
auto dynamics = getOffsets();
277+
if (statics.size() == 0 && dynamics.size() == 0)
278+
return {};
279+
return getMixedValues(statics, dynamics, getContext());
280+
}
281+
282+
xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
283+
return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
284+
}
285+
286+
ArrayRef<int64_t> getDistributeShape() {
287+
return getTensorDescType().getShape();
288+
}
289+
269290
}];
270291

271292
let assemblyFormat = [{
@@ -347,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
347368
xegpu::TensorDescType getTensorDescType() {
348369
return getTensorDesc().getType();
349370
}
371+
372+
SmallVector<OpFoldResult> getMixedOffsets() {
373+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
374+
auto dynamics = getOffsets();
375+
if (statics.size() == 0 && dynamics.size() == 0)
376+
return {};
377+
return getMixedValues(statics, dynamics, getContext());
378+
}
379+
380+
xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
381+
return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
382+
}
383+
384+
ArrayRef<int64_t> getDistributeShape() {
385+
return getTensorDescType().getShape();
386+
}
387+
388+
350389
}];
351390

352391
let assemblyFormat = [{
@@ -421,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
421460
xegpu::TensorDescType getTensorDescType() {
422461
return getTensorDesc().getType();
423462
}
463+
464+
SmallVector<OpFoldResult> getMixedOffsets() {
465+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
466+
auto dynamics = getOffsets();
467+
if (statics.size() == 0 && dynamics.size() == 0)
468+
return {};
469+
return getMixedValues(statics, dynamics, getContext());
470+
}
471+
472+
xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
473+
return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
474+
}
475+
476+
ArrayRef<int64_t> getDistributeShape() {
477+
return getTensorDescType().getShape();
478+
}
479+
424480
}];
425481

426482
let assemblyFormat = [{
@@ -644,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
644700
xegpu::TensorDescType getTensorDescType() {
645701
return dyn_cast<xegpu::TensorDescType>(getSourceType());
646702
}
703+
647704
}];
648705

649706
let assemblyFormat = [{
@@ -1185,6 +1242,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11851242
SmallVector<OpFoldResult> getMixedOffsets() {
11861243
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
11871244
}
1245+
1246+
ArrayRef<int64_t> getDistributeShape() {
1247+
return getRes().getType().getShape();
1248+
}
11881249
}];
11891250

11901251
let hasVerifier = 1;
@@ -1223,6 +1284,11 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
12231284
SmallVector<OpFoldResult> getMixedOffsets() {
12241285
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12251286
}
1287+
1288+
ArrayRef<int64_t> getDistributeShape() {
1289+
return getData().getType().getShape();
1290+
}
1291+
12261292
}];
12271293

12281294
let hasVerifier = 1;

0 commit comments

Comments
 (0)