@@ -825,35 +825,21 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
825825 }
826826 }
827827
828- // Determine the shape of the base tile for each subgroup.
829- SmallVector<int64_t > baseTileShape;
830- if (sgShape.size () == 1 ) {
831- baseTileShape.push_back (sgShape[0 ]);
832- } else if (sgShape.size () == 2 ) {
833- baseTileShape = sgShape;
834- } else {
835- return rewriter.notifyMatchFailure (
836- op, " Only 1D or 2D vector constant supported" );
837- }
838-
839828 // Create a constant for the base tile.
840829 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830+ // For 1D case, extract the first sgShape[0] elements.
841831 SmallVector<Attribute> baseTileValues;
842- if (baseTileShape.size () == 2 ) {
843- int64_t rows = baseTileShape[0 ], cols = baseTileShape[1 ];
844- int64_t wgCols = wgShape[1 ];
845- for (int64_t r = 0 ; r < rows; ++r) {
846- for (int64_t c = 0 ; c < cols; ++c) {
847- baseTileValues.push_back (values[r * wgCols + c]);
848- }
832+ int cols = sgShape[sgShape.size () - 1 ];
833+ int64_t wgCols = wgShape[sgShape.size () - 1 ];
834+ int64_t rows = sgShape.size () == 1 ? 1 : sgShape[0 ];
835+ for (int64_t r = 0 ; r < rows; ++r) {
836+ for (int64_t c = 0 ; c < cols; ++c) {
837+ baseTileValues.push_back (values[r * wgCols + c]);
849838 }
850- } else {
851- // 1D case
852- for (int64_t i = 0 ; i < computeProduct (baseTileShape); ++i)
853- baseTileValues.push_back (values[i]);
854839 }
855- auto tileAttr = DenseElementsAttr::get (
856- VectorType::get (baseTileShape, eltType), baseTileValues);
840+
841+ auto tileAttr = DenseElementsAttr::get (VectorType::get (sgShape, eltType),
842+ baseTileValues);
857843 auto baseConstVec = rewriter.create <arith::ConstantOp>(loc, tileAttr);
858844
859845 // Get subgroup id
@@ -864,27 +850,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
864850 if (failed (sgOffsets))
865851 return failure ();
866852
867- auto strideConst = rewriter. create <arith::ConstantIndexOp>(loc, stride) ;
868- auto rowStrideConst =
869- rewriter.create <arith::ConstantIndexOp>(loc, rowStride);
870- auto colStrideConst =
871- rewriter.create <arith::ConstantIndexOp>(loc, colStride);
853+ SmallVector<Value, 2 > strideConsts ;
854+ strideConsts. push_back (
855+ rewriter.create <arith::ConstantIndexOp>(loc, rowStride)) ;
856+ strideConsts. push_back (
857+ rewriter.create <arith::ConstantIndexOp>(loc, colStride)) ;
872858 SmallVector<Value> newConstOps;
859+ Value mulOffset;
873860 for (auto offsets : *sgOffsets) {
874861 // Multiply offset with stride, broadcast it and add to baseConstVec
875- Value mulOffset;
876- if (wgShape.size () == 1 ) {
877- // 1D: offset[0] * strideConst
878- mulOffset = rewriter.create <arith::MulIOp>(
879- loc, rewriter.getIndexType (), offsets[0 ], strideConst);
880- } else if (wgShape.size () == 2 ) {
881- // 2D: offset[0]*rowStrideConst + offset[1]*colStrideConst
882- Value rowMul = rewriter.create <arith::MulIOp>(
883- loc, rewriter.getIndexType (), offsets[0 ], rowStrideConst);
884- Value colMul = rewriter.create <arith::MulIOp>(
885- loc, rewriter.getIndexType (), offsets[1 ], colStrideConst);
862+ SmallVector<Value> muls;
863+ for (size_t i = 0 ; i < strideConsts.size (); ++i) {
864+ muls.push_back (rewriter.create <arith::MulIOp>(
865+ loc, rewriter.getIndexType (), offsets[i], strideConsts[i]));
866+ }
867+ mulOffset = muls.front ();
868+ if (muls.size () > 1 ) {
886869 mulOffset = rewriter.create <arith::AddIOp>(
887- loc, rewriter.getIndexType (), rowMul, colMul );
870+ loc, rewriter.getIndexType (), mulOffset, muls[ 1 ] );
888871 }
889872 // Broadcast to baseConstVec size
890873 auto bcastOffset = rewriter.create <vector::BroadcastOp>(
0 commit comments