diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index be82cda574f1e..d430cf7eb3dfc 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1239,37 +1239,27 @@ struct WgToSgVectorTransposeOp if (!layout || !layout.isForWorkgroup()) return failure(); - xegpu::DistributeLayoutAttr sourceLayout = - xegpu::getDistributeLayoutAttr(op.getVector()); - if (!sourceLayout || !sourceLayout.isForWorkgroup()) - return failure(); - - SmallVector sourceSgLayout = - sourceLayout.getEffectiveSgLayoutAsInt(); SmallVector resultSgLayout = layout.getEffectiveSgLayoutAsInt(); - DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder(); DenseI32ArrayAttr resultOrder = layout.getOrder(); - if (!sourceOrder || !resultOrder) { - return rewriter.notifyMatchFailure( - op, "Both source and result must have order attributes"); + bool is1DTranspose = (llvm::count_if(resultSgLayout, [](int64_t dim) { + return dim > 1; + }) <= 1); + + if (!is1DTranspose) { + if (!resultOrder) { + return rewriter.notifyMatchFailure( + op, "Multi-dimensional Transposes must have order attributes"); + } } ArrayRef permutation = op.getPermutation(); size_t permutationSize = permutation.size(); - if (sourceSgLayout.size() != permutationSize || - resultSgLayout.size() != permutationSize) { + if (resultSgLayout.size() != permutationSize) { return rewriter.notifyMatchFailure( op, "Layouts and permutation must have the same rank"); } - // Check that sgLayout, sgData & order are properly transposed for source - // and result - if (!layout.isTransposeOf(sourceLayout, permutation)) - return rewriter.notifyMatchFailure( - op, "Result layout is not a valid transpose of source layout " - "according to permutation"); - SmallVector sgShape = getSgShapeAndCount(wgShape, layout).first; VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());