Skip to content

Commit f155fe1

Browse files
committed
fixing ndarray tests
1 parent a0630b9 commit f155fe1

File tree

5 files changed

+131
-151
lines changed

5 files changed

+131
-151
lines changed

include/imex/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
#include "mlir/Pass/Pass.h"
1919

20+
#include <imex/Conversion/ArithToVC/ArithToVC.h>
2021
#include <imex/Conversion/DropRegions/DropRegions.h>
2122
#include <imex/Conversion/GPUToGPUX/GPUToGPUX.h>
2223
#include <imex/Conversion/GPUToSPIRV/GPUToSPIRVPass.h>

test/Conversion/NDArrayToLinalg/NDArrayToLinalg.mlir

Lines changed: 29 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,5 @@
11
// RUN: imex-opt --split-input-file --convert-ndarray-to-linalg %s -verify-diagnostics -o -| FileCheck %s
22

3-
// -----
4-
func.func @test_subview(%arg0: tensor<?xi64>) -> tensor<?xi64> {
5-
%c0 = arith.constant 0 : index
6-
%c3 = arith.constant 3 : index
7-
%0 = ndarray.subview %arg0[%c0][%c3][%c3] : tensor<?xi64> to tensor<?xi64>
8-
return %0 : tensor<?xi64>
9-
}
10-
// CHECK-LABEL: @test_subview
11-
// CHECK-SAME: ([[V:%.*]]: tensor<?xi64>) -> tensor<?xi64> {
12-
// CHECK-NEXT: [[C0:%.*]] = arith.constant
13-
// CHECK-NEXT: [[C1:%.*]] = arith.constant
14-
// CHECK-NEXT: [[V0:%.*]] = bufferization.to_memref [[V]] : tensor<?xi64> to memref<?xi64, strided<[?], offset: ?>>
15-
// CHECK-NEXT: [[S0:%.*]] = memref.subview [[V0]][[[C0]]] [[[C1]]] [[[C1]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
16-
// CHECK-NEXT: [[V1:%.*]] = bufferization.to_tensor [[S0]] restrict writable : memref<?xi64, strided<[?], offset: ?>>
17-
// CHECK-NEXT: return [[V1]] : tensor<?xi64>
18-
193
// -----
204
func.func @test_linspace(%arg0: i64, %arg1: i64, %arg2: index) -> tensor<?xindex> {
215
%0 = ndarray.linspace %arg0 %arg1 %arg2 false : (i64, i64, index) -> tensor<?xindex>
@@ -72,42 +56,6 @@ func.func @test_reshape2(%arg0: index) -> tensor<?x?xi64> {
7256
// CHECK: tensor.reshape
7357
// CHECK-SAME: -> tensor<?x?xi64>
7458

75-
// -----
76-
func.func @test_insert_slice(%arg0: tensor<?xi64>, %arg1: tensor<?xi64>) {
77-
%i0 = arith.constant 0 : index
78-
%i1 = arith.constant 1 : index
79-
%i3 = arith.constant 3 : index
80-
ndarray.insert_slice %arg1 into %arg0[%i0] [%i3] [%i1] : tensor<?xi64> into tensor<?xi64>
81-
return
82-
}
83-
// CHECK-LABEL: @test_insert_slice
84-
// CHECK-SAME: ([[V:%.*]]: tensor<?xi64>, [[VV:%.*]]: tensor<?xi64>) {
85-
// CHECK-NEXT: [[C0:%.*]] = arith.constant
86-
// CHECK-NEXT: [[C1:%.*]] = arith.constant
87-
// CHECK-NEXT: [[C3:%.*]] = arith.constant
88-
// CHECK-NEXT: [[V0:%.*]] = bufferization.to_memref [[VV]]
89-
// CHECK-NEXT: [[V1:%.*]] = bufferization.to_memref [[V]]
90-
// CHECK-NEXT: [[SV:%.*]] = memref.subview [[V1]][[[C0]]] [[[C3]]] [[[C1]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
91-
// CHECK: memref.copy [[V0]], [[SV]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
92-
93-
// -----
94-
func.func @test_insert_slice_scalar(%arg0: tensor<?xi64>, %arg1: tensor<i64>) {
95-
%i0 = arith.constant 0 : index
96-
%i1 = arith.constant 1 : index
97-
%i3 = arith.constant 3 : index
98-
ndarray.insert_slice %arg1 into %arg0[%i0] [%i3] [%i1] : tensor<i64> into tensor<?xi64>
99-
return
100-
}
101-
// CHECK-LABEL: @test_insert_slice_scalar
102-
// CHECK-SAME: ([[V:%.*]]: tensor<?xi64>, [[VV:%.*]]: tensor<i64>) {
103-
// CHECK-NEXT: [[C0:%.*]] = arith.constant
104-
// CHECK-NEXT: [[C1:%.*]] = arith.constant
105-
// CHECK-NEXT: [[C3:%.*]] = arith.constant
106-
// CHECK-NEXT: [[V0:%.*]] = bufferization.to_memref [[VV]]
107-
// CHECK-NEXT: [[V1:%.*]] = bufferization.to_memref [[V]]
108-
// CHECK-NEXT: [[SV:%.*]] = memref.subview [[V1]][[[C0]]] [[[C3]]] [[[C1]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
109-
// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins([[V0]] : memref<i64, strided<[], offset: ?>>) outs([[SV]] : memref<?xi64, strided<[?], offset: ?>>)
110-
11159
// -----
11260
#GPUENV = #ndarray.envs<#region.gpu_env<device = "g">>
11361
func.func @test_env() -> (tensor<16x16xf32, #GPUENV>, tensor<256xf32, #GPUENV>) attributes {llvm.emit_c_interface} {
@@ -170,41 +118,36 @@ func.func @test_env() -> (tensor<16x16xf32, #GPUENV>, tensor<256xf32, #GPUENV>)
170118
// COM: CHECK-SAME: memref<?xi64, strided<[?], offset: ?>>
171119

172120
// -----
173-
func.func @test_copy(%a: !ndarray.ndarray<?xi64>) -> !ndarray.ndarray<?xi64> {
174-
%0 = ndarray.copy %a: !ndarray.ndarray<?xi64> -> !ndarray.ndarray<?xi64>
175-
%1 = ndarray.copy %0: !ndarray.ndarray<?xi64> -> !ndarray.ndarray<?xi64, #region.gpu_env<device = "XeGPU">>
176-
%2 = ndarray.copy %1: !ndarray.ndarray<?xi64, #region.gpu_env<device = "XeGPU">> -> !ndarray.ndarray<?xi64>
177-
return %0 : !ndarray.ndarray<?xi64>
121+
func.func @test_copy(%a: tensor<?xi64>) -> tensor<?xi64> {
122+
%0 = ndarray.copy %a: tensor<?xi64> -> tensor<?xi64>
123+
%1 = ndarray.copy %0: tensor<?xi64> -> tensor<?xi64, #region.gpu_env<device = "XeGPU">>
124+
%2 = ndarray.copy %1: tensor<?xi64, #region.gpu_env<device = "XeGPU">> -> tensor<?xi64>
125+
return %0 : tensor<?xi64>
178126
}
179-
// CHECK-LABEL: func.func @test_copy
180-
// CHECK-NEXT: bufferization.to_tensor
181-
// CHECK-NEXT: arith.constant 0 : index
182-
// CHECK-NEXT: tensor.dim
183-
// CHECK-NEXT: memref.alloc
184-
// CHECK-NEXT: bufferization.to_memref
185-
// CHECK-NEXT: region.env_region "protect_copy_op"
186-
// CHECK-NEXT: memref.copy
187-
// CHECK-NEXT: }
188-
// CHECK-NEXT: bufferization.to_tensor
189-
// CHECK-NEXT: bufferization.to_memref
190-
// CHECK-NEXT: arith.constant 0 : index
191-
// CHECK-NEXT: tensor.dim
192-
// CHECK-NEXT: memref.alloc
193-
// CHECK-NEXT: bufferization.to_memref
194-
// CHECK-NEXT: region.env_region "gpu_copy_op"
195-
// CHECK-NEXT: memref.copy
196-
// CHECK-NEXT: }
197-
// CHECK-NEXT: bufferization.to_tensor
198-
// CHECK-NEXT: arith.constant 0 : index
199-
// CHECK-NEXT: tensor.dim
200-
// CHECK-NEXT: memref.alloc
201-
// CHECK-NEXT: bufferization.to_memref
202-
// CHECK-NEXT: region.env_region "gpu_copy_op"
203-
// CHECK-NEXT: memref.copy
204-
// CHECK-NEXT: }
205-
// CHECK-NEXT: bufferization.to_tensor
206-
// CHECK-NEXT: return
207-
// CHECK-SAME: memref<?xi64, strided<[?], offset: ?>>
127+
// CHECK-LABEL: func.func @test_copy(
128+
// CHECK-SAME: [[varg0:%.*]]: tensor<?xi64>) -> tensor<?xi64> {
129+
// CHECK-NEXT: [[vc0:%.*]] = arith.constant 0 : index
130+
// CHECK-NEXT: [[vdim:%.*]] = tensor.dim [[varg0]], [[vc0]] : tensor<?xi64>
131+
// CHECK-NEXT: [[valloc:%.*]] = memref.alloc([[vdim]]) {alignment = 8 : i64} : memref<?xi64>
132+
// CHECK-NEXT: [[v0:%.*]] = bufferization.to_memref [[varg0]] : tensor<?xi64> to memref<?xi64, strided<[?], offset: ?>>
133+
// CHECK-NEXT: region.env_region "protect_copy_op" {
134+
// CHECK-NEXT: memref.copy [[v0]], [[valloc]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64>
135+
// CHECK: [[v1:%.*]] = bufferization.to_tensor [[valloc]] restrict writable : memref<?xi64> to tensor<?xi64>
136+
// CHECK-NEXT: [[vc0_0:%.*]] = arith.constant 0 : index
137+
// CHECK-NEXT: [[vdim_1:%.*]] = tensor.dim [[v1]], [[vc0_0]] : tensor<?xi64>
138+
// CHECK-NEXT: [[valloc_2:%.*]] = memref.alloc([[vdim_1]]) {alignment = 8 : i64} : memref<?xi64>
139+
// CHECK-NEXT: [[v2:%.*]] = bufferization.to_memref [[v1]] : tensor<?xi64> to memref<?xi64, strided<[?], offset: ?>>
140+
// CHECK-NEXT: region.env_region "protect_copy_op" {
141+
// CHECK-NEXT: memref.copy [[v2]], [[valloc_2]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64>
142+
// CHECK: [[v3:%.*]] = bufferization.to_tensor [[valloc_2]] restrict writable : memref<?xi64> to tensor<?xi64, #region.gpu_env<device = "XeGPU">>
143+
// CHECK-NEXT: [[vc0_3:%.*]] = arith.constant 0 : index
144+
// CHECK-NEXT: [[vdim_4:%.*]] = tensor.dim [[v3]], [[vc0_3]] : tensor<?xi64, #region.gpu_env<device = "XeGPU">>
145+
// CHECK-NEXT: [[valloc_5:%.*]] = memref.alloc([[vdim_4]]) {alignment = 8 : i64} : memref<?xi64>
146+
// CHECK-NEXT: [[v4:%.*]] = bufferization.to_memref [[v3]] : tensor<?xi64, #region.gpu_env<device = "XeGPU">> to memref<?xi64, strided<[?], offset: ?>>
147+
// CHECK-NEXT: region.env_region "protect_copy_op" {
148+
// CHECK-NEXT: memref.copy [[v4]], [[valloc_5]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64>
149+
// CHECK: [[v5:%.*]] = bufferization.to_tensor [[valloc_5]] restrict writable : memref<?xi64> to tensor<?xi64>
150+
// CHECK-NEXT: return [[v1]] : tensor<?xi64>
208151

209152
// -----
210153
func.func @test_delete(%arg0: tensor<?xi64>) {

test/Dialect/NDArray/Extensions/mesh-spmdization.mlir

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,44 +49,47 @@ func.func @test_cast_elemtypeop(%arg0: tensor<1024x1024xi64>) -> tensor<1024x102
4949
func.func @test_linspace() -> tensor<?xi64> {
5050
%c0 = arith.constant 0 : i64
5151
%c10 = arith.constant 10 : i64
52-
// CHECK: [[vcst_0:%.*]] = arith.constant 4.000000e+00 : f64
53-
// CHECK-NEXT: [[vcst:%.*]] = arith.constant 3.000000e+00 : f64
54-
// CHECK-NEXT: [[vcst_1:%.*]] = arith.constant 7.000000e+00 : f64
55-
// CHECK-NEXT: [[v0:%.*]] = ndarray.linspace [[vcst_0]] [[vcst_1]] [[vcst]] false : (f64, f64, f64) -> tensor<?xi64>
52+
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant 4.000000e+00 : f64
53+
// CHECK-DAG: [[vcst:%.*]] = arith.constant 3 : index
54+
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant 7.000000e+00 : f64
55+
// CHECK: [[v0:%.*]] = ndarray.linspace [[vcst_0]] [[vcst_1]] [[vcst]] false : (f64, f64, index) -> tensor<3xi64>
5656
%0 = ndarray.linspace %c0 %c10 %c10 false : (i64, i64, i64) -> tensor<?xi64>
5757
%s = mesh.sharding @mesh4 split_axes = [[0]] : !mesh.sharding
5858
%1 = mesh.shard %0 to %s : tensor<?xi64>
59-
// CHECK-NEXT: return [[v0]] : tensor<?xi64>
59+
// CHECK: [[cast:%.*]] = tensor.cast [[v0]] : tensor<3xi64> to tensor<?xi64>
60+
// CHECK-NEXT: return [[cast]] : tensor<?xi64>
6061
return %1 : tensor<?xi64>
6162
}
6263

6364
// CHECK-LABEL: @test_linspace_halos
6465
func.func @test_linspace_halos() -> tensor<?xi64> {
6566
%c0 = arith.constant 0 : i64
6667
%c10 = arith.constant 10 : i64
67-
// CHECK: [[vcst:%.*]] = arith.constant 3.000000e+00 : f64
68-
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 7.000000e+00 : f64
69-
// CHECK-NEXT: [[vcst_1:%.*]] = arith.constant 1.000000e+01 : f64
70-
// CHECK-NEXT: [[v0:%.*]] = ndarray.linspace [[vcst]] [[vcst_1]] [[vcst_0]] false : (f64, f64, f64) -> tensor<?xi64>
68+
// CHECK-DAG: [[vcst:%.*]] = arith.constant 3.000000e+00 : f64
69+
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant 7 : index
70+
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant 1.000000e+01 : f64
71+
// CHECK: [[v0:%.*]] = ndarray.linspace [[vcst]] [[vcst_1]] [[vcst_0]] false : (f64, f64, index) -> tensor<7xi64>
7172
%0 = ndarray.linspace %c0 %c10 %c10 false : (i64, i64, i64) -> tensor<?xi64>
7273
%s = mesh.sharding @mesh4 split_axes = [[0]] halo_sizes = [1, 3]: !mesh.sharding
7374
%1 = mesh.shard %0 to %s : tensor<?xi64>
74-
// CHECK-NEXT: return [[v0]] : tensor<?xi64>
75+
// CHECK: [[cast:%.*]] = tensor.cast [[v0]] : tensor<7xi64> to tensor<?xi64>
76+
// CHECK-NEXT: return [[cast]] : tensor<?xi64>
7577
return %1 : tensor<?xi64>
7678
}
7779

7880
// CHECK-LABEL: @test_linspace_offsets
7981
func.func @test_linspace_offsets() -> tensor<?xi64> {
8082
%c0 = arith.constant 0 : i64
8183
%c10 = arith.constant 10 : i64
82-
// CHECK: [[vcst:%.*]] = arith.constant 1.000000e+00 : f64
83-
// CHECK-NEXT: [[vcst_0:%.*]] = arith.constant 5.000000e+00 : f64
84-
// CHECK-NEXT: [[vcst_1:%.*]] = arith.constant 6.000000e+00 : f64
85-
// CHECK-NEXT: [[v0:%.*]] = ndarray.linspace [[vcst_0]] [[vcst_1]] [[vcst]] false : (f64, f64, f64) -> tensor<?xi64>
84+
// CHECK-DAG: [[vcst:%.*]] = arith.constant 1 : index
85+
// CHECK-DAG: [[vcst_0:%.*]] = arith.constant 5.000000e+00 : f64
86+
// CHECK-DAG: [[vcst_1:%.*]] = arith.constant 6.000000e+00 : f64
87+
// CHECK-NEXT: [[v0:%.*]] = ndarray.linspace [[vcst_0]] [[vcst_1]] [[vcst]] false : (f64, f64, index) -> tensor<1xi64>
8688
%0 = ndarray.linspace %c0 %c10 %c10 false : (i64, i64, i64) -> tensor<?xi64>
8789
%s = mesh.sharding @mesh4 split_axes = [[0]] sharded_dims_offsets = [0, 0, 5, 6, 10]: !mesh.sharding
8890
%1 = mesh.shard %0 to %s : tensor<?xi64>
89-
// CHECK-NEXT: return [[v0]] : tensor<?xi64>
91+
// CHECK: [[cast:%.*]] = tensor.cast [[v0]] : tensor<1xi64> to tensor<?xi64>
92+
// CHECK-NEXT: return [[cast]] : tensor<?xi64>
9093
return %1 : tensor<?xi64>
9194
}
9295

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// RUN: imex-opt --split-input-file --one-shot-bufferize="bufferize-function-boundaries=1" %s -verify-diagnostics -o -| FileCheck %s
2+
3+
// -----
4+
func.func @test_subview(%arg0: tensor<?xi64>) -> tensor<?xi64> {
5+
%c0 = arith.constant 0 : index
6+
%c3 = arith.constant 3 : index
7+
%0 = ndarray.subview %arg0[%c0][%c3][%c3] : tensor<?xi64> to tensor<?xi64>
8+
return %0 : tensor<?xi64>
9+
}
10+
// CHECK-LABEL: func.func @test_subview(
11+
// CHECK-SAME: [[varg0:%.*]]: memref<?xi64, strided<[?], offset: ?>>) -> memref<?xi64, strided<[?], offset: ?>> {
12+
// CHECK-NEXT: [[vc0:%.*]] = arith.constant 0 : index
13+
// CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
14+
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc0]]] [[[vc3]]] [[[vc3]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
15+
// CHECK-NEXT: return [[vsubview]] : memref<?xi64, strided<[?], offset: ?>>
16+
17+
18+
// -----
19+
func.func @test_insert_slice(%arg0: tensor<?xi64>, %arg1: tensor<?xi64>) {
20+
%i0 = arith.constant 0 : index
21+
%i1 = arith.constant 1 : index
22+
%i3 = arith.constant 3 : index
23+
ndarray.insert_slice %arg1 into %arg0[%i0] [%i3] [%i1] : tensor<?xi64> into tensor<?xi64>
24+
return
25+
}
26+
// CHECK-LABEL: func.func @test_insert_slice(
27+
// CHECK-SAME: [[varg0:%.*]]: memref<?xi64, strided<[?], offset: ?>>, [[varg1:%.*]]: memref<?xi64, strided<[?], offset: ?>>) {
28+
// CHECK-NEXT: [[vc0:%.*]] = arith.constant 0 : index
29+
// CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
30+
// CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
31+
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc0]]] [[[vc3]]] [[[vc1]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
32+
// CHECK-NEXT: memref.copy [[varg1]], [[vsubview]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
33+
// CHECK-NEXT: return
34+
35+
36+
// -----
37+
func.func @test_insert_slice_scalar(%arg0: tensor<?xi64>, %arg1: tensor<i64>) {
38+
%i0 = arith.constant 0 : index
39+
%i1 = arith.constant 1 : index
40+
%i3 = arith.constant 3 : index
41+
ndarray.insert_slice %arg1 into %arg0[%i0] [%i3] [%i1] : tensor<i64> into tensor<?xi64>
42+
return
43+
}
44+
// CHECK-LABEL: func.func @test_insert_slice_scalar(
45+
// CHECK-SAME: [[varg0:%.*]]: memref<?xi64, strided<[?], offset: ?>>, [[varg1:%.*]]: memref<i64, strided<[], offset: ?>>) {
46+
// CHECK-NEXT: [[vc0:%.*]] = arith.constant 0 : index
47+
// CHECK-NEXT: [[vc1:%.*]] = arith.constant 1 : index
48+
// CHECK-NEXT: [[vc3:%.*]] = arith.constant 3 : index
49+
// CHECK-NEXT: [[vsubview:%.*]] = memref.subview [[varg0]][[[vc0]]] [[[vc3]]] [[[vc1]]] : memref<?xi64, strided<[?], offset: ?>> to memref<?xi64, strided<[?], offset: ?>>
50+
// CHECK-NEXT: linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel"]} ins([[varg1]] : memref<i64, strided<[], offset: ?>>) outs([[vsubview]] : memref<?xi64, strided<[?], offset: ?>>) {
51+
// CHECK-NEXT: ^bb0([[vin:%.*]]: i64, [[vout:%.*]]: i64):
52+
// CHECK-NEXT: linalg.yield [[vin]] : i64
53+
// CHECK: return

0 commit comments

Comments
 (0)