Skip to content

Commit 3c7873b

Browse files
authored
[MLIR][XeGPU] Distribute non-splat constant from wg to sg (#161416)
This PR distributes non-splat constant from wg to sg. The current pattern has limitations and avoids cases which require SLM access.
1 parent 4e53067 commit 3c7873b

File tree

3 files changed

+223
-15
lines changed

3 files changed

+223
-15
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 132 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
720720
ConversionPatternRewriter &rewriter) const override {
721721
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
722722
auto vecType = dyn_cast<VectorType>(op.getType());
723-
if (!vecAttr || !vecAttr.isSplat() || !vecType)
723+
if (!vecAttr || !vecType)
724724
return failure();
725725

726726
xegpu::DistributeLayoutAttr layout =
@@ -733,22 +733,139 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
733733
int count;
734734
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);
735735

736-
// Current limitation: constant of vector with single value.
737-
// TODO: support more complex cases, e.g., vector with multiple values.
738-
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
739-
740736
auto newType = VectorType::get(sgShape, vecType.getElementType());
741-
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
742-
auto cstOp =
743-
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
744-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
745-
!layout.getEffectiveInstDataAsInt().empty())
746-
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
747-
layout.dropSgLayoutAndData());
748-
SmallVector<Value> newConsts(count, cstOp);
737+
Location loc = op.getLoc();
738+
auto eltType = vecType.getElementType();
749739

