Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 10 additions & 20 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1239,37 +1239,27 @@ struct WgToSgVectorTransposeOp
if (!layout || !layout.isForWorkgroup())
return failure();

xegpu::DistributeLayoutAttr sourceLayout =
xegpu::getDistributeLayoutAttr(op.getVector());
if (!sourceLayout || !sourceLayout.isForWorkgroup())
return failure();

SmallVector<int64_t> sourceSgLayout =
sourceLayout.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
DenseI32ArrayAttr resultOrder = layout.getOrder();

if (!sourceOrder || !resultOrder) {
return rewriter.notifyMatchFailure(
op, "Both source and result must have order attributes");
bool is1DTranspose = (llvm::count_if(resultSgLayout, [](int64_t dim) {
return dim > 1;
}) <= 1);

if (!is1DTranspose) {
if (!resultOrder) {
return rewriter.notifyMatchFailure(
op, "Multi-dimensional Transposes must have order attributes");
}
}

ArrayRef<int64_t> permutation = op.getPermutation();
size_t permutationSize = permutation.size();
if (sourceSgLayout.size() != permutationSize ||
resultSgLayout.size() != permutationSize) {
if (resultSgLayout.size() != permutationSize) {
return rewriter.notifyMatchFailure(
op, "Layouts and permutation must have the same rank");
}

// Check that sgLayout, sgData & order are properly transposed for source
// and result
if (!layout.isTransposeOf(sourceLayout, permutation))
return rewriter.notifyMatchFailure(
op, "Result layout is not a valid transpose of source layout "
"according to permutation");

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());
Expand Down