@@ -316,22 +316,21 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
316316};
317317
318318// / This pattern transforms vector.broadcast ops to work at subgroup level.
319- // / It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
320- struct WgToSgVectorBroadcastOp : public OpConversionPattern <vector::BroadcastOp> {
319+ struct WgToSgVectorBroadcastOp
320+ : public OpConversionPattern<vector::BroadcastOp> {
321321 using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
322322
323323 LogicalResult
324324 matchAndRewrite (vector::BroadcastOp op, OneToNOpAdaptor adaptor,
325325 ConversionPatternRewriter &rewriter) const override {
326326 auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
327327 if (!resultType)
328- return rewriter. notifyMatchFailure (op, " Result is not a vector type " );
328+ return failure ( );
329329
330330 // Only handle broadcasts to vectors with XeGPU layout attribute
331331 xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
332332 if (!layout || !layout.getSgLayout ())
333- return rewriter.notifyMatchFailure (
334- op, " Result does not have a valid layout attribute for subgroup distribution" );
333+ return failure ();
335334
336335 // Extract sgShape from layout
337336 SmallVector<int64_t > sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
347346 }
348347 }
349348
350- VectorType newResultType = VectorType::get (sgShape, resultType.getElementType ());
349+ VectorType newResultType =
350+ VectorType::get (sgShape, resultType.getElementType ());
351351 SmallVector<Value> newBroadcasts;
352352
353- // The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
354- for (Value unused : adaptor. getOperands (). front ()) {
355- // All subgroups get the same broadcasted value
353+ // The operand is always a scalar or lower-rank vector, so just broadcast
354+ // for each subgroup
355+ for ( size_t i = 0 ; i < adaptor. getOperands (). front (). size (); ++i) {
356356 auto newBroadcast = rewriter.create <vector::BroadcastOp>(
357- op.getLoc (), newResultType, adaptor.getOperands ().front ()[0 ]);
358- xegpu::setLayoutAttr (newBroadcast->getResult (0 ), layout.dropSgLayoutAndData ());
357+ op.getLoc (), newResultType, adaptor.getOperands ().front ()[i]);
358+ xegpu::setLayoutAttr (newBroadcast->getResult (0 ),
359+ layout.dropSgLayoutAndData ());
359360 newBroadcasts.push_back (newBroadcast.getResult ());
360361 }
361362
@@ -371,8 +372,7 @@ namespace xegpu {
371372void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
372373 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
373374 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
374- WgToSgVectorBroadcastOp>(
375- patterns.getContext ());
375+ WgToSgVectorBroadcastOp>(patterns.getContext ());
376376}
377377} // namespace xegpu
378378} // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
420420 return isLegal (layout);
421421 });
422422
423- target.addDynamicallyLegalOp <vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
424- auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
425- if (!resultType)
426- return true ;
427- auto layout = xegpu::getLayoutAttr (op.getResult ());
428- return isLegal (layout);
429- });
423+ target.addDynamicallyLegalOp <vector::BroadcastOp>(
424+ [=](vector::BroadcastOp op) -> bool {
425+ auto resultType = dyn_cast<VectorType>(op.getResult ().getType ());
426+ if (!resultType)
427+ return true ;
428+ xegpu::LayoutAttr = xegpu::getLayoutAttr (op.getResult ());
429+ return isLegal (layout);
430+ });
430431
431432 target.markUnknownOpDynamicallyLegal ([](Operation *) { return true ; });
432433
0 commit comments