Skip to content
124 changes: 109 additions & 15 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
ConversionPatternRewriter &rewriter) const override {
auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue());
auto vecType = dyn_cast<VectorType>(op.getType());
if (!vecAttr || !vecAttr.isSplat() || !vecType)
if (!vecAttr || !vecType)
return failure();

xegpu::DistributeLayoutAttr layout =
Expand All @@ -733,22 +733,116 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
int count;
std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout);

// Current limitation: constant of vector with single value.
// TODO: support more complex cases, e.g., vector with multiple values.
Attribute singleVal = vecAttr.getSplatValue<Attribute>();

auto newType = VectorType::get(sgShape, vecType.getElementType());
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
layout.dropSgLayoutAndData());
SmallVector<Value> newConsts(count, cstOp);
Location loc = op.getLoc();
auto eltType = vecType.getElementType();

rewriter.replaceOpWithMultiple(op, {newConsts});
return success();
auto setLayoutIfNeeded = [&](Value val) {
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
!layout.getEffectiveInstDataAsInt().empty()) {
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
layout.dropSgLayoutAndData());
}
};

if (vecAttr.isSplat()) {
// Splat: single value for all subgroups
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
setLayoutIfNeeded(cstOp->getResult(0));
rewriter.replaceOp(op, cstOp);
return success();
} else if (sgShape == wgShape) { // if the entire vector is shared by all
// subgroups, don't distribute
auto newConstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
setLayoutIfNeeded(newConstOp->getResult(0));
rewriter.replaceOp(op, newConstOp);
return success();
} else {
// Non-splat constant
// Only supports 1D & 2D (with one unit dim)
// TODO: support other cases that require SLM access
if (!eltType.isIndex())
return rewriter.notifyMatchFailure(
op, "Unsupported element type for non-splat constant op.");

SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
if (wgShape.size() > 2)
return rewriter.notifyMatchFailure(
op, "Only 1D & 2D vector constant supported");

// allow 2D vector/distributions with one unit dim
auto hasTwoNonUnitDims = [](ArrayRef<int64_t> dims) {
return dims.size() == 2 && dims[0] != 1 && dims[1] != 1;
};
if (hasTwoNonUnitDims(wgShape) || hasTwoNonUnitDims(sgLayout))
return rewriter.notifyMatchFailure(
op, "2D vector/distribution only supported with 1 unit dim");

int64_t nonUnitDim = 0;
if (wgShape.size() == 2)
nonUnitDim = wgShape[0] != 1 ? 0 : 1;

SmallVector<Attribute> values(vecAttr.getValues<Attribute>());
int64_t stride = 0;
if (values.size() > 1) {
stride = cast<IntegerAttr>(values[1]).getInt() -
cast<IntegerAttr>(values[0]).getInt();
for (size_t i = 2; i < values.size(); ++i) {
int64_t diff = cast<IntegerAttr>(values[i]).getInt() -
cast<IntegerAttr>(values[i - 1]).getInt();
if (diff != stride)
return rewriter.notifyMatchFailure(
op, "Non-constant stride in non-splat constant op.");
}
}

int sgData = 1;
if (sgShape.size() == 1) {
sgData = static_cast<int>(sgShape[0]);
} else if (sgShape.size() == 2) {
sgData = static_cast<int>(sgShape[0] != 1 ? sgShape[0] : sgShape[1]);
} else {
return rewriter.notifyMatchFailure(
op, "Only 1D or 2D vector constant supported");
}

// Create a constant for the base tile
SmallVector<Attribute> baseTileValues;
for (int i = 0; i < sgData; ++i)
baseTileValues.push_back(values[i]);
auto tileAttr = DenseElementsAttr::get(VectorType::get({sgData}, eltType),
baseTileValues);
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);

// Get subgroup id
Value sgId =
gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr);

auto sgOffsets = layout.getOffsets(rewriter, loc, sgId, wgShape);
if (failed(sgOffsets))
return failure();

auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
SmallVector<Value> newConstOps;
for (auto offsets : *sgOffsets) {
// Multiply offset with stride, broadcast it and add to baseConstVec
Value mulOffset = rewriter.create<arith::MulIOp>(
loc, rewriter.getIndexType(), offsets[nonUnitDim], strideConst);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is written specific to one condition (2d but with 1 unit dim, or 1d). If we have to relax the condition, the code need a totoal rewrite.
Can we make it more generic, like having 2 strides for 2d, here you just compute the linear offset, before adding to baseConstVec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to supporting 2D vectors....the high level logic is having two strides,
rowStride & columnStride, and computing the value as rowOffsetrowStride + columnOffsetcolStride and then broadcasting it to baseConstVec size and adding it with the baseConstVec.

auto bcastOffset = rewriter.create<vector::BroadcastOp>(
loc, VectorType::get({sgData}, rewriter.getIndexType()), mulOffset);
auto finalConst =
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
setLayoutIfNeeded(baseConstVec);
setLayoutIfNeeded(bcastOffset);
setLayoutIfNeeded(finalConst);
newConstOps.push_back(finalConst);
}
rewriter.replaceOpWithMultiple(op, {newConstOps});
return success();
}
}
};

Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -98,4 +98,39 @@ gpu.module @test_distribution {
: vector<256x64xf32> to vector<256xf32>
gpu.return
}

gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<[0, 16]> : vector<2xindex>
// CHECK-DAG: %[[SG_ID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C1_0:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[AFF1:.*]] = affine.apply #map4()[%[[SG_ID]]]
// CHECK-DAG: %[[AFF2:.*]] = affine.apply #map5()[%[[SG_ID]]]
// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[MUL:.*]] = index.mul %[[AFF1]], %[[C2]]
// CHECK-DAG: %[[C1_1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C0_2:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[REM:.*]] = index.remu %[[MUL]], %[[C32]]
// CHECK-DAG: %[[C1_3:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[REM2:.*]] = index.remu %[[AFF2]], %[[C1_3]]
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[C0_4:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[MUL]], %[[C16]] : index
// CHECK-DAG: %[[C32_5:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[REM3:.*]] = index.remu %[[ADD]], %[[C32_5]]
// CHECK-DAG: %[[C1_6:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[REM4:.*]] = index.remu %[[AFF2]], %[[C1_6]]
// CHECK-DAG: %[[C16_7:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM]], %[[C16_7]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL2]] : index to vector<2xindex>
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<2xindex>
// CHECK-DAG: %[[MUL3:.*]] = arith.muli %[[REM3]], %[[C16_7]] : index
// CHECK-DAG: %[[BCAST2:.*]] = vector.broadcast %[[MUL3]] : index to vector<2xindex>
// CHECK-DAG: %[[ADD3:.*]] = arith.addi %[[CST]], %[[BCAST2]] : vector<2xindex>
%cst_2 = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [8, 1], sg_data = [2, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
}
19 changes: 19 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -458,4 +458,23 @@ gpu.module @test_distribution {
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
gpu.return
}

// CHECK-LABEL: non_splat_constant
gpu.func @non_splat_constant() {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0> : vector<1xindex>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[C32:.*]] = arith.constant 32 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[IDY:.*]] = affine.apply #map4()[%[[SGID]]]
// CHECK-DAG: %[[IDX:.*]] = affine.apply #map5()[%[[SGID]]]
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[REMU_Y:.*]] = index.remu %[[IDY]], %[[C32]]
// CHECK-DAG: %[[REMU_X:.*]] = index.remu %[[IDX]], %[[C1]]
// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index
// CHECK-DAG: %[[MUL:.*]] = arith.muli %[[REMU_Y]], %[[C16]] : index
// CHECK-DAG: %[[BCAST:.*]] = vector.broadcast %[[MUL]] : index to vector<1xindex>
// CHECK-DAG: %[[ADD:.*]] = arith.addi %[[CST]], %[[BCAST]] : vector<1xindex>
%cst = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32, 1], sg_data = [1, 1]>} dense<[[0], [16], [32], [48], [64], [80], [96], [112], [128], [144], [160], [176], [192], [208], [224], [240], [256], [272], [288], [304], [320], [336], [352], [368], [384], [400], [416], [432], [448], [464], [480], [496]]> : vector<32x1xindex>
gpu.return
}
}