Skip to content
49 changes: 47 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,46 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

/// This pattern transforms vector.broadcast ops to work at subgroup level.
struct WgToSgVectorBroadcastOp
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.

// TODO: Currently only supports cases where the source and result ranks
// are the same.
auto srcType =
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
Copy link
Contributor

@chencha3 chencha3 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can adaptor.getSource() be used here and later instead of using adaptor.getOperands().front() ?

if (!srcType || srcType.getRank() != resultType.getRank())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

SmallVector<Value> newBroadcastOps;
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, operand);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}

rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
}
};

// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
Expand Down Expand Up @@ -473,8 +513,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
patterns.getContext());
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -581,6 +621,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
Expand Down
19 changes: 18 additions & 1 deletion mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}

// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK-COUNT-3: vector.broadcast {{.*}}
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}

}
19 changes: 17 additions & 2 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,22 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}

// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
Expand Down Expand Up @@ -295,6 +311,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}


}