Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from . import xps
from ._array_module import _UndefinedStub
from ._array_module import bool as bool_dtype
from ._array_module import broadcast_to, eye, float32, float64, full
from ._array_module import broadcast_to, float32, float64, full
from .stubs import category_to_funcs
from .pytest_helpers import nargs
from .typing import Array, DataType, Scalar, Shape
Expand Down Expand Up @@ -314,16 +314,20 @@ def symmetric_matrices(draw, dtypes=real_floating_dtypes, finite=True, bound=10.

@composite
def positive_definite_matrices(draw, dtypes=floating_dtypes):
# For now just generate stacks of identity matrices
# TODO: Generate arbitrary positive definite matrices, for instance, by
# using something like
# https://github.com/scikit-learn/scikit-learn/blob/844b4be24/sklearn/datasets/_samples_generator.py#L1351.
base_shape = draw(shapes())
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
shape = base_shape + (n, n)
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
dtype = draw(dtypes)
return broadcast_to(eye(n, dtype=dtype), shape)

import numpy as np
rng = np.random.default_rng(1234567)
a = rng.uniform(size=(n, n))
q, r = np.linalg.qr(a)
arr = q.T @ (np.diag(3 + rng.uniform(size=n))) @ q
arr = xp.asarray(arr, dtype=dtype)

return broadcast_to(arr, shape)

@composite
def invertible_matrices(draw, dtypes=floating_dtypes, stack_shapes=shapes()):
Expand Down
Loading