@@ -79,6 +79,7 @@ def test_input(op) -> None:
79
79
pytest .param (pt .eq , id = "eq" ),
80
80
pytest .param (pt .neq , id = "neq" ),
81
81
pytest .param (pt .true_div , id = "true_div" ),
82
+ pytest .param (pt .int_div , id = "int_div" ),
82
83
],
83
84
)
84
85
def test_elemwise_two_inputs (op ) -> None :
@@ -90,6 +91,119 @@ def test_elemwise_two_inputs(op) -> None:
90
91
compare_mlx_and_py ([x , y ], out , [x_test , y_test ])
91
92
92
93
94
+ def test_int_div_specific () -> None :
95
+ """Test integer division with specific test cases"""
96
+ x = pt .vector ("x" )
97
+ y = pt .vector ("y" )
98
+ out = pt .int_div (x , y )
99
+
100
+ # Test with integers that demonstrate floor division behavior
101
+ x_test = mx .array ([7.0 , 8.0 , 9.0 , - 7.0 , - 8.0 ])
102
+ y_test = mx .array ([3.0 , 3.0 , 3.0 , 3.0 , 3.0 ])
103
+
104
+ compare_mlx_and_py ([x , y ], out , [x_test , y_test ])
105
+
106
+
107
+ def test_isnan () -> None :
108
+ """Test IsNan operation with various inputs including NaN values"""
109
+ x = pt .vector ("x" )
110
+ out = pt .isnan (x )
111
+
112
+ # Test with mix of normal values, NaN, and infinity
113
+ x_test = mx .array ([1.0 , np .nan , 3.0 , np .inf , - np .nan , 0.0 , - np .inf ])
114
+
115
+ compare_mlx_and_py ([x ], out , [x_test ])
116
+
117
+
118
+ def test_isnan_edge_cases () -> None :
119
+ """Test IsNan with edge cases"""
120
+ x = pt .scalar ("x" )
121
+ out = pt .isnan (x )
122
+
123
+ # Test individual cases
124
+ test_cases = [0.0 , np .nan , np .inf , - np .inf , 1e-10 , 1e10 ]
125
+
126
+ for test_val in test_cases :
127
+ x_test = test_val
128
+ compare_mlx_and_py ([x ], out , [x_test ])
129
+
130
+
131
+ def test_erfc () -> None :
132
+ """Test complementary error function"""
133
+ x = pt .vector ("x" )
134
+ out = pt .erfc (x )
135
+
136
+ # Test with various values including negative, positive, and zero
137
+ x_test = mx .array ([0.0 , 0.5 , 1.0 , - 0.5 , - 1.0 , 2.0 , - 2.0 , 0.1 ])
138
+
139
+ compare_mlx_and_py ([x ], out , [x_test ])
140
+
141
+
142
+ def test_erfc_extreme_values () -> None :
143
+ """Test erfc with extreme values"""
144
+ x = pt .vector ("x" )
145
+ out = pt .erfc (x )
146
+
147
+ # Test with larger values where erfc approaches 0 or 2
148
+ x_test = mx .array ([- 3.0 , - 2.5 , 2.5 , 3.0 ])
149
+
150
+ # Use relaxed tolerance for extreme values due to numerical precision differences
151
+ from functools import partial
152
+
153
+ relaxed_assert = partial (np .testing .assert_allclose , rtol = 1e-3 , atol = 1e-6 )
154
+
155
+ compare_mlx_and_py ([x ], out , [x_test ], assert_fn = relaxed_assert )
156
+
157
+
158
+ def test_erfcx () -> None :
159
+ """Test scaled complementary error function"""
160
+ x = pt .vector ("x" )
161
+ out = pt .erfcx (x )
162
+
163
+ # Test with positive values where erfcx is most numerically stable
164
+ x_test = mx .array ([0.0 , 0.5 , 1.0 , 1.5 , 2.0 , 2.5 ])
165
+
166
+ compare_mlx_and_py ([x ], out , [x_test ])
167
+
168
+
169
+ def test_erfcx_small_values () -> None :
170
+ """Test erfcx with small values"""
171
+ x = pt .vector ("x" )
172
+ out = pt .erfcx (x )
173
+
174
+ # Test with small values
175
+ x_test = mx .array ([0.001 , 0.01 , 0.1 , 0.2 ])
176
+
177
+ compare_mlx_and_py ([x ], out , [x_test ])
178
+
179
+
180
+ def test_softplus () -> None :
181
+ """Test softplus (log(1 + exp(x))) function"""
182
+ x = pt .vector ("x" )
183
+ out = pt .softplus (x )
184
+
185
+ # Test with normal range values
186
+ x_test = mx .array ([0.0 , 1.0 , 2.0 , - 1.0 , - 2.0 , 10.0 ])
187
+
188
+ compare_mlx_and_py ([x ], out , [x_test ])
189
+
190
+
191
+ def test_softplus_extreme_values () -> None :
192
+ """Test softplus with extreme values to verify numerical stability"""
193
+ x = pt .vector ("x" )
194
+ out = pt .softplus (x )
195
+
196
+ # Test with extreme values where different branches of the implementation are used
197
+ x_test = mx .array ([- 40.0 , - 50.0 , 20.0 , 30.0 , 35.0 , 50.0 ])
198
+
199
+ # Use relaxed tolerance for extreme values due to numerical precision differences
200
+ from functools import partial
201
+
202
+ relaxed_assert = partial (np .testing .assert_allclose , rtol = 1e-4 , atol = 1e-8 )
203
+
204
+ compare_mlx_and_py ([x ], out , [x_test ], assert_fn = relaxed_assert )
205
+
206
+
93
207
@pytest .mark .xfail (reason = "Argmax not implemented yet" )
94
208
def test_mlx_max_and_argmax ():
95
209
# Test that a single output of a multi-output `Op` can be used as input to
0 commit comments