Skip to content

Commit 56b263b

Browse files
authored
[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass (#144417)
This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass
1 parent a7867fc commit 56b263b

File tree

3 files changed

+117
-5
lines changed

3 files changed

+117
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,65 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
331331
}
332332
};
333333

334+
/// This pattern transforms vector.broadcast ops to work at subgroup level.
335+
struct WgToSgVectorBroadcastOp
336+
: public OpConversionPattern<vector::BroadcastOp> {
337+
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
338+
339+
LogicalResult
340+
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
341+
ConversionPatternRewriter &rewriter) const override {
342+
VectorType resultType = op.getResult().getType();
343+
ArrayRef<int64_t> wgShape = resultType.getShape();
344+
345+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
346+
if (!layout || !layout.getSgLayout())
347+
return failure();
348+
349+
// TODO: Currently only supports cases where the source and result ranks
350+
// are the same.
351+
auto srcType =
352+
dyn_cast<VectorType>(adaptor.getOperands().front()[0].getType());
353+
if (!srcType || srcType.getRank() != resultType.getRank())
354+
return failure();
355+
356+
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
357+
VectorType newResultType =
358+
VectorType::get(sgShape, resultType.getElementType());
359+
360+
// Check if the output layout is distributable
361+
SmallVector<int64_t> sgLayout;
362+
if (auto sgLayoutAttr = layout.getSgLayout())
363+
sgLayout = llvm::to_vector_of<int64_t>(sgLayoutAttr.asArrayRef());
364+
else
365+
return failure();
366+
367+
if (!xegpu::XeGPUDialect::isEvenlyDistributable(wgShape, layout))
368+
return failure();
369+
370+
// Check if the srcShape has unit dim in dimensions being broadcasted,
371+
// and the other dimensions are the same as the destination type
372+
// TODO: Generalize it
373+
auto srcShape = srcType.getShape();
374+
for (size_t i = 0; i < srcShape.size(); ++i) {
375+
if (srcShape[i] != 1 && srcShape[i] != sgShape[i])
376+
return failure();
377+
}
378+
379+
SmallVector<Value> newBroadcastOps;
380+
for (auto operand : adaptor.getOperands().front()) {
381+
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
382+
op.getLoc(), newResultType, operand);
383+
xegpu::setLayoutAttr(newBroadcast->getResult(0),
384+
layout.dropSgLayoutAndData());
385+
newBroadcastOps.push_back(newBroadcast.getResult());
386+
}
387+
388+
rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
389+
return success();
390+
}
391+
};
392+
334393
// This pattern transforms elementwise ops to work at subgroup level.
335394
struct WgToSgElementwiseOp : public ConversionPattern {
336395
WgToSgElementwiseOp(MLIRContext *ctx)
@@ -475,8 +534,8 @@ namespace xegpu {
475534
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
476535
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
477536
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
478-
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp>(
479-
patterns.getContext());
537+
UnrealizedConversionCastOpPattern, WgToSgElementwiseOp,
538+
WgToSgVectorBroadcastOp>(patterns.getContext());
480539
}
481540
} // namespace xegpu
482541
} // namespace mlir
@@ -583,6 +642,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
583642
return isLegal(layout);
584643
});
585644

645+
target.addDynamicallyLegalOp<vector::BroadcastOp>(
646+
[=](vector::BroadcastOp op) -> bool {
647+
return isLegal(xegpu::getLayoutAttr(op.getResult()));
648+
});
649+
586650
target.addDynamicallyLegalDialect<math::MathDialect, arith::ArithDialect>(
587651
[=](Operation *op) -> std::optional<bool> {
588652
// Only handle elementwise mappable ops

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
103103
gpu.return
104104
}
105105

106+
// CHECK-LABEL: broadcast
107+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
108+
gpu.func @broadcast(%src: memref<24x1xf32>) {
109+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
110+
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
111+
%load = xegpu.load_nd %tdesc
112+
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
113+
-> vector<24x1xf32>
114+
// CHECK-COUNT-3: vector.broadcast {{.*}}
115+
// CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
116+
// CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
117+
// CHECK-NOT: vector.broadcast
118+
%broadcast = vector.broadcast %load
119+
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
120+
: vector<24x1xf32> to vector<24x8xf32>
121+
gpu.return
122+
}
123+
106124
gpu.func @scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
107125
%c1 = arith.constant 1 : index
108126
%c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
197215
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
198216
gpu.return
199217
}
200-
201218
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,38 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
170170
gpu.return
171171
}
172172

173+
// CHECK-LABEL: broadcast_dim1
174+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
175+
gpu.func @broadcast_dim1(%src: memref<24x1xf32>) {
176+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
177+
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
178+
%load = xegpu.load_nd %tdesc
179+
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
180+
-> vector<24x1xf32>
181+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
182+
// CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
183+
%broadcast = vector.broadcast %load
184+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
185+
: vector<24x1xf32> to vector<24x8xf32>
186+
gpu.return
187+
}
188+
189+
// CHECK-LABEL: broadcast_dim0
190+
// CHECK-SAME: %[[ARG_0:.*]]: memref<1x32xf32>
191+
gpu.func @broadcast_dim0(%src: memref<1x32xf32>) {
192+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<1x32xf32>
193+
-> !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
194+
%load = xegpu.load_nd %tdesc
195+
: !xegpu.tensor_desc<1x32xf32, #xegpu.layout<sg_layout = [1, 4], sg_data = [1, 8], lane_layout = [1, 8], lane_data = [1, 1]>>
196+
-> vector<1x32xf32>
197+
// CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [1, 8], lane_data = [1, 1]>}
198+
// CHECK-SAME: : vector<1x8xf32> to vector<12x8xf32>
199+
%broadcast = vector.broadcast %load
200+
{layout_result_0 = #xegpu.layout<sg_layout = [1, 4], sg_data = [12, 8], lane_layout = [1, 8], lane_data = [1, 1]>}
201+
: vector<1x32xf32> to vector<12x32xf32>
202+
gpu.return
203+
}
204+
173205
gpu.func @scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
174206
//CHECK: [[c0:%.+]] = arith.constant 0 : index
175207
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +327,5 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
295327
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
296328
gpu.return
297329
}
298-
299-
300330
}
331+

0 commit comments

Comments
 (0)