@@ -236,6 +236,10 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
236236 return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getType().getLayout());
237237 }
238238
239+ ArrayRef<int64_t> getDistributeShape() {
240+ return getTensorDescShape();
241+ }
242+
239243 }];
240244}
241245
@@ -266,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
266270 xegpu::TensorDescType getTensorDescType() {
267271 return getTensorDesc().getType();
268272 }
273+
274+ SmallVector<OpFoldResult> getMixedOffsets() {
275+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
276+ auto dynamics = getOffsets();
277+ if (statics.size() == 0 && dynamics.size() == 0)
278+ return {};
279+ return getMixedValues(statics, dynamics, getContext());
280+ }
281+
282+ xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
283+ return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
284+ }
285+
286+ ArrayRef<int64_t> getDistributeShape() {
287+ return getTensorDescType().getShape();
288+ }
289+
269290 }];
270291
271292 let assemblyFormat = [{
@@ -347,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
347368 xegpu::TensorDescType getTensorDescType() {
348369 return getTensorDesc().getType();
349370 }
371+
372+ SmallVector<OpFoldResult> getMixedOffsets() {
373+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
374+ auto dynamics = getOffsets();
375+ if (statics.size() == 0 && dynamics.size() == 0)
376+ return {};
377+ return getMixedValues(statics, dynamics, getContext());
378+ }
379+
380+ xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
381+ return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
382+ }
383+
384+ ArrayRef<int64_t> getDistributeShape() {
385+ return getTensorDescType().getShape();
386+ }
387+
388+
350389 }];
351390
352391 let assemblyFormat = [{
@@ -421,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
421460 xegpu::TensorDescType getTensorDescType() {
422461 return getTensorDesc().getType();
423462 }
463+
464+ SmallVector<OpFoldResult> getMixedOffsets() {
465+ auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
466+ auto dynamics = getOffsets();
467+ if (statics.size() == 0 && dynamics.size() == 0)
468+ return {};
469+ return getMixedValues(statics, dynamics, getContext());
470+ }
471+
472+ xegpu::DistributeLayoutAttrInterface getLayoutAttr() {
473+ return dyn_cast_if_present<xegpu::DistributeLayoutAttrInterface>(getTensorDescType().getLayout());
474+ }
475+
476+ ArrayRef<int64_t> getDistributeShape() {
477+ return getTensorDescType().getShape();
478+ }
479+
424480 }];
425481
426482 let assemblyFormat = [{
@@ -644,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
644700 xegpu::TensorDescType getTensorDescType() {
645701 return dyn_cast<xegpu::TensorDescType>(getSourceType());
646702 }
703+
647704 }];
648705
649706 let assemblyFormat = [{
@@ -1185,6 +1242,10 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11851242 SmallVector<OpFoldResult> getMixedOffsets() {
11861243 return getMixedValues(getConstOffsets(), getOffsets(), getContext());
11871244 }
1245+
1246+ ArrayRef<int64_t> getDistributeShape() {
1247+ return getRes().getType().getShape();
1248+ }
11881249 }];
11891250
11901251 let hasVerifier = 1;
@@ -1223,6 +1284,11 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
12231284 SmallVector<OpFoldResult> getMixedOffsets() {
12241285 return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12251286 }
1287+
1288+ ArrayRef<int64_t> getDistributeShape() {
1289+ return getData().getType().getShape();
1290+ }
1291+
12261292 }];
12271293
12281294 let hasVerifier = 1;
0 commit comments