Skip to content

Commit fabb419

Browse files
committed
Clean up
1 parent 1b8db0e commit fabb419

File tree

1 file changed

+24
-41
lines changed

1 file changed

+24
-41
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -825,35 +825,21 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
825825
}
826826
}
827827

828-
// Determine the shape of the base tile for each subgroup.
829-
SmallVector<int64_t> baseTileShape;
830-
if (sgShape.size() == 1) {
831-
baseTileShape.push_back(sgShape[0]);
832-
} else if (sgShape.size() == 2) {
833-
baseTileShape = sgShape;
834-
} else {
835-
return rewriter.notifyMatchFailure(
836-
op, "Only 1D or 2D vector constant supported");
837-
}
838-
839828
// Create a constant for the base tile.
840829
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830+
// For 1D case, extract the first sgShape[0] elements.
841831
SmallVector<Attribute> baseTileValues;
842-
if (baseTileShape.size() == 2) {
843-
int64_t rows = baseTileShape[0], cols = baseTileShape[1];
844-
int64_t wgCols = wgShape[1];
845-
for (int64_t r = 0; r < rows; ++r) {
846-
for (int64_t c = 0; c < cols; ++c) {
847-
baseTileValues.push_back(values[r * wgCols + c]);
848-
}
832+
int cols = sgShape[sgShape.size() - 1];
833+
int64_t wgCols = wgShape[sgShape.size() - 1];
834+
int64_t rows = sgShape.size() == 1 ? 1 : sgShape[0];
835+
for (int64_t r = 0; r < rows; ++r) {
836+
for (int64_t c = 0; c < cols; ++c) {
837+
baseTileValues.push_back(values[r * wgCols + c]);
849838
}
850-
} else {
851-
// 1D case
852-
for (int64_t i = 0; i < computeProduct(baseTileShape); ++i)
853-
baseTileValues.push_back(values[i]);
854839
}
855-
auto tileAttr = DenseElementsAttr::get(
856-
VectorType::get(baseTileShape, eltType), baseTileValues);
840+
841+
auto tileAttr = DenseElementsAttr::get(VectorType::get(sgShape, eltType),
842+
baseTileValues);
857843
auto baseConstVec = rewriter.create<arith::ConstantOp>(loc, tileAttr);
858844

859845
// Get subgroup id
@@ -864,27 +850,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
864850
if (failed(sgOffsets))
865851
return failure();
866852

867-
auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
868-
auto rowStrideConst =
869-
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
870-
auto colStrideConst =
871-
rewriter.create<arith::ConstantIndexOp>(loc, colStride);
853+
SmallVector<Value, 2> strideConsts;
854+
strideConsts.push_back(
855+
rewriter.create<arith::ConstantIndexOp>(loc, rowStride));
856+
strideConsts.push_back(
857+
rewriter.create<arith::ConstantIndexOp>(loc, colStride));
872858
SmallVector<Value> newConstOps;
859+
Value mulOffset;
873860
for (auto offsets : *sgOffsets) {
874861
// Multiply offset with stride, broadcast it and add to baseConstVec
875-
Value mulOffset;
876-
if (wgShape.size() == 1) {
877-
// 1D: offset[0] * strideConst
878-
mulOffset = rewriter.create<arith::MulIOp>(
879-
loc, rewriter.getIndexType(), offsets[0], strideConst);
880-
} else if (wgShape.size() == 2) {
881-
// 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
882-
Value rowMul = rewriter.create<arith::MulIOp>(
883-
loc, rewriter.getIndexType(), offsets[0], rowStrideConst);
884-
Value colMul = rewriter.create<arith::MulIOp>(
885-
loc, rewriter.getIndexType(), offsets[1], colStrideConst);
862+
SmallVector<Value> muls;
863+
for (size_t i = 0; i < strideConsts.size(); ++i) {
864+
muls.push_back(rewriter.create<arith::MulIOp>(
865+
loc, rewriter.getIndexType(), offsets[i], strideConsts[i]));
866+
}
867+
mulOffset = muls.front();
868+
if (muls.size() > 1) {
886869
mulOffset = rewriter.create<arith::AddIOp>(
887-
loc, rewriter.getIndexType(), rowMul, colMul);
870+
loc, rewriter.getIndexType(), mulOffset, muls[1]);
888871
}
889872
// Broadcast to baseConstVec size
890873
auto bcastOffset = rewriter.create<vector::BroadcastOp>(

0 commit comments

Comments
 (0)