@@ -19,37 +19,6 @@ def log(*args):
1919 sys .stderr .flush ()
2020
2121
22- elemwise_boiler = """
23- func.func @main() -> f32 attributes {llvm.emit_c_interface} {
24- %v0 = arith.constant 0.0 : f32
25- %v1 = arith.constant 1.0 : f32
26- %v2 = arith.constant 2.0 : f32
27-
28- %lhs = memref.alloc() : memref<f32>
29- %rhs = memref.alloc() : memref<4x8xf32>
30- %O0 = memref.alloc() : memref<4x8xf32>
31- %O1 = memref.alloc() : memref<4x8xf32>
32- linalg.fill ins(%v1 : f32) outs(%lhs : memref<f32>)
33- linalg.fill ins(%v2 : f32) outs(%rhs : memref<4x8xf32>)
34- linalg.fill ins(%v0 : f32) outs(%O0 : memref<4x8xf32>)
35- linalg.fill ins(%v0 : f32) outs(%O1 : memref<4x8xf32>)
36-
37- call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
38- (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
39- call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
40- (memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
41-
42- %c0 = arith.constant 0 : index
43- %res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
44- %res1 = memref.load %O1[%c0, %c0] : memref<4x8xf32>
45-
46- %0 = arith.addf %res0, %res1 : f32
47-
48- // TODO: FFI-based solution to allow testing and printing with python code.
49- return %0 : f32
50- }
51- """
52-
5322fill_boiler = """
5423func.func @main() -> i32 attributes {llvm.emit_c_interface} {
5524 %O0 = memref.alloc() : memref<i32>
@@ -177,96 +146,6 @@ def transform(module, boilerplate):
177146 return mod
178147
179148
180- def test_elemwise_builtin ():
181- with Context () as ctx , Location .unknown ():
182- module = Module .create ()
183- f32 = F32Type .get ()
184- i8 = IntegerType .get_signless (8 )
185- with InsertionPoint (module .body ):
186-
187- @func .FuncOp .from_py_func (
188- MemRefType .get ((), f32 ),
189- MemRefType .get ((4 , 8 ), f32 ),
190- MemRefType .get ((4 , 8 ), f32 ),
191- )
192- def elemwise_exp_add_on_buffers (lhs , rhs , out ):
193- linalg .elemwise_unary (lhs , outs = [out ])
194- linalg .elemwise_binary (out , rhs , outs = [out ])
195-
196- @func .FuncOp .from_py_func (
197- MemRefType .get ((), f32 ),
198- MemRefType .get ((4 , 8 ), f32 ),
199- MemRefType .get ((4 , 8 ), f32 ),
200- )
201- def elemwise_log_mul_on_buffers (lhs , rhs , out ):
202- linalg .elemwise_unary (lhs , outs = [out ], fun = UnaryFn .log )
203- linalg .elemwise_binary (out , rhs , outs = [out ], fun = BinaryFn .mul )
204-
205- execution_engine = ExecutionEngine (transform (module , elemwise_boiler ))
206-
207- # TODO: FFI-based solution to allow testing and printing with python code.
208- # Prepare arguments: one result f32.
209- # Arguments must be passed as pointers.
210- c_float_p = ctypes .c_float * 1
211- res = c_float_p (- 1.0 )
212- execution_engine .invoke ("main" , res )
213-
214- log ("RESULT: " , res [0 ])
215- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
216- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
217- # CHECK: RESULT: 4.71828
218-
219-
220- test_elemwise_builtin ()
221-
222-
223- def test_elemwise_generic ():
224- with Context () as ctx , Location .unknown ():
225- module = Module .create ()
226- f32 = F32Type .get ()
227- i8 = IntegerType .get_signless (8 )
228- with InsertionPoint (module .body ):
229-
230- @func .FuncOp .from_py_func (
231- MemRefType .get ((), f32 ),
232- MemRefType .get ((4 , 8 ), f32 ),
233- MemRefType .get ((4 , 8 ), f32 ),
234- )
235- def elemwise_exp_add_on_buffers (lhs , rhs , out ):
236- linalg .elemwise_unary (lhs , outs = [out ], emit_generic = True )
237- linalg .elemwise_binary (out , rhs , outs = [out ], emit_generic = True )
238-
239- @func .FuncOp .from_py_func (
240- MemRefType .get ((), f32 ),
241- MemRefType .get ((4 , 8 ), f32 ),
242- MemRefType .get ((4 , 8 ), f32 ),
243- )
244- def elemwise_log_mul_on_buffers (lhs , rhs , out ):
245- linalg .elemwise_unary (
246- lhs , outs = [out ], fun = UnaryFn .log , emit_generic = True
247- )
248- linalg .elemwise_binary (
249- out , rhs , outs = [out ], fun = BinaryFn .mul , emit_generic = True
250- )
251-
252- execution_engine = ExecutionEngine (transform (module , elemwise_boiler ))
253-
254- # TODO: FFI-based solution to allow testing and printing with python code.
255- # Prepare arguments: one result f32.
256- # Arguments must be passed as pointers.
257- c_float_p = ctypes .c_float * 1
258- res = c_float_p (- 1.0 )
259- execution_engine .invoke ("main" , res )
260-
261- log ("RESULT: " , res [0 ])
262- # elemwise_exp_add_on_buffers: exp(1.0) + 2.0 = 4.71828182846
263- # elemwise_log_mul_on_buffers: log(1.0) * 2.0 = 0.0
264- # CHECK: RESULT: 4.71828
265-
266-
267- test_elemwise_generic ()
268-
269-
270149def test_fill_builtin ():
271150 with Context () as ctx , Location .unknown ():
272151 module = Module .create ()
0 commit comments