@@ -191,6 +191,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
191191 InterfaceMethod<"Get the rank of attribute",
192192 "int64_t",
193193 "getRank">,
194+ InterfaceMethod<"Get the order field of the attribute as integer array",
195+ "DenseI32ArrayAttr",
196+ "getOrder">,
194197 InterfaceMethod<"Get the num of effective subgroups",
195198 "int64_t",
196199 "getNumSubgroups", (ins), [{
@@ -253,33 +256,40 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
253256 seen[ta.value()] = true;
254257 }
255258 auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
259+ // If both `dst` and `src` are empty, conservatively return true
260+ // here because some layout fields can be empty.
261+ if (dst.empty() && src.empty())
262+ return true;
256263 for (const auto &ta : llvm::enumerate(perm)) {
257264 if (src[ta.index()] != dst[ta.value()])
258265 return false;
259266 }
260267 return true;
261268 };
262- // check sgLayout
269+ // Check sgLayout
263270 if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
264271 return false;
265- // check sgData
272+ // Check sgData
266273 if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
267274 return false;
268- // check instData
275+ // Check instData
269276 if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
270277 return false;
271- // check laneLayout
278+ // Check laneLayout
272279 if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
273280 return false;
274- // check laneData
281+ // Check laneData
275282 if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
276283 return false;
284+ // Check order if both sides have order field.
285+ if ($_self.getOrder() && other.getOrder()) {
286+ auto thisOrderAsInt = llvm::to_vector_of<int64_t>($_self.getOrder().asArrayRef());
287+ auto otherOrderAsInt = llvm::to_vector_of<int64_t>(other.getOrder().asArrayRef());
288+ if (!checkTranspose(thisOrderAsInt, otherOrderAsInt, perm))
289+ return false;
290+ }
277291 return true;
278- }]>,
279- InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
280- /*retTy=*/"bool",
281- /*methodName=*/"isSliceOf",
282- /*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
292+ }]>
283293 ];
284294}
285295
@@ -481,9 +491,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
481491 FailureOr<SmallVector<SmallVector<Value>>>
482492 getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
483493
484- /// Check if this is slice of some other layout.
485- bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
486-
487494 }];
488495
489496 let assemblyFormat = "`<` struct(params) `>`";
@@ -645,9 +652,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
645652 FailureOr<SmallVector<SmallVector<Value>>>
646653 getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
647654
648- /// Check if this is slice of some other layout.
649- bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
650-
651655 }];
652656
653657 let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
0 commit comments