Skip to content

Commit 916c75f

Browse files
committed
add slice attribute utils
1 parent 82486fa commit 916c75f

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,11 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
275275
if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
276276
return false;
277277
return true;
278-
}]>
278+
}]>,
279+
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
280+
/*retTy=*/"bool",
281+
/*methodName=*/"isSliceOf",
282+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
279283
];
280284
}
281285

@@ -477,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
477481
FailureOr<SmallVector<SmallVector<Value>>>
478482
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
479483

484+
/// Check if this is slice of some other layout.
485+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
486+
480487
}];
481488

482489
let assemblyFormat = "`<` struct(params) `>`";
@@ -638,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
638645
FailureOr<SmallVector<SmallVector<Value>>>
639646
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
640647

648+
/// Check if this is slice of some other layout.
649+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
650+
641651
}];
642652

643653
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
1515
#include "mlir/IR/Builders.h"
1616
#include "mlir/IR/DialectImplementation.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/TypeSwitch.h"
1819
#include "llvm/Support/Debug.h"
1920

@@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
409410
shape);
410411
}
411412

413+
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
414+
auto flattenedThis = flatten();
415+
// If other is a LayoutAttr, just compare directly with parent of
416+
// flattenedThis.
417+
if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
418+
return flattenedThis.getParent() == otherLayout;
419+
// If other is a SliceAttr, flatten it first before comparing.
420+
auto otherFlattened = dyn_cast<xegpu::SliceAttr>(other).flatten();
421+
// Both must have common parent LayoutAttr.
422+
if (flattenedThis.getParent() != otherFlattened.getParent())
423+
return false;
424+
// otherFlattened's sliced dims must be a subset of flattenedThis's sliced
425+
// dims.
426+
llvm::SmallDenseSet<int64_t> thisDims(
427+
flattenedThis.getDims().asArrayRef().begin(),
428+
flattenedThis.getDims().asArrayRef().end());
429+
return llvm::all_of(otherFlattened.getDims().asArrayRef(),
430+
[&](int64_t dim) { return thisDims.contains(dim); });
431+
}
432+
412433
//===----------------------------------------------------------------------===//
413434
// XeGPU_RangeAttr
414435
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)