Skip to content

Commit 2998c74

Browse files
authored
[mlir][xegpu] Add SIMT distribution support for GEMM transpose B case. (#155517)
This PR adds the features needed for supporting the GEMM with transpose B case. Summary of changes. 1). Add distribution logic for `vector.bitcast`, `vector.transpose` and `memref.extract_aligned_pointer_as_index` cases. 2). Add layout propagation support for `vector.shape_cast`, `vector.broadcast` and `vector.bitcast` 3). Incorporate slice attribute and `DistributeLayoutAttr` interface with the core logic in layout prop.
1 parent bedfee0 commit 2998c74

File tree

5 files changed

+717
-159
lines changed

5 files changed

+717
-159
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,54 @@ def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
235235
"FailureOr<SmallVector<SmallVector<Value>>>",
236236
"getOffsets",
237237
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>,
238+
InterfaceMethod</*desc=*/[{Check if this layout can be achieved by applying a transpose
239+
to some other layout according to given permutation of (0...n-1).}],
240+
/*retTy=*/"bool",
241+
/*methodName=*/"isTransposeOf",
242+
/*args=*/(ins "const xegpu::DistributeLayoutAttr&": $other, "ArrayRef<int64_t>": $perm),
243+
/*methodBody=*/[{
244+
if (!other)
245+
return false;
246+
if ($_self.getRank() != other.getRank() || perm.size() != static_cast<size_t>($_self.getRank()))
247+
return false;
248+
// Check if the permutation is valid
249+
if (!isPermutationVector(perm))
250+
return false;
251+
auto checkTranspose = [](ArrayRef<int64_t> dst, ArrayRef<int64_t> src, ArrayRef<int64_t> perm) {
252+
// If both `dst` and `src` are empty, conservatively return true
253+
// here because some layout fields can be empty.
254+
if (dst.empty() && src.empty())
255+
return true;
256+
for (const auto &ta : llvm::enumerate(perm)) {
257+
if (src[ta.index()] != dst[ta.value()])
258+
return false;
259+
}
260+
return true;
261+
};
262+
// Check sgLayout
263+
if (!checkTranspose($_self.getEffectiveSgLayoutAsInt(), other.getEffectiveSgLayoutAsInt(), perm))
264+
return false;
265+
// Check sgData
266+
if (!checkTranspose($_self.getEffectiveSgDataAsInt(), other.getEffectiveSgDataAsInt(), perm))
267+
return false;
268+
// Check instData
269+
if (!checkTranspose($_self.getEffectiveInstDataAsInt(), other.getEffectiveInstDataAsInt(), perm))
270+
return false;
271+
// Check laneLayout
272+
if (!checkTranspose($_self.getEffectiveLaneLayoutAsInt(), other.getEffectiveLaneLayoutAsInt(), perm))
273+
return false;
274+
// Check laneData
275+
if (!checkTranspose($_self.getEffectiveLaneDataAsInt(), other.getEffectiveLaneDataAsInt(), perm))
276+
return false;
277+
// Check order if both sides have order field.
278+
if ($_self.getOrder() && other.getOrder()) {
279+
auto thisOrderAsInt = llvm::to_vector_of<int64_t>($_self.getOrder().asArrayRef());
280+
auto otherOrderAsInt = llvm::to_vector_of<int64_t>(other.getOrder().asArrayRef());
281+
if (!checkTranspose(thisOrderAsInt, otherOrderAsInt, perm))
282+
return false;
283+
}
284+
return true;
285+
}]>,
238286
InterfaceMethod</*desc=*/[{Check if this layout is a slice of some other layout.}],
239287
/*retTy=*/"bool",
240288
/*methodName=*/"isSliceOf",

0 commit comments

Comments
 (0)