1- // RUN: mlir-opt %s --test-vector-scan-lowering | FileCheck %s
1+ // RUN: mlir-opt %s -split-input-file - -test-vector-scan-lowering | FileCheck %s
22
33// CHECK-LABEL: func @scan1d_inc
44// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
@@ -18,6 +18,20 @@ func.func @scan1d_inc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi
1818 return %0#0 , %0#1 : vector <2 xi32 >, vector <i32 >
1919}
2020
21+ // -----
22+
23+ // Reducing scalable dims is not yet supported!
24+
25+ // CHECK-LABEL: func @scan1d_inc_scalable
26+ // CHECK: vector.scan
27+ func.func @scan1d_inc_scalable (%arg0 : vector <[2 ]xi32 >, %arg1 : vector <i32 >) -> (vector <[2 ]xi32 >, vector <i32 >) {
28+ %0:2 = vector.scan <add >, %arg0 , %arg1 {inclusive = true , reduction_dim = 0 } :
29+ vector <[2 ]xi32 >, vector <i32 >
30+ return %0#0 , %0#1 : vector <[2 ]xi32 >, vector <i32 >
31+ }
32+
33+ // -----
34+
2135// CHECK-LABEL: func @scan1d_exc
2236// CHECK-SAME: %[[ARG0:.*]]: vector<2xi32>,
2337// CHECK-SAME: %[[ARG1:.*]]: vector<i32>
@@ -36,6 +50,20 @@ func.func @scan1d_exc(%arg0 : vector<2xi32>, %arg1 : vector<i32>) -> (vector<2xi
3650 return %0#0 , %0#1 : vector <2 xi32 >, vector <i32 >
3751}
3852
53+ // -----
54+
55+ // Rducing scalable dims is not yet supported!
56+
57+ // CHECK-LABEL: func @scan1d_exc_scalable
58+ // CHECK: vector.scan
59+ func.func @scan1d_exc_scalable (%arg0 : vector <[2 ]xi32 >, %arg1 : vector <i32 >) -> (vector <[2 ]xi32 >, vector <i32 >) {
60+ %0:2 = vector.scan <add >, %arg0 , %arg1 {inclusive = false , reduction_dim = 0 } :
61+ vector <[2 ]xi32 >, vector <i32 >
62+ return %0#0 , %0#1 : vector <[2 ]xi32 >, vector <i32 >
63+ }
64+
65+ // -----
66+
3967// CHECK-LABEL: func @scan2d_mul_dim0
4068// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
4169// CHECK-SAME: %[[ARG1:.*]]: vector<3xi32>
@@ -53,6 +81,27 @@ func.func @scan2d_mul_dim0(%arg0 : vector<2x3xi32>, %arg1 : vector<3xi32>) -> (v
5381 return %0#0 , %0#1 : vector <2 x3 xi32 >, vector <3 xi32 >
5482}
5583
84+ // -----
85+
86+ // CHECK-LABEL: func @scan2d_mul_dim0_scalable
87+ // CHECK-SAME: %[[ARG0:.*]]: vector<2x[3]xi32>,
88+ // CHECK-SAME: %[[ARG1:.*]]: vector<[3]xi32>
89+ // CHECK: %[[A:.*]] = arith.constant dense<0> : vector<2x[3]xi32>
90+ // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32>
91+ // CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32>
92+ // CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [1, 0], sizes = [1, 3], strides = [1, 1]} : vector<2x[3]xi32> to vector<1x[3]xi32>
93+ // CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<1x[3]xi32>
94+ // CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [1, 0], strides = [1, 1]} : vector<1x[3]xi32> into vector<2x[3]xi32>
95+ // CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<1x[3]xi32> to vector<[3]xi32>
96+ // CHECK: return %[[F]], %[[G]] : vector<2x[3]xi32>, vector<[3]xi32>
97+ func.func @scan2d_mul_dim0_scalable (%arg0 : vector <2 x[3 ]xi32 >, %arg1 : vector <[3 ]xi32 >) -> (vector <2 x[3 ]xi32 >, vector <[3 ]xi32 >) {
98+ %0:2 = vector.scan <mul >, %arg0 , %arg1 {inclusive = true , reduction_dim = 0 } :
99+ vector <2 x[3 ]xi32 >, vector <[3 ]xi32 >
100+ return %0#0 , %0#1 : vector <2 x[3 ]xi32 >, vector <[3 ]xi32 >
101+ }
102+
103+ // -----
104+
56105// CHECK-LABEL: func @scan2d_mul_dim1
57106// CHECK-SAME: %[[ARG0:.*]]: vector<2x3xi32>,
58107// CHECK-SAME: %[[ARG1:.*]]: vector<2xi32>
@@ -73,6 +122,30 @@ func.func @scan2d_mul_dim1(%arg0 : vector<2x3xi32>, %arg1 : vector<2xi32>) -> (v
73122 return %0#0 , %0#1 : vector <2 x3 xi32 >, vector <2 xi32 >
74123}
75124
125+ // -----
126+
127+ // CHECK-LABEL: func @scan2d_mul_dim1_scalable
128+ // CHECK-SAME: %[[ARG0:.*]]: vector<[2]x3xi32>,
129+ // CHECK-SAME: %[[ARG1:.*]]: vector<[2]xi32>
130+ // CHECK: %[[A:.*]] = arith.constant dense<0> : vector<[2]x3xi32>
131+ // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
132+ // CHECK: %[[C:.*]] = vector.insert_strided_slice %[[B]], %[[A]] {offsets = [0, 0], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
133+ // CHECK: %[[D:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 1], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
134+ // CHECK: %[[E:.*]] = arith.muli %[[B]], %[[D]] : vector<[2]x1xi32>
135+ // CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[C]] {offsets = [0, 1], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
136+ // CHECK: %[[G:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 2], sizes = [2, 1], strides = [1, 1]} : vector<[2]x3xi32> to vector<[2]x1xi32>
137+ // CHECK: %[[H:.*]] = arith.muli %[[E]], %[[G]] : vector<[2]x1xi32>
138+ // CHECK: %[[I:.*]] = vector.insert_strided_slice %[[H]], %[[F]] {offsets = [0, 2], strides = [1, 1]} : vector<[2]x1xi32> into vector<[2]x3xi32>
139+ // CHECK: %[[J:.*]] = vector.shape_cast %[[H]] : vector<[2]x1xi32> to vector<[2]xi32>
140+ // CHECK: return %[[I]], %[[J]] : vector<[2]x3xi32>, vector<[2]xi32>
141+ func.func @scan2d_mul_dim1_scalable (%arg0 : vector <[2 ]x3 xi32 >, %arg1 : vector <[2 ]xi32 >) -> (vector <[2 ]x3 xi32 >, vector <[2 ]xi32 >) {
142+ %0:2 = vector.scan <mul >, %arg0 , %arg1 {inclusive = true , reduction_dim = 1 } :
143+ vector <[2 ]x3 xi32 >, vector <[2 ]xi32 >
144+ return %0#0 , %0#1 : vector <[2 ]x3 xi32 >, vector <[2 ]xi32 >
145+ }
146+
147+ // -----
148+
76149// CHECK-LABEL: func @scan3d_mul_dim1
77150// CHECK-SAME: %[[ARG0:.*]]: vector<4x2x3xf32>,
78151// CHECK-SAME: %[[ARG1:.*]]: vector<4x3xf32>
@@ -89,3 +162,22 @@ func.func @scan3d_mul_dim1(%arg0 : vector<4x2x3xf32>, %arg1 : vector<4x3xf32>) -
89162 vector <4 x2 x3 xf32 >, vector <4 x3 xf32 >
90163 return %0#0 , %0#1 : vector <4 x2 x3 xf32 >, vector <4 x3 xf32 >
91164}
165+
166+ // -----
167+
168+ // CHECK-LABEL: func @scan3d_mul_dim1_scalable
169+ // CHECK-SAME: %[[ARG0:.*]]: vector<4x2x[3]xf32>,
170+ // CHECK-SAME: %[[ARG1:.*]]: vector<4x[3]xf32>
171+ // CHECK: %[[A:.*]] = arith.constant dense<0.000000e+00> : vector<4x2x[3]xf32>
172+ // CHECK: %[[B:.*]] = vector.extract_strided_slice %[[ARG0]] {offsets = [0, 0, 0], sizes = [4, 1, 3], strides = [1, 1, 1]} : vector<4x2x[3]xf32> to vector<4x1x[3]xf32>
173+ // CHECK: %[[C:.*]] = vector.shape_cast %[[ARG1]] : vector<4x[3]xf32> to vector<4x1x[3]xf32>
174+ // CHECK: %[[D:.*]] = vector.insert_strided_slice %[[C]], %[[A]] {offsets = [0, 0, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32>
175+ // CHECK: %[[E:.*]] = arith.mulf %[[C]], %[[B]] : vector<4x1x[3]xf32>
176+ // CHECK: %[[F:.*]] = vector.insert_strided_slice %[[E]], %[[D]] {offsets = [0, 1, 0], strides = [1, 1, 1]} : vector<4x1x[3]xf32> into vector<4x2x[3]xf32>
177+ // CHECK: %[[G:.*]] = vector.shape_cast %[[E]] : vector<4x1x[3]xf32> to vector<4x[3]xf32>
178+ // CHECK: return %[[F]], %[[G]] : vector<4x2x[3]xf32>, vector<4x[3]xf32>
179+ func.func @scan3d_mul_dim1_scalable (%arg0 : vector <4 x2 x[3 ]xf32 >, %arg1 : vector <4 x[3 ]xf32 >) -> (vector <4 x2 x[3 ]xf32 >, vector <4 x[3 ]xf32 >) {
180+ %0:2 = vector.scan <mul >, %arg0 , %arg1 {inclusive = false , reduction_dim = 1 } :
181+ vector <4 x2 x[3 ]xf32 >, vector <4 x[3 ]xf32 >
182+ return %0#0 , %0#1 : vector <4 x2 x[3 ]xf32 >, vector <4 x[3 ]xf32 >
183+ }
0 commit comments