Skip to content

Commit 1b8db0e

Browse files
committed
Clean up
1 parent 1b779b7 commit 1b8db0e

File tree

1 file changed

+10
-19
lines changed

1 file changed

+10
-19
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -762,13 +762,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
762762
return success();
763763
} else {
764764
// Non-splat constant
765-
// Only supports 1D & 2D (with one unit dim)
765+
// Only supports 1D & 2D
766766
// TODO: support other cases that require SLM access
767767
if (!eltType.isIndex())
768768
return rewriter.notifyMatchFailure(
769769
op, "Unsupported element type for non-splat constant op.");
770770

771-
SmallVector<int64_t> sgLayout = layout.getEffectiveSgLayoutAsInt();
772771
if (wgShape.size() > 2)
773772
return rewriter.notifyMatchFailure(
774773
op, "Only 1D & 2D vector constant supported");
@@ -792,9 +791,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
792791
} else if (wgShape.size() == 2) {
793792
// 2D case: row stride and column stride
794793
int64_t rows = wgShape[0], cols = wgShape[1];
795-
if (values.size() != static_cast<size_t>(rows * cols))
796-
return rewriter.notifyMatchFailure(
797-
op, "Mismatch between vector shape and constant values size.");
798794
// Compute col stride (stride between elements in a column)
799795
if (cols > 1) {
800796
colStride = cast<IntegerAttr>(values[1]).getInt() -
@@ -840,25 +836,20 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
840836
op, "Only 1D or 2D vector constant supported");
841837
}
842838

843-
// Compute the number of elements in the base tile.
844-
int64_t baseTileElemCount = 1;
845-
for (int64_t d : baseTileShape)
846-
baseTileElemCount *= d;
847-
848839
// Create a constant for the base tile.
849840
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
850841
SmallVector<Attribute> baseTileValues;
851842
if (baseTileShape.size() == 2) {
852843
int64_t rows = baseTileShape[0], cols = baseTileShape[1];
853-
int64_t wgRows = wgShape[0], wgCols = wgShape[1];
844+
int64_t wgCols = wgShape[1];
854845
for (int64_t r = 0; r < rows; ++r) {
855846
for (int64_t c = 0; c < cols; ++c) {
856847
baseTileValues.push_back(values[r * wgCols + c]);
857848
}
858849
}
859850
} else {
860851
// 1D case
861-
for (int64_t i = 0; i < baseTileElemCount; ++i)
852+
for (int64_t i = 0; i < computeProduct(baseTileShape); ++i)
862853
baseTileValues.push_back(values[i]);
863854
}
864855
auto tileAttr = DenseElementsAttr::get(
@@ -874,24 +865,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
874865
return failure();
875866

876867
auto strideConst = rewriter.create<arith::ConstantIndexOp>(loc, stride);
877-
auto strideConstRow =
868+
auto rowStrideConst =
878869
rewriter.create<arith::ConstantIndexOp>(loc, rowStride);
879-
auto strideConstCol =
870+
auto colStrideConst =
880871
rewriter.create<arith::ConstantIndexOp>(loc, colStride);
881872
SmallVector<Value> newConstOps;
882873
for (auto offsets : *sgOffsets) {
883874
// Multiply offset with stride, broadcast it and add to baseConstVec
884875
Value mulOffset;
885-
if (baseTileShape.size() == 1) {
876+
if (wgShape.size() == 1) {
886877
// 1D: offset[0] * strideConst
887878
mulOffset = rewriter.create<arith::MulIOp>(
888879
loc, rewriter.getIndexType(), offsets[0], strideConst);
889-
} else if (baseTileShape.size() == 2) {
890-
// 2D: offset[0]*strideConstRow + offset[1]*strideConstCol
880+
} else if (wgShape.size() == 2) {
881+
// 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
891882
Value rowMul = rewriter.create<arith::MulIOp>(
892-
loc, rewriter.getIndexType(), offsets[0], strideConstRow);
883+
loc, rewriter.getIndexType(), offsets[0], rowStrideConst);
893884
Value colMul = rewriter.create<arith::MulIOp>(
894-
loc, rewriter.getIndexType(), offsets[1], strideConstCol);
885+
loc, rewriter.getIndexType(), offsets[1], colStrideConst);
895886
mulOffset = rewriter.create<arith::AddIOp>(
896887
loc, rewriter.getIndexType(), rowMul, colMul);
897888
}

0 commit comments

Comments
 (0)