Skip to content

Commit 0550d4b

Browse files
committed
Merge branch 'slice_utils' into vector_multi_reduction_distr_refactor
2 parents f9b3933 + 77e8a94 commit 0550d4b

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,55 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
231231
multiple blocks according to round-robin distribution rules.}],
232232
"FailureOr<SmallVector<SmallVector<Value>>>",
233233
"getOffsets",
234-
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
234+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
235+
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
236+
to some other layout according to given permutation of (0...n-1).}],
237+
/*retTy=*/"bool",
238+
/*methodName=*/"isTransposeOf",
239+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
240+
/*methodBody=*/[{
241+
if (!other)
242+
return false;
243+
if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
244+
return false;
245+
// check if the permutation is valid
246+
int64_t rank = $_self.getRank();
247+
SmallVector<bool, 8> seen(rank, false);
248+
for (const auto &ta : llvm::enumerate(perm)) {
249+
if (ta.value() < 0 || ta.value() >= rank)
250+
return false;
251+
if (seen[ta.value()])
252+
return false;
253+
seen[ta.value()] = true;
254+
}
255+
auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
256+
for (const auto &ta : llvm::enumerate(perm)) {
257+
if (src[ta.index()] != dst[ta.value()])
258+
return false;
259+
}
260+
return true;
261+
};
262+
// check sgLayout
263+
if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
264+
return false;
265+
// check sgData
266+
if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
267+
return false;
268+
// check instData
269+
if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
270+
return false;
271+
// check laneLayout
272+
if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
273+
return false;
274+
// check laneData
275+
if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
276+
return false;
277+
return true;
278+
}]>,
279+
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
280+
/*retTy=*/"bool",
281+
/*methodName=*/"isSliceOf",
282+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
235283
];
236284
}
237285

@@ -433,6 +481,9 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
433481
FailureOr<SmallVector<SmallVector<Value>>>
434482
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
435483

484+
/// Check if this is slice of some other layout.
485+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
486+
436487
}];
437488

438489
let assemblyFormat = "`<` struct(params) `>`";
@@ -594,6 +645,9 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
594645
FailureOr<SmallVector<SmallVector<Value>>>
595646
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
596647

648+
/// Check if this is slice of some other layout.
649+
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
650+
597651
}];
598652

599653
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
1515
#include "mlir/IR/Builders.h"
1616
#include "mlir/IR/DialectImplementation.h"
17+
#include "llvm/ADT/STLExtras.h"
1718
#include "llvm/ADT/TypeSwitch.h"
1819
#include "llvm/Support/Debug.h"
1920

@@ -409,6 +410,26 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
409410
shape);
410411
}
411412

413+
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
414+
auto flattenedThis = flatten();
415+
// If other is a LayoutAttr, just compare directly with parent of
416+
// flattenedThis.
417+
if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
418+
return flattenedThis.getParent() == otherLayout;
419+
// If other is a SliceAttr, flatten it first before comparing.
420+
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
421+
// Both must have common parent LayoutAttr.
422+
if (flattenedThis.getParent() != flattenedOther.getParent())
423+
return false;
424+
// otherFlattened's sliced dims must be a subset of flattenedThis's sliced
425+
// dims.
426+
llvm::SmallDenseSet<int64_t> thisDims(
427+
flattenedThis.getDims().asArrayRef().begin(),
428+
flattenedThis.getDims().asArrayRef().end());
429+
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
430+
[&](int64_t dim) { return thisDims.contains(dim); });
431+
}
432+
412433
//===----------------------------------------------------------------------===//
413434
// XeGPU_RangeAttr
414435
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)