@@ -231,7 +231,55 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
231231 multiple blocks according to round-robin distribution rules.}],
232232 "FailureOr<SmallVector<SmallVector<Value>>>",
233233 "getOffsets",
234- (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
234+ (ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
235+ InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
236+ to some other layout according to given permutation of (0...n-1).}],
237+ /*retTy=*/"bool",
238+ /*methodName=*/"isTransposeOf",
239+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
240+ /*methodBody=*/[{
241+ if (!other)
242+ return false;
243+ if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
244+ return false;
245+ // check if the permutation is valid
246+ int64_t rank = $_self.getRank();
247+ SmallVector<bool, 8> seen(rank, false);
248+ for (const auto &ta : llvm::enumerate(perm)) {
249+ if (ta.value() < 0 || ta.value() >= rank)
250+ return false;
251+ if (seen[ta.value()])
252+ return false;
253+ seen[ta.value()] = true;
254+ }
255+ auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
256+ for (const auto &ta : llvm::enumerate(perm)) {
257+ if (src[ta.index()] != dst[ta.value()])
258+ return false;
259+ }
260+ return true;
261+ };
262+ // check sgLayout
263+ if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
264+ return false;
265+ // check sgData
266+ if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
267+ return false;
268+ // check instData
269+ if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
270+ return false;
271+ // check laneLayout
272+ if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
273+ return false;
274+ // check laneData
275+ if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
276+ return false;
277+ return true;
278+ }]>,
279+ InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
280+ /*retTy=*/"bool",
281+ /*methodName=*/"isSliceOf",
282+ /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
235283 ];
236284}
237285
@@ -433,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
433481 FailureOr<SmallVector<SmallVector<Value>>>
434482 getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
435483
484+ /// Check if this is slice of some other layout.
485+ bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
486+
436487 }];
437488
438489 let assemblyFormat = "`<` struct(params) `>`";
@@ -594,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
594645 FailureOr<SmallVector<SmallVector<Value>>>
595646 getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
596647
648+ /// Check if this is slice of some other layout.
649+ bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
650+
597651 }];
598652
599653 let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
0 commit comments