1+ // RUN: mlir-opt %s | FileCheck %s
2+
3+ // CHECK-LABEL: test_collapse(
4+ func.func @test_collapse (%arg0: memref <1 x?xf32 , strided <[5 , 1 ]>>) {
5+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 ]] : memref <1 x?xf32 , strided <[5 , 1 ]>> into memref <?xf32 , strided <[1 ]>>
6+ return
7+ }
8+
9+ // CHECK-LABEL: test_collapse_5d_middle_dynamic(
10+ func.func @test_collapse_5d_middle_dynamic (%arg0: memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>>) {
11+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
12+ : memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>> into memref <?xf32 , strided <[?]>>
13+ return
14+ }
15+
16+ // CHECK-LABEL: test_collapse_5d_mostly_units(
17+ func.func @test_collapse_5d_mostly_units (%arg0: memref <1 x1 x1 x?x1 xf32 , strided <[320 , 80 , 16 , 2 , 1 ]>>) {
18+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
19+ : memref <1 x1 x1 x?x1 xf32 , strided <[320 , 80 , 16 , 2 , 1 ]>> into memref <?xf32 , strided <[2 ]>>
20+ return
21+ }
22+
23+ // CHECK-LABEL: test_partial_collapse_6d(
24+ func.func @test_partial_collapse_6d (%arg0: memref <1 x?x1 x1 x5 x1 xf32 , strided <[3360 , 420 , 140 , 35 , 7 , 1 ]>>) {
25+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 ], [4 , 5 ]]
26+ : memref <1 x?x1 x1 x5 x1 xf32 , strided <[3360 , 420 , 140 , 35 , 7 , 1 ]>> into memref <?x5 xf32 , strided <[420 , 7 ]>>
27+ return
28+ }
29+
30+ // CHECK-LABEL: test_collapse_5d_grouped(
31+ func.func @test_collapse_5d_grouped (%arg0: memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>>) {
32+ %collapse_shape = memref.collapse_shape %arg0 [[0 ], [1 , 2 , 3 , 4 ]]
33+ : memref <1 x5 x1 x?x1 xf32 , strided <[540 , 108 , 18 , 2 , 1 ]>> into memref <1 x?xf32 , strided <[540 , ?]>>
34+ return
35+ }
36+
37+ // CHECK-LABEL: test_collapse_all_units(
38+ func.func @test_collapse_all_units (%arg0: memref <1 x1 x1 x1 x1 xf32 , strided <[100 , 50 , 25 , 10 , 1 ]>>) {
39+ %collapse_shape = memref.collapse_shape %arg0 [[0 , 1 , 2 , 3 , 4 ]]
40+ : memref <1 x1 x1 x1 x1 xf32 , strided <[100 , 50 , 25 , 10 , 1 ]>> into memref <1 xf32 , strided <[100 ]>>
41+ return
42+ }
0 commit comments