6
6
7
7
import pytest
8
8
from attr import dataclass
9
- from hypothesis import assume , given
9
+ from hypothesis import HealthCheck , assume , given , settings
10
10
11
+ from . import dtype_helpers as dh
11
12
from . import hypothesis_helpers as hh
12
13
from . import shape_helpers as sh
13
14
from . import xps
14
15
from ._array_module import mod as xp
15
16
from .stubs import category_to_funcs
16
17
17
- repr_to_value = {
18
- "NaN" : float ("nan" ),
19
- "+infinity" : float ("infinity" ),
20
- "infinity" : float ("infinity" ),
21
- "-infinity" : float ("-infinity" ),
22
- "+0" : 0.0 ,
23
- "0" : 0.0 ,
24
- "-0" : - 0.0 ,
25
- "+1" : 1.0 ,
26
- "1" : 1.0 ,
27
- "-1" : - 1.0 ,
28
- "+π/2" : math .pi / 2 ,
29
- "π/2" : math .pi / 2 ,
30
- "-π/2" : - math .pi / 2 ,
31
- }
18
+
19
+ def is_pos_zero (n : float ) -> bool :
20
+ return n == 0 and math .copysign (1 , n ) == 1
21
+
22
+
23
+ def is_neg_zero (n : float ) -> bool :
24
+ return n == 0 and math .copysign (1 , n ) == - 1
32
25
33
26
34
27
def make_eq (v : float ) -> Callable [[float ], bool ]:
35
28
if math .isnan (v ):
36
29
return math .isnan
30
+ if v == 0 :
31
+ if is_pos_zero (v ):
32
+ return is_pos_zero
33
+ else :
34
+ return is_neg_zero
37
35
38
36
def eq (i : float ) -> bool :
39
37
return i == v
@@ -42,6 +40,8 @@ def eq(i: float) -> bool:
42
40
43
41
44
42
def make_rough_eq (v : float ) -> Callable [[float ], bool ]:
43
+ assert math .isfinite (v ) # sanity check
44
+
45
45
def rough_eq (i : float ) -> bool :
46
46
return math .isclose (i , v , abs_tol = 0.01 )
47
47
@@ -73,21 +73,52 @@ def or_(i: float):
73
73
return or_
74
74
75
75
76
- r_value = re .compile (r"``([^\s]+)``" )
77
- r_approx_value = re .compile (
78
- rf"an implementation-dependent approximation to { r_value .pattern } "
79
- )
76
+ repr_to_value = {
77
+ "NaN" : float ("nan" ),
78
+ "infinity" : float ("infinity" ),
79
+ "0" : 0.0 ,
80
+ "1" : 1.0 ,
81
+ }
82
+
83
+ r_value = re .compile (r"([+-]?)(.+)" )
84
+ r_pi = re .compile (r"(\d?)π(?:/(\d))?" )
80
85
81
86
82
87
@dataclass
83
88
class ValueParseError (ValueError ):
84
89
value : str
85
90
86
91
87
- def parse_value (value : str ) -> float :
88
- if m := r_value .match (value ):
89
- return repr_to_value [m .group (1 )]
90
- raise ValueParseError (value )
92
+ def parse_value (s_value : str ) -> float :
93
+ assert not s_value .startswith ("``" ) and not s_value .endswith ("``" ) # sanity check
94
+ m = r_value .match (s_value )
95
+ if m is None :
96
+ raise ValueParseError (s_value )
97
+ if pi_m := r_pi .match (m .group (2 )):
98
+ value = math .pi
99
+ if numerator := pi_m .group (1 ):
100
+ value *= int (numerator )
101
+ if denominator := pi_m .group (2 ):
102
+ value /= int (denominator )
103
+ else :
104
+ value = repr_to_value [m .group (2 )]
105
+ if sign := m .group (1 ):
106
+ if sign == "-" :
107
+ value *= - 1
108
+ return value
109
+
110
+
111
+ r_inline_code = re .compile (r"``([^\s]+)``" )
112
+ r_approx_value = re .compile (
113
+ rf"an implementation-dependent approximation to { r_inline_code .pattern } "
114
+ )
115
+
116
+
117
+ def parse_inline_code (inline_code : str ) -> float :
118
+ if m := r_inline_code .match (inline_code ):
119
+ return parse_value (m .group (1 ))
120
+ else :
121
+ raise ValueParseError (inline_code )
91
122
92
123
93
124
class Result (NamedTuple ):
@@ -96,22 +127,24 @@ class Result(NamedTuple):
96
127
strict_check : bool
97
128
98
129
99
- def parse_result (result : str ) -> Result :
100
- if m := r_value .match (result ):
101
- repr_ = m .group (1 )
130
+ def parse_result (s_result : str ) -> Result :
131
+ match = None
132
+ if m := r_inline_code .match (s_result ):
133
+ match = m
102
134
strict_check = True
103
- elif m := r_approx_value .match (result ):
104
- repr_ = m . group ( 1 )
135
+ elif m := r_approx_value .match (s_result ):
136
+ match = m
105
137
strict_check = False
106
138
else :
107
- raise ValueParseError (result )
108
- value = repr_to_value [repr_ ]
139
+ raise ValueParseError (s_result )
140
+ value = parse_value (match .group (1 ))
141
+ repr_ = match .group (1 )
109
142
return Result (value , repr_ , strict_check )
110
143
111
144
112
145
r_special_cases = re .compile (
113
- r"\*\*Special [Cc]ases\*\*\n\n \s*"
114
- r"For floating-point operands,\n\n "
146
+ r"\*\*Special [Cc]ases\*\*\n+ \s*"
147
+ r"For floating-point operands,\n+ "
115
148
r"((?:\s*-\s*.*\n)+)"
116
149
)
117
150
r_case = re .compile (r"\s+-\s*(.*)\.\n?" )
@@ -148,7 +181,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
148
181
if m := pattern .search (case ):
149
182
* s_values , s_result = m .groups ()
150
183
try :
151
- values = [parse_value (v ) for v in s_values ]
184
+ values = [parse_inline_code (v ) for v in s_values ]
152
185
except ValueParseError as e :
153
186
warn (f"value not machine-readable: '{ e .value } '" )
154
187
break
@@ -166,7 +199,56 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
166
199
return condition_to_result
167
200
168
201
202
+ binary_pattern_to_condition_factory : Dict [Pattern , Callable ] = {
203
+ re .compile (
204
+ "If ``x1_i`` is (.+) and ``x2_i`` is (.+), the result is (.+)"
205
+ ): lambda v1 , v2 : lambda i1 , i2 : make_eq (v1 )(i1 )
206
+ and make_eq (v2 )(i2 ),
207
+ }
208
+
209
+
210
+ def parse_binary_docstring (docstring : str ) -> Dict [Callable , Result ]:
211
+ match = r_special_cases .search (docstring )
212
+ if match is None :
213
+ return {}
214
+ cases = match .group (1 ).split ("\n " )[:- 1 ]
215
+ condition_to_result = {}
216
+ for line in cases :
217
+ if m := r_case .match (line ):
218
+ case = m .group (1 )
219
+ else :
220
+ warn (f"line not machine-readable: '{ line } '" )
221
+ continue
222
+ for pattern , make_cond in binary_pattern_to_condition_factory .items ():
223
+ if m := pattern .search (case ):
224
+ * s_values , s_result = m .groups ()
225
+ try :
226
+ values = [parse_inline_code (v ) for v in s_values ]
227
+ except ValueParseError as e :
228
+ warn (f"value not machine-readable: '{ e .value } '" )
229
+ break
230
+ cond = make_cond (* values )
231
+ if (
232
+ "atan2" in docstring
233
+ and is_pos_zero (values [0 ])
234
+ and is_neg_zero (values [1 ])
235
+ ):
236
+ breakpoint ()
237
+ try :
238
+ result = parse_result (s_result )
239
+ except ValueParseError as e :
240
+ warn (f"result not machine-readable: '{ e .value } '" )
241
+ break
242
+ condition_to_result [cond ] = result
243
+ break
244
+ else :
245
+ if not r_remaining_case .search (case ):
246
+ warn (f"case not machine-readable: '{ case } '" )
247
+ return condition_to_result
248
+
249
+
169
250
unary_params = []
251
+ binary_params = []
170
252
for stub in category_to_funcs ["elementwise" ]:
171
253
if stub .__doc__ is None :
172
254
warn (f"{ stub .__name__ } () stub has no docstring" )
@@ -193,7 +275,10 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
193
275
warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
194
276
continue
195
277
if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
196
- pass # TODO
278
+ if condition_to_result := parse_binary_docstring (stub .__doc__ ):
279
+ p = pytest .param (stub .__name__ , func , condition_to_result , id = stub .__name__ )
280
+ binary_params .append (p )
281
+ continue
197
282
else :
198
283
warn (
199
284
f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
@@ -209,7 +294,7 @@ def parse_unary_docstring(docstring: str) -> Dict[Callable, Result]:
209
294
210
295
@pytest .mark .parametrize ("func_name, func, condition_to_result" , unary_params )
211
296
@given (x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )))
212
- def test_unary_special_cases (func_name , func , condition_to_result , x ):
297
+ def test_unary (func_name , func , condition_to_result , x ):
213
298
res = func (x )
214
299
good_example = False
215
300
for idx in sh .ndindex (res .shape ):
@@ -238,3 +323,44 @@ def test_unary_special_cases(func_name, func, condition_to_result, x):
238
323
)
239
324
break
240
325
assume (good_example )
326
+
327
+
328
+ @pytest .mark .parametrize ("func_name, func, condition_to_result" , binary_params )
329
+ @given (
330
+ * hh .two_mutual_arrays (
331
+ dtypes = dh .float_dtypes ,
332
+ two_shapes = hh .mutually_broadcastable_shapes (2 , min_side = 1 ),
333
+ )
334
+ )
335
+ @settings (suppress_health_check = [HealthCheck .filter_too_much ]) # TODO: remove
336
+ def test_binary (func_name , func , condition_to_result , x1 , x2 ):
337
+ res = func (x1 , x2 )
338
+ good_example = False
339
+ for l_idx , r_idx , o_idx in sh .iter_indices (x1 .shape , x2 .shape , res .shape ):
340
+ l = float (x1 [l_idx ])
341
+ r = float (x2 [r_idx ])
342
+ for cond , result in condition_to_result .items ():
343
+ if cond (l , r ):
344
+ good_example = True
345
+ out = float (res [o_idx ])
346
+ f_left = f"{ sh .fmt_idx ('x1' , l_idx )} ={ l } "
347
+ f_right = f"{ sh .fmt_idx ('x2' , r_idx )} ={ r } "
348
+ f_out = f"{ sh .fmt_idx ('out' , o_idx )} ={ out } "
349
+ if result .strict_check :
350
+ msg = (
351
+ f"{ f_out } , but should be { result .repr_ } [{ func_name } ()]\n "
352
+ f"{ f_left } , { f_right } "
353
+ )
354
+ if math .isnan (result .value ):
355
+ assert math .isnan (out ), msg
356
+ else :
357
+ assert out == result .value , msg
358
+ else :
359
+ assert math .isfinite (result .value ) # sanity check
360
+ assert math .isclose (out , result .value , abs_tol = 0.1 ), (
361
+ f"{ f_out } , but should be roughly { result .repr_ } ={ result .value } "
362
+ f"[{ func_name } ()]\n "
363
+ f"{ f_left } , { f_right } "
364
+ )
365
+ break
366
+ assume (good_example )
0 commit comments