Skip to content

Commit f1509d2

Browse files
committed
Add pattern for broadcast
1 parent 79108da commit f1509d2

File tree

3 files changed

+89
-2
lines changed

3 files changed

+89
-2
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/Utils/IndexingUtils.h"
1717
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1818
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
19+
#include "mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021

2122
namespace mlir {
@@ -314,13 +315,63 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
314315
}
315316
};
316317

318+
/// This pattern transforms vector.broadcast ops to work at subgroup level.
319+
/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
320+
struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
321+
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
322+
323+
LogicalResult
324+
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
325+
ConversionPatternRewriter &rewriter) const override {
326+
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
327+
if (!resultType)
328+
return rewriter.notifyMatchFailure(op, "Result is not a vector type");
329+
330+
// Only handle broadcasts to vectors with XeGPU layout attribute
331+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
332+
if (!layout || !layout.getSgLayout())
333+
return rewriter.notifyMatchFailure(
334+
op, "Result does not have a valid layout attribute for subgroup distribution");
335+
336+
// Extract sgShape from layout
337+
SmallVector<int64_t> sgShape;
338+
if (auto sgDataAttr = layout.getSgData()) {
339+
sgShape = llvm::to_vector_of<int64_t>(sgDataAttr.asArrayRef());
340+
} else {
341+
auto sgLayoutArr = layout.getSgLayout();
342+
ArrayRef<int64_t> shape = resultType.getShape();
343+
sgShape.reserve(shape.size());
344+
for (size_t i = 0; i < shape.size(); ++i) {
345+
assert(sgLayoutArr[i] != 0 && "sgLayout elements must be non-zero");
346+
sgShape.push_back(shape[i] / sgLayoutArr[i]);
347+
}
348+
}
349+
350+
VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
351+
SmallVector<Value> newBroadcasts;
352+
353+
// The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
354+
for (Value unused : adaptor.getOperands().front()) {
355+
// All subgroups get the same broadcasted value
356+
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
357+
op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
358+
xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
359+
newBroadcasts.push_back(newBroadcast.getResult());
360+
}
361+
362+
rewriter.replaceOpWithMultiple(op, {newBroadcasts});
363+
return success();
364+
}
365+
};
366+
317367
} // namespace
318368

319369
namespace mlir {
320370
namespace xegpu {
321371
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
322372
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
323-
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp>(
373+
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
374+
WgToSgVectorBroadcastOp>(
324375
patterns.getContext());
325376
}
326377
} // namespace xegpu
@@ -369,6 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
369420
return isLegal(layout);
370421
});
371422

423+
target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
424+
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
425+
if (!resultType)
426+
return true;
427+
auto layout = xegpu::getLayoutAttr(op.getResult());
428+
return isLegal(layout);
429+
});
430+
372431
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
373432

374433
xegpu::populateXeGPUWgToSgDistributePatterns(patterns);

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,18 @@ gpu.module @test_round_robin_assignment {
102102
: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
103103
gpu.return
104104
}
105+
106+
// CHECK-LABEL: test_broadcast
107+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
108+
gpu.func @test_broadcast(%src: memref<24x1xf32>) {
109+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
110+
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
111+
%load = xegpu.load_nd %tdesc
112+
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
113+
-> vector<24x1xf32>
114+
%broadcast = vector.broadcast %load
115+
{layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
116+
: vector<24x1xf32> to vector<24x8xf32>
117+
gpu.return
118+
}
105119
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,18 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
169169
: vector<24x32xf32>, vector<32x24xf32> -> vector<24x24xf32>
170170
gpu.return
171171
}
172-
}
172+
173+
// CHECK-LABEL: test_broadcast
174+
// CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
175+
gpu.func @test_broadcast(%src: memref<24x1xf32>) {
176+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
177+
-> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
178+
%load = xegpu.load_nd %tdesc
179+
: !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
180+
-> vector<24x1xf32>
181+
%broadcast = vector.broadcast %load
182+
{layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
183+
: vector<24x1xf32> to vector<24x8xf32>
184+
gpu.return
185+
}
186+
}

0 commit comments

Comments
 (0)