@@ -773,54 +773,40 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
773
773
op, " Only 1D & 2D vector constant supported" );
774
774
775
775
SmallVector<Attribute> values (vecAttr.getValues <Attribute>());
776
- int64_t stride = 0 ;
777
776
int64_t rowStride = 0 , colStride = 0 ;
778
- if (wgShape.size () == 1 ) {
779
- // 1D case: single stride
780
- if (values.size () > 1 ) {
781
- stride = cast<IntegerAttr>(values[1 ]).getInt () -
782
- cast<IntegerAttr>(values[0 ]).getInt ();
783
- for (size_t i = 2 ; i < values.size (); ++i) {
784
- int64_t diff = cast<IntegerAttr>(values[i]).getInt () -
785
- cast<IntegerAttr>(values[i - 1 ]).getInt ();
786
- if (diff != stride)
777
+ int64_t rows = wgShape.size () == 1 ? 1 : wgShape[0 ];
778
+ int64_t cols = wgShape.size () == 1 ? wgShape[0 ] : wgShape[1 ];
779
+
780
+ // Compute colStride and rowStride, and check for constant strides.
781
+ if (cols > 1 ) {
782
+ colStride = cast<IntegerAttr>(values[1 ]).getInt () -
783
+ cast<IntegerAttr>(values[0 ]).getInt ();
784
+ }
785
+ if (rows > 1 ) {
786
+ rowStride = cast<IntegerAttr>(values[cols]).getInt () -
787
+ cast<IntegerAttr>(values[0 ]).getInt ();
788
+ }
789
+
790
+ for (int64_t r = 0 ; r < rows; ++r) {
791
+ for (int64_t c = 0 ; c < cols; ++c) {
792
+ int64_t idx = r * cols + c;
793
+ // Check column stride (skip first column)
794
+ if (c > 0 && cols > 1 ) {
795
+ int64_t prevIdx = r * cols + (c - 1 );
796
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
797
+ cast<IntegerAttr>(values[prevIdx]).getInt ();
798
+ if (diff != colStride)
787
799
return rewriter.notifyMatchFailure (
788
- op, " Non-constant stride in non-splat constant op." );
800
+ op, " Non-constant column stride in constant op." );
789
801
}
790
- }
791
- } else if (wgShape.size () == 2 ) {
792
- // 2D case: row stride and column stride
793
- int64_t rows = wgShape[0 ], cols = wgShape[1 ];
794
- // Compute col stride (stride between elements in a column)
795
- if (cols > 1 ) {
796
- colStride = cast<IntegerAttr>(values[1 ]).getInt () -
797
- cast<IntegerAttr>(values[0 ]).getInt ();
798
- for (int64_t r = 0 ; r < rows; ++r) {
799
- for (int64_t c = 1 ; c < cols; ++c) {
800
- int64_t idx = r * cols + c;
801
- int64_t prevIdx = r * cols + (c - 1 );
802
- int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
803
- cast<IntegerAttr>(values[prevIdx]).getInt ();
804
- if (diff != colStride)
805
- return rewriter.notifyMatchFailure (
806
- op, " Non-constant column stride in 2D constant op." );
807
- }
808
- }
809
- }
810
- // Compute row stride (stride between elements in a row)
811
- if (rows > 1 ) {
812
- rowStride = cast<IntegerAttr>(values[cols]).getInt () -
813
- cast<IntegerAttr>(values[0 ]).getInt ();
814
- for (int64_t c = 0 ; c < cols; ++c) {
815
- for (int64_t r = 1 ; r < rows; ++r) {
816
- int64_t idx = r * cols + c;
817
- int64_t prevIdx = (r - 1 ) * cols + c;
818
- int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
819
- cast<IntegerAttr>(values[prevIdx]).getInt ();
820
- if (diff != rowStride)
821
- return rewriter.notifyMatchFailure (
822
- op, " Non-constant row stride in 2D constant op." );
823
- }
802
+ // Check row stride (skip first row)
803
+ if (r > 0 && rows > 1 ) {
804
+ int64_t prevIdx = (r - 1 ) * cols + c;
805
+ int64_t diff = cast<IntegerAttr>(values[idx]).getInt () -
806
+ cast<IntegerAttr>(values[prevIdx]).getInt ();
807
+ if (diff != rowStride)
808
+ return rewriter.notifyMatchFailure (
809
+ op, " Non-constant row stride in constant op." );
824
810
}
825
811
}
826
812
}
@@ -829,12 +815,11 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
829
815
// For 2D case, extract the top-left sgShape[0] x sgShape[1] submatrix.
830
816
// For 1D case, extract the first sgShape[0] elements.
831
817
SmallVector<Attribute> baseTileValues;
832
- int cols = sgShape[sgShape.size () - 1 ];
833
- int64_t wgCols = wgShape[sgShape.size () - 1 ];
834
- int64_t rows = sgShape.size () == 1 ? 1 : sgShape[0 ];
835
- for (int64_t r = 0 ; r < rows; ++r) {
836
- for (int64_t c = 0 ; c < cols; ++c) {
837
- baseTileValues.push_back (values[r * wgCols + c]);
818
+ int baseTileCols = sgShape[sgShape.size () - 1 ];
819
+ int64_t baseTileRows = sgShape.size () == 1 ? 1 : sgShape[0 ];
820
+ for (int64_t r = 0 ; r < baseTileRows; ++r) {
821
+ for (int64_t c = 0 ; c < baseTileCols; ++c) {
822
+ baseTileValues.push_back (values[r * cols + c]);
838
823
}
839
824
}
840
825
0 commit comments