Skip to content

Commit 425d677

Browse files
committed
add check
1 parent 1d17537 commit 425d677

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,15 @@ struct WgToSgVectorBroadcastOp
357357
VectorType newResultType =
358358
VectorType::get(sgShape, resultType.getElementType());
359359

360+
// Check if the srcShape has unit dim in dimensions being broadcasted,
361+
// and the other dimensions are the same as the destination type
362+
// TODO: Generalize it
363+
auto srcShape = srcType.getShape();
364+
for (size_t i = 0; i < srcShape.size(); ++i) {
365+
if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
366+
return failure();
367+
}
368+
360369
SmallVector<Value> newBroadcastOps;
361370
for (auto operand : adaptor.getOperands().front()) {
362371
auto newBroadcast = rewriter.create<vector::BroadcastOp>(

0 commit comments

Comments
 (0)