750-
rewriter.replaceOpWithMultiple(op, {newConsts});
751-
return success();
740+
auto setLayoutIfNeeded = [&](Value val) {
741+
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
742+
!layout.getEffectiveInstDataAsInt().empty()) {
743+
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
744+
layout.dropSgLayoutAndData());
745+
}
746+
};
747+
748+
if (vecAttr.isSplat()) {
749+
// Splat: single value for all subgroups
750+
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
751+
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
752+
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
753+
setLayoutIfNeeded(cstOp->getResult(0));
754+
rewriter.replaceOp(op, cstOp);
755+
return success();
756+
} else if (sgShape == wgShape) { // if the entire vector is shared by all
757+
// subgroups, don't distribute
758+
auto newConstOp =
759+
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
760+
setLayoutIfNeeded(newConstOp->getResult(0));
761+
rewriter.replaceOp(op, newConstOp);
762+
return success();
763+
} else {
764+
// Non-splat constant
765+
// Only supports 1D & 2D
766+
// TODO: support other cases that require SLM access
767+
if (!eltType.isIndex())
768+
return rewriter.notifyMatchFailure(
769+
op, "Unsupported element type for non-splat constant op.");
770+
771+
if (wgShape.size() > 2)
772+
return rewriter.notifyMatchFailure(
773+
op, "Only 1D & 2D vector constant supported");
774+
775+
SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
776+
int64_t rowStride = 0, colStride = 0;
777+
int64_t rows = wgShape.size() == 1 ? 1 : wgShape[0];
778+
int64_t cols = wgShape.size() == 1 ? wgShape[0] : wgShape[1];
779+
780+
// Compute colStride and rowStride, and check for constant strides.
781+
if (cols > 1) {
782+
colStride = cast<IntegerAttr>(values[1]).getInt() -
783+
cast<IntegerAttr>(values[0]).getInt();
784+
}
785+
if (rows > 1) {
786+
rowStride = cast<IntegerAttr>(values[cols]).getInt() -
787+
cast<IntegerAttr>(values[0]).getInt();
788+
}
789+
790+
for (int64_t r = 0; r < rows; ++r) {
791+
for (int64_t c = 0; c < cols; ++c) {
792+
int64_t idx = r * cols + c;
793+
// Check column stride
794+
if (c > 0 && cols > 1) {
795+
int64_t prevIdx = r * cols + (c - 1);
796+
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
797+
cast<IntegerAttr>(values[prevIdx]).getInt();
798+
if (diff != colStride)
799+
return rewriter.notifyMatchFailure(
800+
op, "Non-constant column stride in constant op.");
801+
}
802+
// Check row stride
803+
if (r > 0 && rows > 1) {
804+
int64_t prevIdx = (r - 1) * cols + c;
805+
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
806+
cast<IntegerAttr>(values[prevIdx]).getInt();
807+
if (diff != rowStride)
808+
return rewriter.notifyMatchFailure(
809+
op, "Non-constant row stride in constant op.");
810+
}
811+
}
812+
}
813+
814+
// Create a constant for the base tile.
815+
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
816+
// For 1D case, extract the first sgShape[0] elements.
817+
SmallVector<Attribute> baseTileValues;
818+
int baseTileCols = sgShape[sgShape.size() - 1];
819+
int64_t baseTileRows = sgShape.size() == 1 ? 1 : sgShape[0];
820+
for (int64_t r = 0; r < baseTileRows; ++r) {
821+
for (int64_t c = 0; c < baseTileCols; ++c) {
822+
baseTileValues.push_back(values[r * cols + c]);
823+
}
824+
}
825+
826+
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
827+
baseTileValues);
828+
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
829+
830+
// Get subgroup id
831+
Value sgId =
832+
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);
833+
834+
auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
835+
if (failed(sgOffsets))
836+
return failure();
837+
838+
SmallVector<Value, 2> strideConsts;
839+
strideConsts.push_back(
840+
rewriter.create<arith::ConstantIndexOp>(loc, colStride));
841+
if (rows > 1)
842+
strideConsts.insert(
843+
strideConsts.begin(),
844+
rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
845+
846+
SmallVector<Value> newConstOps;
847+
for (auto offsets : *sgOffsets) {
848+
// Multiply offset with stride, broadcast it and add to baseConstVec
849+
Value mulOffset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
850+
for (size_t i = 0; i < strideConsts.size(); ++i) {
851+
Value mul = rewriter.create<arith::MulIOp>(
852+
loc, rewriter.getIndexType(), offsets[i], strideConsts[i]);
853+
mulOffset = rewriter.create<arith::AddIOp>(
854+
loc, rewriter.getIndexType(), mulOffset, mul);
855+
}
856+
// Broadcast to baseConstVec size
857+
auto bcastOffset = rewriter.create<vector::BroadcastOp>(
858+
loc, baseConstVec.getType(), mulOffset);
859+
auto finalConst =
860+
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
861+
setLayoutIfNeeded(baseConstVec);
862+
setLayoutIfNeeded(bcastOffset);
863+
setLayoutIfNeeded(finalConst);
864+
newConstOps.push_back(finalConst);
865+
}
866+
rewriter.replaceOpWithMultiple(op, {newConstOps});
867+
return success();
868+
}
752869
}
753870
};
754871

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,31 @@ gpu.module @test_distribution {
9898
: vector<256x64xf32> to vector<256xf32>
9999
gpu.return
100100
}
101+
102+
gpu.func @non_splat_constant() {
103+
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
104+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
105+
// CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
106+
// CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
107+
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
108+
// CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
109+
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
110+
// CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
111+
// CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
112+
// CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
113+
// CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
114+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[STRIDE1]] : index
115+
// CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
116+
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[ADDSTRIDES]], %[[STRIDE2]] : index
117+
// CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
118+
// CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
119+
// CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
120+
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[C0:.*]], %[[STRIDE3]] : index
121+
// CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
122+
// CHECK-DAG: %[[ADDSTRIDES3:.*]] = arith.addi %[[ADDSTRIDES2]], %[[STRIDE4]] : index
123+
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES3]] : index to vector<2x1xindex>
124+
// CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
125+
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
126+
gpu.return
127+
}
101128
}

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -463,4 +463,68 @@ gpu.module @test_distribution {
463463
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
464464
gpu.return
465465
}
466+
467+
// CHECK-LABEL: non_splat_constant_2D
468+
gpu.func @non_splat_constant_2D() {
469+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
470+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
471+
// CHECK-DAG: affine.apply #map4()[%[[SGID]]]
472+
// CHECK-DAG: affine.apply #map5()[%[[SGID]]]
473+
// CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
474+
// CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
475+
// CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
476+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[STRIDECOL]] : index
477+
// CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
478+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[STRIDEROW]] : index
479+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
480+
// CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
481+
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
482+
gpu.return
483+
}
484+
485+
// CHECK-LABEL: non_splat_constant_2D_non_unit_dim
486+
gpu.func @non_splat_constant_2D_non_unit_dim() {
487+
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
488+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
489+
// CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
490+
// CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
491+
// CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
492+
// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
493+
// CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
494+
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
495+
// CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
496+
// CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
497+
// CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
498+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[C0:.*]], %[[MUL5]] : index
499+
// CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
500+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[ADD]], %[[MUL6]] : index
501+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<2x2xindex>
502+
// CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
503+
%cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
504+
[0, 16, 32, 48, 64, 80, 96, 112],
505+
[8, 24, 40, 56, 72, 88, 104, 120],
506+
[16, 32, 48, 64, 80, 96, 112, 128],
507+
[24, 40, 56, 72, 88, 104, 120, 136],
508+
[32, 48, 64, 80, 96, 112, 128, 144],
509+
[40, 56, 72, 88, 104, 120, 136, 152],
510+
[48, 64, 80, 96, 112, 128, 144, 160],
511+
[56, 72, 88, 104, 120, 136, 152, 168]
512+
]> : vector<8x8xindex>
513+
gpu.return
514+
}
515+
516+
// CHECK-LABEL: non_splat_constant
517+
gpu.func @non_splat_constant() {
518+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
519+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
520+
// CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C32:.*]]
521+
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU]], %[[C16:.*]] : index
522+
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[C0:.*]], %[[MUL]] : index
523+
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1xindex>
524+
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
525+
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [1]>} dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256, 272, 288, 304, 320, 336, 352, 368, 384, 400, 416, 432, 448, 464, 480, 496]> : vector<32xindex>
526+
// CHECK: arith.constant dense<{{\[}}[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]{{\]}}> : vector<1x16xindex>
527+
%cst_1 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>
528+
gpu.return
529+
}
466530
}

0 commit comments

Comments
 (0)