Skip to content
68 changes: 66 additions & 2 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};

/// This pattern transforms vector.broadcast ops to work at subgroup level.
struct WgToSgVectorBroadcastOp
: public OpConversionPattern<vector::BroadcastOp> {
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType resultType = op.getResult().getType();
ArrayRef<int64_t> wgShape = resultType.getShape();

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.

// TODO: Currently only supports cases where the source and result ranks
// are the same.
auto srcType =
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
Copy link
Contributor

@chencha3 chencha3 Jul 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can adaptor.getSource() be used here and later instead of using adaptor.getOperands().front() ?

if (!srcType || srcType.getRank() != resultType.getRank())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
VectorType newResultType =
VectorType::get(sgShape, resultType.getElementType());

// Check if the output layout is distributable
SmallVector<int64_t> sgLayout;
if (auto sgLayoutAttr = layout.getSgLayout())
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
else
return failure();

if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
return failure();

// Check if the srcShape has unit dim in dimensions being broadcasted,
// and the other dimensions are the same as the destination type
// TODO: Generalize it
auto srcShape = srcType.getShape();
for (size_t i = 0; i < srcShape.size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems this check duplicates the check in broadcast verifier, unless there are cases where the source vector, e.g., vector<32x1x1xf32> can be distributed to a vector, e.g., <8x2x1>.

if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
return failure();
}

SmallVector<Value> newBroadcastOps;
for (auto operand : adaptor.getOperands().front()) {
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
op.getLoc(), newResultType, operand);
xegpu::setLayoutAttr(newBroadcast->getResult(0),
layout.dropSgLayoutAndData());
newBroadcastOps.push_back(newBroadcast.getResult());
}

rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
return success();
}
};

// This pattern transforms elementwise ops to work at subgroup level.
struct WgToSgElementwiseOp : public ConversionPattern {
WgToSgElementwiseOp(MLIRContext *ctx)
Expand Down Expand Up @@ -473,8 +532,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
patterns.getContext());
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
WgToSgVectorBroadcastOp>(patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -581,6 +640,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::BroadcastOp>(
[=](vector::BroadcastOp op) -> bool {
return isLegal(xegpu::getLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
[=](Operation *op) -> std::optional<bool> {
// Only handle elementwise mappable ops
Expand Down
19 changes: 18 additions & 1 deletion mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}

// CHECK-LABEL: broadcast
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK-COUNT-3: vector.broadcast {{.*}}
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
// CHECK-NOT: vector.broadcast
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
Expand Down Expand Up @@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}

}
35 changes: 33 additions & 2 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,38 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}

// CHECK-LABEL: broadcast_dim1
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
-> vector<24x1xf32>
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
: vector<24x1xf32> to vector<24x8xf32>
gpu.return
}

// CHECK-LABEL: broadcast_dim0
// CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
-> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
-> vector<1x32xf32>
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
// CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
%broadcast = vector.broadcast %load
{layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
: vector<1x32xf32> to vector<12x32xf32>
gpu.return
}

gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
Expand Down Expand Up @@ -295,6 +327,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}


}

Loading