Skip to content

Commit 81fd298

Browse files
committed
Relax checks in vector.transpose
1 parent e3905a4 commit 81fd298

File tree

1 file changed

+10
-20
lines changed

1 file changed

+10
-20
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1239,37 +1239,27 @@ struct WgToSgVectorTransposeOp
12391239
if (!layout || !layout.isForWorkgroup())
12401240
return failure();
12411241

1242-
xegpu::DistributeLayoutAttr sourceLayout =
1243-
xegpu::getDistributeLayoutAttr(op.getVector());
1244-
if (!sourceLayout || !sourceLayout.isForWorkgroup())
1245-
return failure();
1246-
1247-
SmallVector<int64_t> sourceSgLayout =
1248-
sourceLayout.getEffectiveSgLayoutAsInt();
12491242
SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1250-
DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
12511243
DenseI32ArrayAttr resultOrder = layout.getOrder();
12521244

1253-
if (!sourceOrder || !resultOrder) {
1254-
return rewriter.notifyMatchFailure(
1255-
op, "Both source and result must have order attributes");
1245+
bool is1DTranspose = (llvm::count_if(resultSgLayout, [](int64_t dim) {
1246+
return dim > 1;
1247+
}) <= 1);
1248+
1249+
if (!is1DTranspose) {
1250+
if (!resultOrder) {
1251+
return rewriter.notifyMatchFailure(
1252+
op, "Multi-dimensional Transposes must have order attributes");
1253+
}
12561254
}
12571255

12581256
ArrayRef<int64_t> permutation = op.getPermutation();
12591257
size_t permutationSize = permutation.size();
1260-
if (sourceSgLayout.size() != permutationSize ||
1261-
resultSgLayout.size() != permutationSize) {
1258+
if (resultSgLayout.size() != permutationSize) {
12621259
return rewriter.notifyMatchFailure(
12631260
op, "Layouts and permutation must have the same rank");
12641261
}
12651262

1266-
// Check that sgLayout, sgData & order are properly transposed for source
1267-
// and result
1268-
if (!layout.isTransposeOf(sourceLayout, permutation))
1269-
return rewriter.notifyMatchFailure(
1270-
op, "Result layout is not a valid transpose of source layout "
1271-
"according to permutation");
1272-
12731263
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
12741264
VectorType newResultType =
12751265
VectorType::get(sgShape, resultType.getElementType());

0 commit comments

Comments
 (0)