Skip to content

Commit c5cd274

Browse files
committed
Add pattern for broadcast
1 parent f1509d2 commit c5cd274

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -316,22 +316,21 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
316316
};
317317

318318
/// This pattern transforms vector.broadcast ops to work at subgroup level.
319-
/// It splits the broadcast to match the subgroup shape and drops sgLayout/sgData.
320-
struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp> {
319+
struct WgToSgVectorBroadcastOp
320+
: public OpConversionPattern<vector::BroadcastOp> {
321321
using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
322322

323323
LogicalResult
324324
matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
325325
ConversionPatternRewriter &rewriter) const override {
326326
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
327327
if (!resultType)
328-
return rewriter.notifyMatchFailure(op, "Result is not a vector type");
328+
return failure();
329329

330330
// Only handle broadcasts to vectors with XeGPU layout attribute
331331
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
332332
if (!layout || !layout.getSgLayout())
333-
return rewriter.notifyMatchFailure(
334-
op, "Result does not have a valid layout attribute for subgroup distribution");
333+
return failure();
335334

336335
// Extract sgShape from layout
337336
SmallVector<int64_t> sgShape;
@@ -347,15 +346,17 @@ struct WgToSgVectorBroadcastOp : public OpConversionPattern<vector::BroadcastOp>
347346
}
348347
}
349348

350-
VectorType newResultType = VectorType::get(sgShape, resultType.getElementType());
349+
VectorType newResultType =
350+
VectorType::get(sgShape, resultType.getElementType());
351351
SmallVector<Value> newBroadcasts;
352352

353-
// The operand is always a scalar or lower-rank vector, so just broadcast for each subgroup
354-
for (Value unused : adaptor.getOperands().front()) {
355-
// All subgroups get the same broadcasted value
353+
// The operand is always a scalar or lower-rank vector, so just broadcast
354+
// for each subgroup
355+
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
356356
auto newBroadcast = rewriter.create<vector::BroadcastOp>(
357-
op.getLoc(), newResultType, adaptor.getOperands().front()[0]);
358-
xegpu::setLayoutAttr(newBroadcast->getResult(0), layout.dropSgLayoutAndData());
357+
op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
358+
xegpu::setLayoutAttr(newBroadcast->getResult(0),
359+
layout.dropSgLayoutAndData());
359360
newBroadcasts.push_back(newBroadcast.getResult());
360361
}
361362

@@ -371,8 +372,7 @@ namespace xegpu {
371372
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
372373
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
373374
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
374-
WgToSgVectorBroadcastOp>(
375-
patterns.getContext());
375+
WgToSgVectorBroadcastOp>(patterns.getContext());
376376
}
377377
} // namespace xegpu
378378
} // namespace mlir
@@ -420,13 +420,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
420420
return isLegal(layout);
421421
});
422422

423-
target.addDynamicallyLegalOp<vector::BroadcastOp>([=](vector::BroadcastOp op) -> bool {
424-
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
425-
if (!resultType)
426-
return true;
427-
auto layout = xegpu::getLayoutAttr(op.getResult());
428-
return isLegal(layout);
429-
});
423+
target.addDynamicallyLegalOp<vector::BroadcastOp>(
424+
[=](vector::BroadcastOp op) -> bool {
425+
auto resultType = dyn_cast<VectorType>(op.getResult().getType());
426+
if (!resultType)
427+
return true;
428+
xegpu::LayoutAttr = xegpu::getLayoutAttr(op.getResult());
429+
return isLegal(layout);
430+
});
430431

431432
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
432433

0 commit comments

Comments
 (0)