@@ -338,8 +338,6 @@ struct WgToSgVectorBroadcastOp
338338 ConversionPatternRewriter &rewriter) const override {
339339 VectorType resultType = op.getResult ().getType ();
340340 ArrayRef<int64_t > wgShape = resultType.getShape ();
341- if (!resultType)
342- return failure ();
343341
344342 xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
345343 if (!layout || !layout.getSgLayout ())
@@ -348,17 +346,17 @@ struct WgToSgVectorBroadcastOp
348346 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
349347 VectorType newResultType =
350348 VectorType::get (sgShape, resultType.getElementType ());
351- SmallVector<Value> newBroadcasts;
352349
350+ SmallVector<Value> newBroadcastOps;
353351 for (size_t i = 0 ; i < adaptor.getOperands ().front ().size (); ++i) {
354352 auto newBroadcast = rewriter.create <vector::BroadcastOp>(
355353 op.getLoc (), newResultType, adaptor.getOperands ().front ()[i]);
356354 xegpu::setLayoutAttr (newBroadcast->getResult (0 ),
357355 layout.dropSgLayoutAndData ());
358- newBroadcasts .push_back (newBroadcast.getResult ());
356+ newBroadcastOps .push_back (newBroadcast.getResult ());
359357 }
360358
361- rewriter.replaceOpWithMultiple (op, {newBroadcasts });
359+ rewriter.replaceOpWithMultiple (op, {newBroadcastOps });
362360 return success ();
363361 }
364362};
@@ -556,11 +554,7 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
556554
557555 target.addDynamicallyLegalOp <vector::BroadcastOp>(
558556 [=](vector::BroadcastOp op) -> bool {
559- auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
560- if (!resultType)
561- return true ;
562- xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
563- return isLegal (layout);
557+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
564558 });
565559
566560 target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
0 commit comments