We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 1d17537 commit 425d677Copy full SHA for 425d677
mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -357,6 +357,15 @@ struct WgToSgVectorBroadcastOp
357
VectorType newResultType =
358
VectorType::get(sgShape, resultType.getElementType());
359
360
+ // Check if the srcShape has unit dim in dimensions being broadcasted,
361
+ // and the other dimensions are the same as the destination type
362
+ // TODO: Generalize it
363
+ auto srcShape = srcType.getShape();
364
+ for (size_t i = 0; i < srcShape.size(); ++i) {
365
+ if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
366
+ return failure();
367
+ }
368
+
369
SmallVector<Value> newBroadcastOps;
370
for (auto operand : adaptor.getOperands().front()) {
371
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
0 commit comments