Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 165 additions & 15 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecAttr || !vecAttr.isSplat() || !vecType)
if (!vecAttr || !vecType)
return failure();

xegpu::DistributeLayoutAttr layout =
Expand All @@ -733,22 +733,172 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);

// Current limitation: constant of vector with single value.
// TODO: support more complex cases, e.g., vector with multiple values.
Attribute singleVal = vecAttr.getSplatValue<Attribute>();

auto newType = VectorType::get(sgShape, vecType.getElementType());
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
layout.dropSgLayoutAndData());
SmallVector<Value> newConsts(count, cstOp);
Location loc = op.getLoc();
auto eltType = vecType.getElementType();

rewriter.replaceOpWithMultiple(op, {newConsts});
return success();
auto setLayoutIfNeeded = [&](Value val) {
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
layout.dropSgLayoutAndData());
}
};

if (vecAttr.isSplat()) {
// Splat: single value for all subgroups
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
setLayoutIfNeeded(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we specialize for a rare condition?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so that we can avoid generating IR (for all this computation) that we have for non-splat case..

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a test for this branch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, but a case like this, do you want to add this as a test?

%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 16]>} dense<[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]> : vector<1x16xindex>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep

// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
setLayoutIfNeeded(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
// Non-splat constant
// Only supports 1D & 2D
// TODO: support other cases that require SLM access
if (!eltType.isIndex())
return rewriter.notifyMatchFailure(
op, "Unsupported element type for non-splat constant op.");

if (wgShape.size() > 2)
return rewriter.notifyMatchFailure(
op, "Only 1D & 2D vector constant supported");

SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
int64_t stride = 0;
int64_t rowStride = 0, colStride = 0;
if (wgShape.size() == 1) {
// 1D case: single stride
if (values.size() > 1) {
stride = cast<IntegerAttr>(values[1]).getInt() -
cast<IntegerAttr>(values[0]).getInt();
for (size_t i = 2; i < values.size(); ++i) {
int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
cast<IntegerAttr>(values[i - 1]).getInt();
if (diff != stride)
return rewriter.notifyMatchFailure(
op, "Non-constant stride in non-splat constant op.");
}
}
} else if (wgShape.size() == 2) {
// 2D case: row stride and column stride
int64_t rows = wgShape[0], cols = wgShape[1];
// Compute col stride (stride between elements in a column)
if (cols > 1) {
colStride = cast<IntegerAttr>(values[1]).getInt() -
cast<IntegerAttr>(values[0]).getInt();
for (int64_t r = 0; r < rows; ++r) {
for (int64_t c = 1; c < cols; ++c) {
int64_t idx = r * cols + c;
int64_t prevIdx = r * cols + (c - 1);
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
cast<IntegerAttr>(values[prevIdx]).getInt();
if (diff != colStride)
return rewriter.notifyMatchFailure(
op, "Non-constant column stride in 2D constant op.");
}
}
}
// Compute row stride (stride between elements in a row)
if (rows > 1) {
rowStride = cast<IntegerAttr>(values[cols]).getInt() -
cast<IntegerAttr>(values[0]).getInt();
for (int64_t c = 0; c < cols; ++c) {
for (int64_t r = 1; r < rows; ++r) {
int64_t idx = r * cols + c;
int64_t prevIdx = (r - 1) * cols + c;
int64_t diff = cast<IntegerAttr>(values[idx]).getInt() -
cast<IntegerAttr>(values[prevIdx]).getInt();
if (diff != rowStride)
return rewriter.notifyMatchFailure(
op, "Non-constant row stride in 2D constant op.");
}
}
}
}

// Determine the shape of the base tile for each subgroup.
SmallVector<int64_t> baseTileShape;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just use sgShape directly instead of new var?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, cleaning it up further

if (sgShape.size() == 1) {
baseTileShape.push_back(sgShape[0]);
} else if (sgShape.size() == 2) {
baseTileShape = sgShape;
} else {
return rewriter.notifyMatchFailure(
op, "Only 1D or 2D vector constant supported");
}

// Create a constant for the base tile.
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
SmallVector<Attribute> baseTileValues;
if (baseTileShape.size() == 2) {
int64_t rows = baseTileShape[0], cols = baseTileShape[1];
int64_t wgCols = wgShape[1];
for (int64_t r = 0; r < rows; ++r) {
for (int64_t c = 0; c < cols; ++c) {
baseTileValues.push_back(values[r * wgCols + c]);
}
}
} else {
// 1D case
for (int64_t i = 0; i < computeProduct(baseTileShape); ++i)
baseTileValues.push_back(values[i]);
}
auto tileAttr = DenseElementsAttr::get(
VectorType::get(baseTileShape, eltType), baseTileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);

// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);

auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();

auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
auto rowStrideConst =
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
auto colStrideConst =
rewriter.create<arith::ConstantIndexOp>(loc, colStride);
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset;
if (wgShape.size() == 1) {
// 1D: offset[0] * strideConst
mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[0], strideConst);
} else if (wgShape.size() == 2) {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just else?

// 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
Value rowMul = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[0], rowStrideConst);
Value colMul = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[1], colStrideConst);
mulOffset = rewriter.create<arith::AddIOp>(
loc, rewriter.getIndexType(), rowMul, colMul);
}
// Broadcast to baseConstVec size
auto bcastOffset = rewriter.create<vector::BroadcastOp>(
loc, baseConstVec.getType(), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
setLayoutIfNeeded(bcastOffset);
setLayoutIfNeeded(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
return success();
}
}
};

Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,29 @@ gpu.module @test_distribution {
: vector<256x64xf32> to vector<256xf32>
gpu.return
}

