Skip to content

Commit dc3a250

Browse files
committed
use isTransposeOf
1 parent d0de396 commit dc3a250

File tree

3 files changed

+25
-50
lines changed

3 files changed

+25
-50
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,9 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
191191
InterfaceMethod<"Get the rank of attribute",
192192
"int64_t",
193193
"getRank">,
194+
InterfaceMethod<"Get the order field of the attribute as integer array",
195+
"DenseI32ArrayAttr",
196+
"getOrder">,
194197
InterfaceMethod<"Get the num of effective subgroups",
195198
"int64_t",
196199
"getNumSubgroups", (ins), [{
@@ -253,33 +256,40 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
253256
seen[ta.value()] = true;
254257
}
255258
auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
259+
// If both `dst` and `src` are empty, conservatively return true
260+
// here because some layout fields can be empty.
261+
if (dst.empty() && src.empty())
262+
return true;
256263
for (const auto &ta : llvm::enumerate(perm)) {
257264
if (src[ta.index()] != dst[ta.value()])
258265
return false;
259266
}
260267
return true;
261268
};
262-
// check sgLayout
269+
// Check sgLayout
263270
if (!checkTranspose($_self.getSgLayoutAsInt(), other.getSgLayoutAsInt(), perm))
264271
return false;
265-
// check sgData
272+
// Check sgData
266273
if (!checkTranspose($_self.getSgDataAsInt(), other.getSgDataAsInt(), perm))
267274
return false;
268-
// check instData
275+
// Check instData
269276
if (!checkTranspose($_self.getInstDataAsInt(), other.getInstDataAsInt(), perm))
270277
return false;
271-
// check laneLayout
278+
// Check laneLayout
272279
if (!checkTranspose($_self.getLaneLayoutAsInt(), other.getLaneLayoutAsInt(), perm))
273280
return false;
274-
// check laneData
281+
// Check laneData
275282
if (!checkTranspose($_self.getLaneDataAsInt(), other.getLaneDataAsInt(), perm))
276283
return false;
284+
// Check order if both sides have order field.
285+
if ($_self.getOrder() && other.getOrder()) {
286+
auto thisOrderAsInt = llvm::to_vector_of<int64_t>($_self.getOrder().asArrayRef());
287+
auto otherOrderAsInt = llvm::to_vector_of<int64_t>(other.getOrder().asArrayRef());
288+
if (!checkTranspose(thisOrderAsInt, otherOrderAsInt, perm))
289+
return false;
290+
}
277291
return true;
278-
}]>,
279-
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
280-
/*retTy=*/"bool",
281-
/*methodName=*/"isSliceOf",
282-
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other)>
292+
}]>
283293
];
284294
}
285295

@@ -481,9 +491,6 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
481491
FailureOr<SmallVector<SmallVector<Value>>>
482492
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
483493

484-
/// Check if this is slice of some other layout.
485-
bool isSliceOf(const xegpu::DistributeLayoutAttr &other) { return false; }
486-
487494
}];
488495

489496
let assemblyFormat = "`<` struct(params) `>`";
@@ -645,9 +652,6 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
645652
FailureOr<SmallVector<SmallVector<Value>>>
646653
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
647654

648-
/// Check if this is slice of some other layout.
649-
bool isSliceOf(const xegpu::DistributeLayoutAttr &other);
650-
651655
}];
652656

653657
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -410,26 +410,6 @@ SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
410410
shape);
411411
}
412412

413-
bool SliceAttr::isSliceOf(const xegpu::DistributeLayoutAttr &other) {
414-
auto flattenedThis = flatten();
415-
// If other is a LayoutAttr, just compare directly with parent of
416-
// flattenedThis.
417-
if (auto otherLayout = dyn_cast<xegpu::LayoutAttr>(other))
418-
return flattenedThis.getParent() == otherLayout;
419-
// If other is a SliceAttr, flatten it first before comparing.
420-
auto flattenedOther = dyn_cast<xegpu::SliceAttr>(other).flatten();
421-
// Both must have common parent LayoutAttr.
422-
if (flattenedThis.getParent() != flattenedOther.getParent())
423-
return false;
424-
// otherFlattened's sliced dims must be a subset of flattenedThis's sliced
425-
// dims.
426-
llvm::SmallDenseSet<int64_t> thisDims(
427-
flattenedThis.getDims().asArrayRef().begin(),
428-
flattenedThis.getDims().asArrayRef().end());
429-
return llvm::all_of(flattenedOther.getDims().asArrayRef(),
430-
[&](int64_t dim) { return thisDims.contains(dim); });
431-
}
432-
433413
//===----------------------------------------------------------------------===//
434414
// XeGPU_RangeAttr
435415
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1108,24 +1108,15 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
11081108
transposeOp,
11091109
"the source or result vector of the transpose op lacks layout "
11101110
"attribute");
1111-
SmallVector<int64_t> sourceLaneLayout = sourceLayout.getLaneLayoutAsInt();
1112-
SmallVector<int64_t> resultLaneLayout = resultLayout.getLaneLayoutAsInt();
1113-
SmallVector<int64_t> sourceLaneData = sourceLayout.getLaneDataAsInt();
1114-
SmallVector<int64_t> resultLaneData = resultLayout.getLaneDataAsInt();
1115-
if (sourceLaneLayout.size() != 2 || resultLaneLayout.size() != 2)
1111+
if (sourceLayout.getRank() != 2 || resultLayout.getRank() != 2)
11161112
return rewriter.notifyMatchFailure(
11171113
transposeOp, "the source or result vector of the transpose op "
11181114
"does not have 2D layout");
1119-
auto is2DTranspose = [](ArrayRef<int64_t> input, ArrayRef<int64_t> output) {
1120-
return input.size() == 2 && output.size() == 2 && input[0] == output[1] &&
1121-
input[1] == output[0];
1122-
};
1123-
1124-
if (!is2DTranspose(sourceLaneLayout, resultLaneLayout) ||
1125-
!is2DTranspose(sourceLaneData, resultLaneData))
1115+
ArrayRef<int64_t> perm = transposeOp.getPermutation();
1116+
if (!resultLayout.isTransposeOf(sourceLayout, perm))
11261117
return rewriter.notifyMatchFailure(
11271118
transposeOp,
1128-
"the source or result vector layouts must be transposes of each "
1119+
"the source or result vector layouts must be 2D transposes of each "
11291120
"other");
11301121
FailureOr<VectorType> distributedSourceTypeOrFailure =
11311122
getDistVecTypeBasedOnLaneLayout(sourceLayout,
@@ -1141,7 +1132,7 @@ struct VectorTransposeDistribution final : public gpu::WarpDistributionPattern {
11411132
rewriter.setInsertionPointAfter(newWarpOp);
11421133
auto newTransposeOp = vector::TransposeOp::create(
11431134
rewriter, newWarpOp.getLoc(), newWarpOp.getResult(newRetIndices[0]),
1144-
transposeOp.getPermutation());
1135+
perm);
11451136
Value distributedVal = newWarpOp.getResult(operandIdx);
11461137
rewriter.replaceAllUsesWith(distributedVal, newTransposeOp.getResult());
11471138
return success();

0 commit comments

Comments
 (0)