File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed
mlir/lib/Dialect/XeGPU/Transforms Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -343,14 +343,21 @@ struct WgToSgVectorBroadcastOp
343343 if (!layout || !layout.getSgLayout ())
344344 return failure ();
345345
346+ // TODO: Currently only supports cases where the source and result ranks
347+ // are the same.
348+ auto srcType =
349+ dyn_cast<VectorType>(adaptor.getOperands ().front ()[0 ].getType ());
350+ if (!srcType || srcType.getRank () != resultType.getRank ())
351+ return failure ();
352+
346353 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
347354 VectorType newResultType =
348355 VectorType::get (sgShape, resultType.getElementType ());
349356
350357 SmallVector<Value> newBroadcastOps;
351- for (size_t i = 0 ; i < adaptor.getOperands ().front (). size (); ++i ) {
358+ for (auto operand : adaptor.getOperands ().front ()) {
352359 auto newBroadcast = rewriter.create <vector::BroadcastOp>(
353- op.getLoc (), newResultType, adaptor. getOperands (). front ()[i] );
360+ op.getLoc (), newResultType, operand );
354361 xegpu::setLayoutAttr (newBroadcast->getResult (0 ),
355362 layout.dropSgLayoutAndData ());
356363 newBroadcastOps.push_back (newBroadcast.getResult ());
You can’t perform that action at this time.
0 commit comments