gpu.func @non_splat_constant() {
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}}> : vector<2x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[MAP4:.*]] = affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: %[[MAP5:.*]] = affine.apply #map5()[%[[SGID]]]
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[MAP4]], %[[C2:.*]]
// CHECK-DAG: %[[REMU1:.*]] = index.remu %[[MUL]], %[[C32:.*]]
// CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
// CHECK-DAG: %[[ADD16:.*]] = arith.addi %[[MUL]], %[[C16:.*]] : index
// CHECK-DAG: %[[REMU3:.*]] = index.remu %[[ADD16]], %[[C32:.*]]
// CHECK-DAG: %[[REMU4:.*]] = index.remu %[[MAP5]], %[[C1:.*]]
// CHECK-DAG: %[[STRIDE1:.*]] = arith.muli %[[REMU1]], %[[C16:.*]] : index
// CHECK-DAG: %[[STRIDE2:.*]] = arith.muli %[[REMU2]], %[[C0:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES1:.*]] = arith.addi %[[STRIDE1]], %[[STRIDE2]] : index
// CHECK-DAG: %[[BCAST1:.*]] = vector.broadcast %[[ADDSTRIDES1]] : index to vector<2x1xindex>
// CHECK-DAG: %[[RESULT1:.*]] = arith.addi %[[BASECST]], %[[BCAST1]] : vector<2x1xindex>
// CHECK-DAG: %[[STRIDE3:.*]] = arith.muli %[[REMU3]], %[[C16:.*]] : index
// CHECK-DAG: %[[STRIDE4:.*]] = arith.muli %[[REMU4]], %[[C0:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES2:.*]] = arith.addi %[[STRIDE3]], %[[STRIDE4]] : index
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[ADDSTRIDES2]] : index to vector<2x1xindex>
// CHECK-DAG: %[[RESULT2:.*]] = arith.addi %[[BASECST]], %[[BCAST2]] : vector<2x1xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
}
47 changes: 47 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -458,4 +458,51 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
gpu.return
}

// CHECK-LABEL: non_splat_constant
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1x1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: affine.apply #map5()[%[[SGID]]]
// CHECK-DAG: %[[IDY:.*]] = index.remu %{{.*}}, %[[C32:.*]]
// CHECK-DAG: %[[IDX:.*]] = index.remu %{{.*}}, %[[C1:.*]]
// CHECK-DAG: %[[STRIDECOL:.*]] = arith.muli %[[IDY]], %[[C16:.*]] : index
// CHECK-DAG: %[[STRIDEROW:.*]] = arith.muli %[[IDX]], %[[C0:.*]] : index
// CHECK-DAG: %[[ADDSTRIDES:.*]] = arith.addi %[[STRIDECOL]], %[[STRIDEROW]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDSTRIDES]] : index to vector<1x1xindex>
// CHECK-DAG: arith.addi %[[CST]], %[[BCAST]] : vector<1x1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}

// CHECK-LABEL: non_splat_constant_2D_non_unit_dim
gpu.func @non_splat_constant_2D_non_unit_dim() {
// CHECK-DAG: %[[BASECST:.*]] = arith.constant dense<{{.*}} : vector<2x2xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[IDY:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK-DAG: %[[IDX:.*]] = affine.apply #map1()[%[[SGID]]]
// CHECK-DAG: %[[MULY:.*]] = index.mul %[[IDY]], %[[C2:.*]]
// CHECK-DAG: %[[C2_2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[MULX:.*]] = index.mul %[[IDX]], %[[C2:.*]]
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[MULY]], %[[C8:.*]]
// CHECK-DAG: %[[C8_2:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[MULX]], %[[C8:.*]]
// CHECK-DAG: %[[MUL5:.*]] = arith.muli %[[REMU_Y]], %[[C8:.*]] : index
// CHECK-DAG: %[[MUL6:.*]] = arith.muli %[[REMU_X]], %[[C16:.*]] : index
// CHECK-DAG: %[[ADDIDX:.*]] = arith.addi %[[MUL5]], %[[MUL6]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[ADDIDX]] : index to vector<2x2xindex>
// CHECK-DAG: %[[ADDCST:.*]] = arith.addi %[[BASECST]], %[[BCAST]] : vector<2x2xindex>
%cst_8x8 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2]>} dense<[
[0, 16, 32, 48, 64, 80, 96, 112],
[8, 24, 40, 56, 72, 88, 104, 120],
[16, 32, 48, 64, 80, 96, 112, 128],
[24, 40, 56, 72, 88, 104, 120, 136],
[32, 48, 64, 80, 96, 112, 128, 144],
[40, 56, 72, 88, 104, 120, 136, 152],
[48, 64, 80, 96, 112, 128, 144, 160],
[56, 72, 88, 104, 120, 136, 152, 168]
]> : vector<8x8xindex>
gpu.return
}
}