Skip to content

Commit 692ae9e

Browse files
committed
add check
1 parent 2c97ee7 commit 692ae9e

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,21 @@ struct WgToSgVectorBroadcastOp
343343
if (!layout || !layout.getSgLayout())
344344
return failure();
345345

346+
// TODO: Currently only supports cases where the source and result ranks
347+
// are the same.
348+
auto srcType =
349+
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
350+
if (!srcType || srcType.getRank() != resultType.getRank())
351+
return failure();
352+
346353
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
347354
VectorType newResultType =
348355
VectorType::get(sgShape, resultType.getElementType());
349356

350357
SmallVector<Value> newBroadcastOps;
351-
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
358+
for (auto operand : adaptor.getOperands().front()) {
352359
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
353-
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
360+
op.getLoc(), newResultType, operand);
354361
xegpu::setLayoutAttr(newBroadcast->getResult(0),
355362
layout.dropSgLayoutAndData());
356363
newBroadcastOps.push_back(newBroadcast.getResult());

0 commit comments

Comments
 (0)