@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
331331 }
332332};
333333
334+ // / This pattern transforms vector.broadcast ops to work at subgroup level.
335+ struct WgToSgVectorBroadcastOp
336+ : public OpConversionPattern<vector::BroadcastOp> {
337+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
338+
339+ LogicalResult
340+ matchAndRewrite (vector::BroadcastOp op, OneToNOpAdaptor adaptor,
341+ ConversionPatternRewriter &rewriter) const override {
342+ VectorType resultType = op.getResult ().getType ();
343+ ArrayRef<int64_t > wgShape = resultType.getShape ();
344+
345+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
346+ if (!layout || !layout.getSgLayout ())
347+ return failure ();
348+
349+ // TODO: Currently only supports cases where the source and result ranks
350+ // are the same.
351+ auto srcType =
352+ dyn_cast<VectorType>(adaptor.getOperands ().front ()[0 ].getType ());
353+ if (!srcType || srcType.getRank () != resultType.getRank ())
354+ return failure ();
355+
356+ SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
357+ VectorType newResultType =
358+ VectorType::get (sgShape, resultType.getElementType ());
359+
360+ // Check if the output layout is distributable
361+ SmallVector<int64_t > sgLayout;
362+ if (auto sgLayoutAttr = layout.getSgLayout ())
363+ sgLayout = llvm::to_vector_of<int64_t >(sgLayoutAttr.asArrayRef ());
364+ else
365+ return failure ();
366+
367+ if (!xegpu::XeGPUDialect::isEvenlyDistributable (wgShape, layout))
368+ return failure ();
369+
370+ // Check if the srcShape has unit dim in dimensions being broadcasted,
371+ // and the other dimensions are the same as the destination type
372+ // TODO: Generalize it
373+ auto srcShape = srcType.getShape ();
374+ for (size_t i = 0 ; i < srcShape.size (); ++i) {
375+ if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
376+ return failure ();
377+ }
378+
379+ SmallVector<Value> newBroadcastOps;
380+ for (auto operand : adaptor.getOperands ().front ()) {
381+ auto newBroadcast = rewriter.create <vector::BroadcastOp>(
382+ op.getLoc (), newResultType, operand);
383+ xegpu::setLayoutAttr (newBroadcast->getResult (0 ),
384+ layout.dropSgLayoutAndData ());
385+ newBroadcastOps.push_back (newBroadcast.getResult ());
386+ }
387+
388+ rewriter.replaceOpWithMultiple (op, {newBroadcastOps});
389+ return success ();
390+ }
391+ };
392+
334393// This pattern transforms elementwise ops to work at subgroup level.
335394struct WgToSgElementwiseOp : public ConversionPattern {
336395 WgToSgElementwiseOp (MLIRContext *ctx)
@@ -475,8 +534,8 @@ namespace xegpu {
475534void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
476535 patterns.add <WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
477536 WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
478- UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
479- patterns.getContext ());
537+ UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
538+ WgToSgVectorBroadcastOp>( patterns.getContext ());
480539}
481540} // namespace xegpu
482541} // namespace mlir
@@ -583,6 +642,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
583642 return isLegal (layout);
584643 });
585644
645+ target.addDynamicallyLegalOp <vector::BroadcastOp>(
646+ [=](vector::BroadcastOp op) -> bool {
647+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
648+ });
649+
586650 target.addDynamicallyLegalDialect <math::MathDialect, arith::ArithDialect>(
587651 [=](Operation *op) -> std::optional<bool > {
588652 // Only handle elementwise mappable ops
0 commit comments