@@ -773,54 +773,40 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
773773 op, " Only 1D & 2D vector constant supported" );
774774
775775 SmallVector<Attribute> values (vecAttr.getValues <Attribute>());
776- int64_t stride = 0 ;
777776 int64_t rowStride = 0 , colStride = 0 ;
778- if (wgShape.size () == 1 ) {
779- // 1D case: single stride
780- if (values.size () > 1 ) {
781- stride = cast<IntegerAttr>(values[1 ]).getInt () -
782- cast<IntegerAttr>(values[0 ]).getInt ();
783- for (size_t i = 2 ; i < values.size (); ++i) {
784- int64_t diff = cast<IntegerAttr>(values[i]).getInt () -
785- cast<IntegerAttr>(values[i - 1 ]).getInt ();
786- if (diff != stride)
777+ int64_t rows = wgShape.size () == 1 ? 1 : wgShape[0 ];
778+ int64_t cols = wgShape.size () == 1 ? wgShape[0 ] : wgShape[1 ];
779+
780+ // Compute colStride and rowStride, and check for constant strides.
781+ if (cols > 1 ) {
782+ colStride = cast<IntegerAttr>(values[1 ]).getInt () -
783+ cast<IntegerAttr>(values[0 ]).getInt ();
784+ }
785+ if (rows > 1 ) {
786+ rowStride = cast<IntegerAttr>(values[cols]).getInt () -
787+ cast<IntegerAttr>(values[0 ]).getInt ();
788+ }
789+
790+ for (int64_t r = 0 ; r < rows; ++r) {
791+ for (int64_t c = 0 ; c < cols; ++c) {
792+ int64_t idx = r * cols + c;
793+ // Check column stride (skip first column)
794+ if (c > 0 && cols > 1 ) {
795+ int64_t prevIdx = r * cols + (c - 1 );
796+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
797+ cast<IntegerAttr>(values[prevIdx]).getInt ();
798+ if (diff != colStride)
787799 return rewriter.notifyMatchFailure (
788- op, " Non-constant stride in non-splat constant op." );
800+ op, " Non-constant column stride in constant op." );
789801 }
790- }
791- } else if (wgShape.size () == 2 ) {
792- // 2D case: row stride and column stride
793- int64_t rows = wgShape[0 ], cols = wgShape[1 ];
794- // Compute col stride (stride between elements in a column)
795- if (cols > 1 ) {
796- colStride = cast<IntegerAttr>(values[1 ]).getInt () -
797- cast<IntegerAttr>(values[0 ]).getInt ();
798- for (int64_t r = 0 ; r < rows; ++r) {
799- for (int64_t c = 1 ; c < cols; ++c) {
800- int64_t idx = r * cols + c;
801- int64_t prevIdx = r * cols + (c - 1 );
802- int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
803- cast<IntegerAttr>(values[prevIdx]).getInt ();
804- if (diff != colStride)
805- return rewriter.notifyMatchFailure (
806- op, " Non-constant column stride in 2D constant op." );
807- }
808- }
809- }
810- // Compute row stride (stride between elements in a row)
811- if (rows > 1 ) {
812- rowStride = cast<IntegerAttr>(values[cols]).getInt () -
813- cast<IntegerAttr>(values[0 ]).getInt ();
814- for (int64_t c = 0 ; c < cols; ++c) {
815- for (int64_t r = 1 ; r < rows; ++r) {
816- int64_t idx = r * cols + c;
817- int64_t prevIdx = (r - 1 ) * cols + c;
818- int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
819- cast<IntegerAttr>(values[prevIdx]).getInt ();
820- if (diff != rowStride)
821- return rewriter.notifyMatchFailure (
822- op, " Non-constant row stride in 2D constant op." );
823- }
802+ // Check row stride (skip first row)
803+ if (r > 0 && rows > 1 ) {
804+ int64_t prevIdx = (r - 1 ) * cols + c;
805+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
806+ cast<IntegerAttr>(values[prevIdx]).getInt ();
807+ if (diff != rowStride)
808+ return rewriter.notifyMatchFailure (
809+ op, " Non-constant row stride in constant op." );
824810 }
825811 }
826812 }
@@ -829,12 +815,11 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
829815 // For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830816 // For 1D case, extract the first sgShape[0] elements.
831817 SmallVector<Attribute> baseTileValues;
832- int cols = sgShape[sgShape.size () - 1 ];
833- int64_t wgCols = wgShape[sgShape.size () - 1 ];
834- int64_t rows = sgShape.size () == 1 ? 1 : sgShape[0 ];
835- for (int64_t r = 0 ; r < rows; ++r) {
836- for (int64_t c = 0 ; c < cols; ++c) {
837- baseTileValues.push_back (values[r * wgCols + c]);
818+ int baseTileCols = sgShape[sgShape.size () - 1 ];
819+ int64_t baseTileRows = sgShape.size () == 1 ? 1 : sgShape[0 ];
820+ for (int64_t r = 0 ; r < baseTileRows; ++r) {
821+ for (int64_t c = 0 ; c < baseTileCols; ++c) {
822+ baseTileValues.push_back (values[r * cols + c]);
838823 }
839824 }
840825
0 commit comments