@@ -39,13 +39,15 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
3939// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
4040// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
4141// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
42+ // CHECK-SAME: kind = #vector.kind<add>
4243// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
4344// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
4445
4546/// w == 1, kw == 0
4647// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
4748// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
4849// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
50+ // CHECK-SAME: kind = #vector.kind<add>
4951// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
5052// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
5153
@@ -61,6 +63,36 @@ func.func @conv1d_nwc_4x2x8_memref(%input: memref<4x6x3xf32>, %filter: memref<1x
6163
6264// -----
6365
66+ // This test is same as above but for i1 type with the only difference being that
67+ // the combining kind for `vector.contract` is `OR`.
68+ func.func @conv1d_nwc_4x2x8_memref_i1 (%input: memref <4 x6 x3 xi1 >, %filter: memref <1 x3 x8 xi1 >, %output: memref <4 x2 x8 xi1 >) {
69+ linalg.conv_1d_nwc_wcf
70+ {dilations = dense <1 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
71+ ins (%input , %filter : memref <4 x6 x3 xi1 >, memref <1 x3 x8 xi1 >)
72+ outs (%output : memref <4 x2 x8 xi1 >)
73+ return
74+ }
75+ // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
76+ // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
77+ // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
78+
79+ // CHECK: func @conv1d_nwc_4x2x8_memref_i1
80+ /// w == 0, kw == 0
81+ // CHECK: %[[CONTRACT_0:.+]] = vector.contract {
82+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
83+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
84+ // CHECK-SAME: kind = #vector.kind<or>
85+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
86+
87+ /// w == 1, kw == 0
88+ // CHECK: %[[CONTRACT_1:.+]] = vector.contract {
89+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
90+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
91+ // CHECK-SAME: kind = #vector.kind<or>
92+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
93+
94+ // -----
95+
6496// The i8i8i32 case is similar to f32 case, so checking one case is enough for
6597// test coverage.
6698func.func @conv1d_nwc_4x2x8_i8i8i32_memref (%input: memref <4 x6 x3 xi8 >, %filter: memref <1 x3 x8 xi8 >, %output: memref <4 x2 x8 xi32 >) {
@@ -299,13 +331,15 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
299331// CHECK: %[[CONTRACT_0:.+]] = vector.contract {
300332// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
301333// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
334+ // CHECK-SAME: kind = #vector.kind<add>
302335// CHECK-SAME: %[[V_INPUT_0]], %[[V_FILTER]], %[[V_OUTPUT_0]]
303336// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
304337
305338/// w == 1, kw == 0
306339// CHECK: %[[CONTRACT_1:.+]] = vector.contract {
307340// CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
308341// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
342+ // CHECK-SAME: kind = #vector.kind<add>
309343// CHECK-SAME: %[[V_INPUT_1]], %[[V_FILTER]], %[[V_OUTPUT_1]]
310344// CHECK-SAME: : vector<4x1x3xf32>, vector<3x8xf32> into vector<4x1x8xf32>
311345
@@ -324,6 +358,37 @@ func.func @conv1d_ncw_4x8x2_memref(%input: memref<4x3x6xf32>, %filter: memref<8x
324358
325359// -----
326360
361+ // This test is same as above but for i1 type with the only difference being that
362+ // the combining kind for `vector.contract` is `OR`.
363+ func.func @conv1d_ncw_4x8x2_memref_i1 (%input: memref <4 x3 x6 xi1 >, %filter: memref <8 x3 x1 xi1 >, %output: memref <4 x8 x2 xi1 >) {
364+ linalg.conv_1d_ncw_fcw
365+ {dilations = dense <1 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
366+ ins (%input , %filter : memref <4 x3 x6 xi1 >, memref <8 x3 x1 xi1 >)
367+ outs (%output : memref <4 x8 x2 xi1 >)
368+ return
369+ }
370+
371+ // CHECK: #[[INPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
372+ // CHECK: #[[FILTER_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
373+ // CHECK: #[[OUTPUT_MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
374+
375+ // CHECK: func @conv1d_ncw_4x8x2_memref_i1
376+ /// w == 0, kw == 0
377+ // CHECK: vector.contract {
378+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
379+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
380+ // CHECK-SAME: kind = #vector.kind<or>
381+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
382+
383+ /// w == 1, kw == 0
384+ // CHECK: vector.contract {
385+ // CHECK-SAME: indexing_maps = [#[[INPUT_MAP]], #[[FILTER_MAP]], #[[OUTPUT_MAP]]],
386+ // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "reduction"]
387+ // CHECK-SAME: kind = #vector.kind<or>
388+ // CHECK-SAME: : vector<4x1x3xi1>, vector<3x8xi1> into vector<4x1x8xi1>
389+
390+ // -----
391+
327392func.func @conv1d_ncw_4x8x2_memref (%input: memref <4 x3 x6 xf32 >, %filter: memref <8 x3 x2 xf32 >, %output: memref <4 x8 x2 xf32 >) {
328393 linalg.conv_1d_ncw_fcw
329394 {dilations = dense <2 > : tensor <1 xi64 >, strides = dense <3 > : tensor <1 xi64 >}
0 commit comments