@@ -762,13 +762,12 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
762
762
return success ();
763
763
} else {
764
764
// Non-splat constant
765
- // Only supports 1D & 2D (with one unit dim)
765
+ // Only supports 1D & 2D
766
766
// TODO: support other cases that require SLM access
767
767
if (!eltType.isIndex ())
768
768
return rewriter.notifyMatchFailure (
769
769
op, " Unsupported element type for non-splat constant op." );
770
770
771
- SmallVector<int64_t > sgLayout = layout.getEffectiveSgLayoutAsInt ();
772
771
if (wgShape.size () > 2 )
773
772
return rewriter.notifyMatchFailure (
774
773
op, " Only 1D & 2D vector constant supported" );
@@ -792,9 +791,6 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
792
791
} else if (wgShape.size () == 2 ) {
793
792
// 2D case: row stride and column stride
794
793
int64_t rows = wgShape[0 ], cols = wgShape[1 ];
795
- if (values.size () != static_cast <size_t >(rows * cols))
796
- return rewriter.notifyMatchFailure (
797
- op, " Mismatch between vector shape and constant values size." );
798
794
// Compute col stride (stride between elements in a column)
799
795
if (cols > 1 ) {
800
796
colStride = cast<IntegerAttr>(values[1 ]).getInt () -
@@ -840,25 +836,20 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
840
836
op, " Only 1D or 2D vector constant supported" );
841
837
}
842
838
843
- // Compute the number of elements in the base tile.
844
- int64_t baseTileElemCount = 1 ;
845
- for (int64_t d : baseTileShape)
846
- baseTileElemCount *= d;
847
-
848
839
// Create a constant for the base tile.
849
840
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
850
841
SmallVector<Attribute> baseTileValues;
851
842
if (baseTileShape.size () == 2 ) {
852
843
int64_t rows = baseTileShape[0 ], cols = baseTileShape[1 ];
853
- int64_t wgRows = wgShape[ 0 ], wgCols = wgShape[1 ];
844
+ int64_t wgCols = wgShape[1 ];
854
845
for (int64_t r = 0 ; r < rows; ++r) {
855
846
for (int64_t c = 0 ; c < cols; ++c) {
856
847
baseTileValues.push_back (values[r * wgCols + c]);
857
848
}
858
849
}
859
850
} else {
860
851
// 1D case
861
- for (int64_t i = 0 ; i < baseTileElemCount ; ++i)
852
+ for (int64_t i = 0 ; i < computeProduct (baseTileShape) ; ++i)
862
853
baseTileValues.push_back (values[i]);
863
854
}
864
855
auto tileAttr = DenseElementsAttr::get (
@@ -874,24 +865,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
874
865
return failure ();
875
866
876
867
auto strideConst = rewriter.create <arith::ConstantIndexOp>(loc, stride);
877
- auto strideConstRow =
868
+ auto rowStrideConst =
878
869
rewriter.create <arith::ConstantIndexOp>(loc, rowStride);
879
- auto strideConstCol =
870
+ auto colStrideConst =
880
871
rewriter.create <arith::ConstantIndexOp>(loc, colStride);
881
872
SmallVector<Value> newConstOps;
882
873
for (auto offsets : *sgOffsets) {
883
874
// Multiply offset with stride, broadcast it and add to baseConstVec
884
875
Value mulOffset;
885
- if (baseTileShape .size () == 1 ) {
876
+ if (wgShape .size () == 1 ) {
886
877
// 1D: offset[0] * strideConst
887
878
mulOffset = rewriter.create <arith::MulIOp>(
888
879
loc, rewriter.getIndexType (), offsets[0 ], strideConst);
889
- } else if (baseTileShape .size () == 2 ) {
890
- // 2D: offset[0]*strideConstRow + offset[1]*strideConstCol
880
+ } else if (wgShape .size () == 2 ) {
881
+ // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
891
882
Value rowMul = rewriter.create <arith::MulIOp>(
892
- loc, rewriter.getIndexType (), offsets[0 ], strideConstRow );
883
+ loc, rewriter.getIndexType (), offsets[0 ], rowStrideConst );
893
884
Value colMul = rewriter.create <arith::MulIOp>(
894
- loc, rewriter.getIndexType (), offsets[1 ], strideConstCol );
885
+ loc, rewriter.getIndexType (), offsets[1 ], colStrideConst );
895
886
mulOffset = rewriter.create <arith::AddIOp>(
896
887
loc, rewriter.getIndexType (), rowMul, colMul);
897
888
}
0 commit comments