Skip to content

Commit af87214

Browse files
authored
[MLIR][XeGPU] Add pattern for arith.constant for wg to sg distribution (#151977)
1 parent 2796336 commit af87214

File tree

2 files changed

+59
-6
lines changed

2 files changed

+59
-6
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern
647647
}
648648
};
649649

650+
// This pattern distributes arith.constant op into subgroup-level constants
651+
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
652+
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
653+
654+
LogicalResult
655+
matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor,
656+
ConversionPatternRewriter &rewriter) const override {
657+
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
658+
auto vecType = dyn_cast<VectorType>(op.getType());
659+
if (!vecAttr || !vecAttr.isSplat() || !vecType)
660+
return failure();
661+
662+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
663+
if (!layout || !layout.getSgLayout())
664+
return failure();
665+
666+
ArrayRef<int64_t> wgShape = vecType.getShape();
667+
SmallVector<int64_t> sgShape;
668+
int count;
669+
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
670+
671+
// Current limitation: constant of vector with single value.
672+
// TODO: support more complex cases, e.g., vector with multiple values.
673+
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
674+
675+
auto newType = VectorType::get(sgShape, vecType.getElementType());
676+
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
677+
auto cstOp =
678+
rewriter.create<arith::ConstantOp>(op.getLoc(), newType, sgAttr);
679+
if (auto newLayout = layout.dropSgLayoutAndData())
680+
xegpu::setLayoutAttr(cstOp->getResult(0), newLayout);
681+
SmallVector<Value> newConsts(count, cstOp);
682+
683+
rewriter.replaceOpWithMultiple(op, {newConsts});
684+
return success();
685+
}
686+
};
687+
650688
} // namespace
651689

652690
namespace mlir {
653691
namespace xegpu {
654692
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
655-
patterns
656-
.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
657-
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
658-
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
659-
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
660-
patterns.getContext());
693+
patterns.add<WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
694+
WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp,
695+
WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern,
696+
WgToSgElementwiseOp, WgToSgVectorBroadcastOp,
697+
WgToSgConvertLayoutOp, WgToSgArithConstantOp>(
698+
patterns.getContext());
661699
}
662700
} // namespace xegpu
663701
} // namespace mlir
@@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
769807
return isLegal(xegpu::getLayoutAttr(op.getResult()));
770808
});
771809

810+
target.addDynamicallyLegalOp<arith::ConstantOp>(
811+
[=](arith::ConstantOp op) -> bool {
812+
auto vecType = dyn_cast<VectorType>(op.getType());
813+
if (!vecType)
814+
return true;
815+
return isLegal(xegpu::getLayoutAttr(op.getResult()));
816+
});
817+
772818
target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
773819
[=](xegpu::ConvertLayoutOp op) -> bool {
774820
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,4 +365,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) {
365365
} {sg_id_range = #xegpu.range<[3, 19]>}
366366
gpu.return
367367
}
368+
369+
// CHECK-LABEL: distribute_constant
370+
gpu.func @distribute_constant() {
371+
// CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32>
372+
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} dense<1.0> : vector<256x128xf32>
373+
gpu.return
374+
}
368375
}

0 commit comments

Comments
 (0)