Skip to content

Commit 7d29933

Browse files
committed
[mlir][arith] Add integration tests for addi emulation
This includes tests with the exact expected values and comparison-based tests. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D134321
1 parent 897a79f commit 7d29933

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
// Check that the wide integer addition emulation produces the same result as
2+
// wide addition. Emulate i16 ops with i8 ops.
3+
4+
// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
5+
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
6+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
7+
// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
8+
// RUN: FileCheck %s --match-full-lines
9+
10+
// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=8" \
11+
// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
12+
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
13+
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
14+
// RUN: --shared-libs=%mlir_lib_dir/libmlir_c_runner_utils%shlibext | \
15+
// RUN: FileCheck %s --match-full-lines
16+
17+
// Ops in this function *only* will be emulated using i8 types.
18+
func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) {
19+
%res = arith.addi %lhs, %rhs : i16
20+
return %res : i16
21+
}
22+
23+
func.func @check_addi(%lhs : i16, %rhs : i16) -> () {
24+
%res = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16)
25+
vector.print %res : i16
26+
return
27+
}
28+
29+
func.func @entry() {
30+
%cst0 = arith.constant 0 : i16
31+
%cst1 = arith.constant 1 : i16
32+
%cst_1 = arith.constant -1 : i16
33+
%cst_3 = arith.constant -3 : i16
34+
35+
%cst13 = arith.constant 13 : i16
36+
%cst37 = arith.constant 37 : i16
37+
%cst42 = arith.constant 42 : i16
38+
39+
%cst256 = arith.constant 256 : i16
40+
%cst_i16_max = arith.constant 32767 : i16
41+
%cst_i16_min = arith.constant -32768 : i16
42+
43+
// CHECK: 0
44+
func.call @check_addi(%cst0, %cst0) : (i16, i16) -> ()
45+
// CHECK-NEXT: 1
46+
func.call @check_addi(%cst0, %cst1) : (i16, i16) -> ()
47+
// CHECK-NEXT: 2
48+
func.call @check_addi(%cst1, %cst1) : (i16, i16) -> ()
49+
// CHECK-NEXT: 0
50+
func.call @check_addi(%cst1, %cst_1) : (i16, i16) -> ()
51+
// CHECK-NEXT: -2
52+
func.call @check_addi(%cst_1, %cst_1) : (i16, i16) -> ()
53+
// CHECK-NEXT: -2
54+
func.call @check_addi(%cst1, %cst_3) : (i16, i16) -> ()
55+
56+
// CHECK-NEXT: 26
57+
func.call @check_addi(%cst13, %cst13) : (i16, i16) -> ()
58+
// CHECK-NEXT: 50
59+
func.call @check_addi(%cst13, %cst37) : (i16, i16) -> ()
60+
// CHECK-NEXT: 79
61+
func.call @check_addi(%cst37, %cst42) : (i16, i16) -> ()
62+
63+
// CHECK-NEXT: 255
64+
func.call @check_addi(%cst_1, %cst256) : (i16, i16) -> ()
65+
// CHECK-NEXT: 269
66+
func.call @check_addi(%cst256, %cst13) : (i16, i16) -> ()
67+
// CHECK-NEXT: 293
68+
func.call @check_addi(%cst256, %cst37) : (i16, i16) -> ()
69+
// CHECK-NEXT: 253
70+
func.call @check_addi(%cst256, %cst_3) : (i16, i16) -> ()
71+
72+
// CHECK-NEXT: -32756
73+
func.call @check_addi(%cst13, %cst_i16_max) : (i16, i16) -> ()
74+
// CHECK-NEXT: -32731
75+
func.call @check_addi(%cst_i16_min, %cst37) : (i16, i16) -> ()
76+
77+
// CHECK-NEXT: -2
78+
func.call @check_addi(%cst_i16_max, %cst_i16_max) : (i16, i16) -> ()
79+
// CHECK-NEXT: -32755
80+
func.call @check_addi(%cst_i16_min, %cst13) : (i16, i16) -> ()
81+
// CHECK-NEXT: 0
82+
func.call @check_addi(%cst_i16_min, %cst_i16_min) : (i16, i16) -> ()
83+
84+
return
85+
}

mlir/test/Integration/Dialect/Arithmetic/CPU/test-wide-int-emulation-compare-results-i16.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,53 @@ func.func @xhash(%i : i16) -> (i16) {
6262
return %res : i16
6363
}
6464

65+
//===----------------------------------------------------------------------===//
66+
// Test arith.addi
67+
//===----------------------------------------------------------------------===//
68+
69+
// Ops in this function will be emulated using i8 ops.
70+
func.func @emulate_addi(%lhs : i16, %rhs : i16) -> (i16) {
71+
%res = arith.addi %lhs, %rhs : i16
72+
return %res : i16
73+
}
74+
75+
// Performs both wide and emulated `arith.muli`, and checks that the results
76+
// match.
77+
func.func @check_addi(%lhs : i16, %rhs : i16) -> () {
78+
%wide = arith.addi %lhs, %rhs : i16
79+
%emulated = func.call @emulate_addi(%lhs, %rhs) : (i16, i16) -> (i16)
80+
func.call @check_results(%lhs, %rhs, %wide, %emulated) : (i16, i16, i16, i16) -> ()
81+
return
82+
}
83+
84+
// Checks that `arith.addi` is emulated properly by sampling the input space.
85+
// In total, this test function checks 500 * 500 = 250k input pairs.
86+
func.func @test_addi() -> () {
87+
%idx0 = arith.constant 0 : index
88+
%idx1 = arith.constant 1 : index
89+
%idx500 = arith.constant 500 : index
90+
91+
%cst0 = arith.constant 0 : i16
92+
%cst1 = arith.constant 1 : i16
93+
94+
scf.for %lhs_idx = %idx0 to %idx500 step %idx1 iter_args(%lhs = %cst0) -> (i16) {
95+
%arg_lhs = func.call @xhash(%lhs) : (i16) -> (i16)
96+
97+
scf.for %rhs_idx = %idx0 to %idx500 step %idx1 iter_args(%rhs = %cst0) -> (i16) {
98+
%arg_rhs = func.call @xhash(%rhs) : (i16) -> (i16)
99+
func.call @check_addi(%arg_lhs, %arg_rhs) : (i16, i16) -> ()
100+
101+
%rhs_next = arith.addi %rhs, %cst1 : i16
102+
scf.yield %rhs_next : i16
103+
}
104+
105+
%lhs_next = arith.addi %lhs, %cst1 : i16
106+
scf.yield %lhs_next : i16
107+
}
108+
109+
return
110+
}
111+
65112
//===----------------------------------------------------------------------===//
66113
// Test arith.muli
67114
//===----------------------------------------------------------------------===//
@@ -161,6 +208,7 @@ func.func @test_shrui() -> () {
161208
//===----------------------------------------------------------------------===//
162209

163210
func.func @entry() {
211+
func.call @test_addi() : () -> ()
164212
func.call @test_muli() : () -> ()
165213
func.call @test_shrui() : () -> ()
166214
return

0 commit comments

Comments
 (0)