@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
331
331
}
332
332
};
333
333
334
+ // / This pattern transforms vector.broadcast ops to work at subgroup level.
335
+ struct WgToSgVectorBroadcastOp
336
+ : public OpConversionPattern<vector::BroadcastOp> {
337
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
338
+
339
+ LogicalResult
340
+ matchAndRewrite (vector::BroadcastOp op, OneToNOpAdaptor adaptor,
341
+ ConversionPatternRewriter &rewriter) const override {
342
+ VectorType resultType = op.getResult ().getType ();
343
+ ArrayRef<int64_t > wgShape = resultType.getShape ();
344
+
345
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
346
+ if (!layout || !layout.getSgLayout ())
347
+ return failure ();
348
+
349
+ // TODO: Currently only supports cases where the source and result ranks
350
+ // are the same.
351
+ auto srcType =
352
+ dyn_cast<VectorType>(adaptor.getOperands ().front ()[0 ].getType ());
353
+ if (!srcType || srcType.getRank () != resultType.getRank ())
354
+ return failure ();
355
+
356
+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
357
+ VectorType newResultType =
358
+ VectorType::get (sgShape, resultType.getElementType ());
359
+
360
+ // Check if the output layout is distributable
361
+ SmallVector<int64_t > sgLayout;
362
+ if (auto sgLayoutAttr = layout.getSgLayout ())
363
+ sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
364
+ else
365
+ return failure ();
366
+
367
+ if (!xegpu::XeGPUDialect::isEvenlyDistributable (wgShape, layout))
368
+ return failure ();
369
+
370
+ // Check if the srcShape has unit dim in dimensions being broadcasted,
371
+ // and the other dimensions are the same as the destination type
372
+ // TODO: Generalize it
373
+ auto srcShape = srcType.getShape ();
374
+ for (size_t i = 0 ; i < srcShape.size (); ++i) {
375
+ if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
376
+ return failure ();
377
+ }
378
+
379
+ SmallVector<Value> newBroadcastOps;
380
+ for (auto operand : adaptor.getOperands ().front ()) {
381
+ auto newBroadcast = rewriter.create <vector::BroadcastOp>(
382
+ op.getLoc (), newResultType, operand);
383
+ xegpu::setLayoutAttr (newBroadcast->getResult (0 ),
384
+ layout.dropSgLayoutAndData ());
385
+ newBroadcastOps.push_back (newBroadcast.getResult ());
386
+ }
387
+
388
+ rewriter.replaceOpWithMultiple (op, {newBroadcastOps});
389
+ return success ();
390
+ }
391
+ };
392
+
334
393
// This pattern transforms elementwise ops to work at subgroup level.
335
394
struct WgToSgElementwiseOp : public ConversionPattern {
336
395
WgToSgElementwiseOp (MLIRContext *ctx)
@@ -475,8 +534,8 @@ namespace xegpu {
475
534
void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
476
535
patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
477
536
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
478
- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
479
- patterns.getContext ());
537
+ UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
538
+ WgToSgVectorBroadcastOp>( patterns.getContext ());
480
539
}
481
540
} // namespace xegpu
482
541
} // namespace mlir
@@ -583,6 +642,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
583
642
return isLegal (layout);
584
643
});
585
644
645
+ target.addDynamicallyLegalOp <vector::BroadcastOp>(
646
+ [=](vector::BroadcastOp op) -> bool {
647
+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
648
+ });
649
+
586
650
target.addDynamicallyLegalDialect <math::MathDialect, arith::ArithDialect>(
587
651
[=](Operation *op) -> std::optional<bool > {
588
652
// Only handle elementwise mappable ops
0 commit comments