@@ -174,10 +174,24 @@ def oneway_broadcastable_shapes(draw) -> OnewayBroadcastableShapes:
174
174
return OnewayBroadcastableShapes (input_shape , result_shape )
175
175
176
176
177
+ # Use these instead of xps.scalar_dtypes, etc. because it skips dtypes from
178
+ # ARRAY_API_TESTS_SKIP_DTYPES
179
+ all_dtypes = sampled_from (_sorted_dtypes )
180
+ int_dtypes = sampled_from (dh .int_dtypes )
181
+ uint_dtypes = sampled_from (dh .uint_dtypes )
182
+ real_dtypes = sampled_from (dh .real_dtypes )
183
+ # Warning: The hypothesis "floating_dtypes" is what we call
184
+ # "real_floating_dtypes"
185
+ floating_dtypes = sampled_from (dh .all_float_dtypes )
186
+ real_floating_dtypes = sampled_from (dh .real_float_dtypes )
187
+ numeric_dtypes = sampled_from (dh .numeric_dtypes )
188
+ # Note: this always returns complex dtypes, even if api_version < 2022.12
189
+ complex_dtypes = sampled_from (dh .complex_dtypes )
190
+
177
191
def all_floating_dtypes () -> SearchStrategy [DataType ]:
178
- strat = xps . floating_dtypes ()
192
+ strat = floating_dtypes
179
193
if api_version >= "2022.12" :
180
- strat |= xps . complex_dtypes ()
194
+ strat |= complex_dtypes
181
195
return strat
182
196
183
197
@@ -236,7 +250,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
236
250
237
251
@composite
238
252
def finite_matrices (draw , shape = matrix_shapes ()):
239
- return draw (arrays (dtype = xps . floating_dtypes () ,
253
+ return draw (arrays (dtype = floating_dtypes ,
240
254
shape = shape ,
241
255
elements = dict (allow_nan = False ,
242
256
allow_infinity = False )))
@@ -245,7 +259,7 @@ def finite_matrices(draw, shape=matrix_shapes()):
245
259
# Should we set a max_value here?
246
260
_rtol_float_kw = dict (allow_nan = False , allow_infinity = False , min_value = 0 )
247
261
rtols = one_of (floats (** _rtol_float_kw ),
248
- arrays (dtype = xps . floating_dtypes () ,
262
+ arrays (dtype = real_floating_dtypes ,
249
263
shape = rtol_shared_matrix_shapes .map (lambda shape : shape [:- 2 ]),
250
264
elements = _rtol_float_kw ))
251
265
@@ -280,9 +294,9 @@ def mutually_broadcastable_shapes(
280
294
281
295
two_mutually_broadcastable_shapes = mutually_broadcastable_shapes (2 )
282
296
283
- # Note: This should become hermitian_matrices when complex dtypes are added
297
+ # TODO: Add support for complex Hermitian matrices
284
298
@composite
285
- def symmetric_matrices (draw , dtypes = xps . floating_dtypes () , finite = True , bound = 10. ):
299
+ def symmetric_matrices (draw , dtypes = real_floating_dtypes , finite = True , bound = 10. ):
286
300
shape = draw (square_matrix_shapes )
287
301
dtype = draw (dtypes )
288
302
if not isinstance (finite , bool ):
@@ -297,7 +311,7 @@ def symmetric_matrices(draw, dtypes=xps.floating_dtypes(), finite=True, bound=10
297
311
return H
298
312
299
313
@composite
300
- def positive_definite_matrices (draw , dtypes = xps . floating_dtypes () ):
314
+ def positive_definite_matrices (draw , dtypes = floating_dtypes ):
301
315
# For now just generate stacks of identity matrices
302
316
# TODO: Generate arbitrary positive definite matrices, for instance, by
303
317
# using something like
@@ -310,7 +324,7 @@ def positive_definite_matrices(draw, dtypes=xps.floating_dtypes()):
310
324
return broadcast_to (eye (n , dtype = dtype ), shape )
311
325
312
326
@composite
313
- def invertible_matrices (draw , dtypes = xps . floating_dtypes () , stack_shapes = shapes ()):
327
+ def invertible_matrices (draw , dtypes = floating_dtypes , stack_shapes = shapes ()):
314
328
# For now, just generate stacks of diagonal matrices.
315
329
stack_shape = draw (stack_shapes )
316
330
n = draw (integers (0 , SQRT_MAX_ARRAY_SIZE // max (math .prod (stack_shape ), 1 )),)
@@ -344,7 +358,7 @@ def two_broadcastable_shapes(draw):
344
358
sqrt_sizes = integers (0 , SQRT_MAX_ARRAY_SIZE )
345
359
346
360
numeric_arrays = arrays (
347
- dtype = shared (xps . floating_dtypes () , key = 'dtypes' ),
361
+ dtype = shared (floating_dtypes , key = 'dtypes' ),
348
362
shape = shared (xps .array_shapes (), key = 'shapes' ),
349
363
)
350
364
@@ -388,7 +402,7 @@ def python_integer_indices(draw, sizes):
388
402
def integer_indices (draw , sizes ):
389
403
# Return either a Python integer or a 0-D array with some integer dtype
390
404
idx = draw (python_integer_indices (sizes ))
391
- dtype = draw (xps . integer_dtypes () | xps . unsigned_integer_dtypes () )
405
+ dtype = draw (int_dtypes | uint_dtypes )
392
406
m , M = dh .dtype_ranges [dtype ]
393
407
if m <= idx <= M :
394
408
return draw (one_of (just (idx ),
0 commit comments