Skip to content

Commit d5bcfcb

Browse files
committed
MAINT: generate random positive definite matrices, not just stacks of identities
1 parent 9db56ec commit d5bcfcb

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from . import xps
2121
from ._array_module import _UndefinedStub
2222
from ._array_module import bool as bool_dtype
23-
from ._array_module import broadcast_to, eye, float32, float64, full
23+
from ._array_module import broadcast_to, float32, float64, full
2424
from .stubs import category_to_funcs
2525
from .pytest_helpers import nargs
2626
from .typing import Array, DataType, Scalar, Shape
@@ -314,16 +314,20 @@ def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.
314314

315315
@composite
316316
def positive_definite_matrices(draw, dtypes=floating_dtypes):
317-
# For now just generate stacks of identity matrices
318-
# TODO: Generate arbitrary positive definite matrices, for instance, by
319-
# using something like
320-
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
321317
base_shape = draw(shapes())
322318
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
323319
shape = base_shape + (n, n)
324320
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
325321
dtype = draw(dtypes)
326-
return broadcast_to(eye(n, dtype=dtype), shape)
322+
323+
import numpy as np
324+
rng = np.random.default_rng(1234567)
325+
a = rng.uniform(size=(n, n))
326+
q, r = np.linalg.qr(a)
327+
arr = q.T @ (np.diag(3 + rng.uniform(size=n))) @ q
328+
arr = xp.asarray(arr, dtype=dtype)
329+
330+
return broadcast_to(arr, shape)
327331

328332
@composite
329333
def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):

0 commit comments

Comments
 (0)