@@ -825,35 +825,21 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
825
825
}
826
826
}
827
827
828
- // Determine the shape of the base tile for each subgroup.
829
- SmallVector<int64_t > baseTileShape;
830
- if (sgShape.size () == 1 ) {
831
- baseTileShape.push_back (sgShape[0 ]);
832
- } else if (sgShape.size () == 2 ) {
833
- baseTileShape = sgShape;
834
- } else {
835
- return rewriter.notifyMatchFailure (
836
- op, " Only 1D or 2D vector constant supported" );
837
- }
838
-
839
828
// Create a constant for the base tile.
840
829
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830
+ // For 1D case, extract the first sgShape[0] elements.
841
831
SmallVector<Attribute> baseTileValues;
842
- if (baseTileShape.size () == 2 ) {
843
- int64_t rows = baseTileShape[0 ], cols = baseTileShape[1 ];
844
- int64_t wgCols = wgShape[1 ];
845
- for (int64_t r = 0 ; r < rows; ++r) {
846
- for (int64_t c = 0 ; c < cols; ++c) {
847
- baseTileValues.push_back (values[r * wgCols + c]);
848
- }
832
+ int cols = sgShape[sgShape.size () - 1 ];
833
+ int64_t wgCols = wgShape[sgShape.size () - 1 ];
834
+ int64_t rows = sgShape.size () == 1 ? 1 : sgShape[0 ];
835
+ for (int64_t r = 0 ; r < rows; ++r) {
836
+ for (int64_t c = 0 ; c < cols; ++c) {
837
+ baseTileValues.push_back (values[r * wgCols + c]);
849
838
}
850
- } else {
851
- // 1D case
852
- for (int64_t i = 0 ; i < computeProduct (baseTileShape); ++i)
853
- baseTileValues.push_back (values[i]);
854
839
}
855
- auto tileAttr = DenseElementsAttr::get (
856
- VectorType::get (baseTileShape, eltType), baseTileValues);
840
+
841
+ auto tileAttr = DenseElementsAttr::get (VectorType::get (sgShape, eltType),
842
+ baseTileValues);
857
843
auto baseConstVec = rewriter.create <arith::ConstantOp>(loc, tileAttr);
858
844
859
845
// Get subgroup id
@@ -864,27 +850,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
864
850
if (failed (sgOffsets))
865
851
return failure ();
866
852
867
- auto strideConst = rewriter. create <arith::ConstantIndexOp>(loc, stride) ;
868
- auto rowStrideConst =
869
- rewriter.create <arith::ConstantIndexOp>(loc, rowStride);
870
- auto colStrideConst =
871
- rewriter.create <arith::ConstantIndexOp>(loc, colStride);
853
+ SmallVector<Value, 2 > strideConsts ;
854
+ strideConsts. push_back (
855
+ rewriter.create <arith::ConstantIndexOp>(loc, rowStride)) ;
856
+ strideConsts. push_back (
857
+ rewriter.create <arith::ConstantIndexOp>(loc, colStride)) ;
872
858
SmallVector<Value> newConstOps;
859
+ Value mulOffset;
873
860
for (auto offsets : *sgOffsets) {
874
861
// Multiply offset with stride, broadcast it and add to baseConstVec
875
- Value mulOffset;
876
- if (wgShape.size () == 1 ) {
877
- // 1D: offset[0] * strideConst
878
- mulOffset = rewriter.create <arith::MulIOp>(
879
- loc, rewriter.getIndexType (), offsets[0 ], strideConst);
880
- } else if (wgShape.size () == 2 ) {
881
- // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
882
- Value rowMul = rewriter.create <arith::MulIOp>(
883
- loc, rewriter.getIndexType (), offsets[0 ], rowStrideConst);
884
- Value colMul = rewriter.create <arith::MulIOp>(
885
- loc, rewriter.getIndexType (), offsets[1 ], colStrideConst);
862
+ SmallVector<Value> muls;
863
+ for (size_t i = 0 ; i < strideConsts.size (); ++i) {
864
+ muls.push_back (rewriter.create <arith::MulIOp>(
865
+ loc, rewriter.getIndexType (), offsets[i], strideConsts[i]));
866
+ }
867
+ mulOffset = muls.front ();
868
+ if (muls.size () > 1 ) {
886
869
mulOffset = rewriter.create <arith::AddIOp>(
887
- loc, rewriter.getIndexType (), rowMul, colMul );
870
+ loc, rewriter.getIndexType (), mulOffset, muls[ 1 ] );
888
871
}
889
872
// Broadcast to baseConstVec size
890
873
auto bcastOffset = rewriter.create <vector::BroadcastOp>(
0 commit comments