Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 52 additions & 24 deletions bqskit/utils/test/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Any
from typing import Sequence

from hypothesis.control import assume
from hypothesis.strategies import composite
from hypothesis.strategies import deferred
from hypothesis.strategies import dictionaries
Expand Down Expand Up @@ -168,7 +167,7 @@ def constant_unitary_gates(
def dagger_gates(
draw: Any,
radixes: Sequence[int] | int | None = None,
constant: bool = False,
constant: bool | None = False,
) -> DaggerGate:
"""Hypothesis strategy for generating `DaggerGate`'s."""
gate = draw(deferred(lambda: gates(radixes, constant)))
Expand All @@ -179,22 +178,23 @@ def dagger_gates(
def frozen_gates(
draw: Any,
radixes: Sequence[int] | int | None = None,
constant: bool = False,
constant: bool | None = False,
) -> FrozenParameterGate:
"""Hypothesis strategy for generating `FrozenParameterGate`'s."""
gate = draw(deferred(lambda: gates(radixes, False)))
if gate.num_params == 0:
return FrozenParameterGate(gate, {})
max_idx = gate.num_params
indices = integers(0, max_idx - 1)
values = floats(allow_nan=False, allow_infinity=False, width=16)
min_size = gate.num_params if constant else 0
values = floats(-2**16, 2**16, allow_nan=False, allow_infinity=False)
min_size = 0 if constant is None or constant is False else max_idx
max_size = max_idx if constant is None or constant is True else max_idx - 1
frozen_params = draw(
dictionaries(
indices,
values,
min_size=min_size,
max_size=max_idx,
max_size=max_size,
),
)
return FrozenParameterGate(gate, frozen_params)
Expand All @@ -204,7 +204,7 @@ def frozen_gates(
def tagged_gates(
draw: Any,
radixes: Sequence[int] | int | None = None,
constant: bool = False,
constant: bool | None = False,
) -> TaggedGate:
"""Hypothesis strategy for generating `TaggedGate`'s."""
gate = draw(deferred(lambda: gates(radixes, constant)))
Expand Down Expand Up @@ -237,24 +237,52 @@ def gates(

num_qudits = len(radixes)

gate = draw(
one_of(
just(IdentityGate(num_qudits, radixes)),
constant_unitary_gates(radixes),
just(PauliGate(num_qudits)),
constant_strategies = [
just(IdentityGate(num_qudits, radixes)),
constant_unitary_gates(radixes),
]

nonconstant_strategies = []
# the num_params checks for these could probably be done without
# instantiating a gate
if PauliGate(num_qudits).num_params <= 128:
nonconstant_strategies.append(just(PauliGate(num_qudits)))
if VariableUnitaryGate(num_qudits, radixes).num_params <= 128:
nonconstant_strategies.append(
just(VariableUnitaryGate(num_qudits, radixes)),
sampled_from(gate_instances),
dagger_gates(radixes, constant),
tagged_gates(radixes, constant),
frozen_gates(radixes, constant),
),
)
)

strategies = []
if constant is None or constant is True:
strategies.extend(constant_strategies)
if constant is None or constant is False:
strategies.extend(nonconstant_strategies)

filtered_instances = [
g for g in gate_instances
if g.is_constant() == constant
and sorted(g.radixes) == sorted(radixes)
]
if len(filtered_instances) > 0:
strategies.append(sampled_from(filtered_instances))
strategies += [
dagger_gates(radixes, constant),
tagged_gates(radixes, constant),
frozen_gates(radixes, constant),
]

strategy = one_of(strategies)

# these filters should be handled by the above code, but have been left
# in just in case
if constant is not None:
assume(gate.is_constant() == constant)
strategy = strategy.filter(lambda g: g.is_constant() == constant)
strategy.filter(
lambda g: sorted(g.radixes) == sorted(radixes)
and g.num_params <= 128,
)

assume(sorted(gate.radixes) == sorted(radixes))
assume(gate.num_params <= 128)
gate = draw(strategy)

return gate

Expand All @@ -269,7 +297,7 @@ def gates_and_params(
gate = draw(gates(radixes, constant))
params = draw(
lists(
floats(allow_nan=False, allow_infinity=False, width=16),
floats(-2**16, 2**16, allow_nan=False, allow_infinity=False),
min_size=gate.num_params,
max_size=gate.num_params,
),
Expand Down Expand Up @@ -297,7 +325,7 @@ def operations(
one_of([
lists(floats(), max_size=0),
lists(
floats(allow_nan=False, allow_infinity=False, width=16),
floats(-2**16, 2**16, allow_nan=False, allow_infinity=False),
min_size=gate.num_params,
max_size=gate.num_params,
),
Expand Down Expand Up @@ -349,7 +377,7 @@ def circuits(
gate_location = list(zip(*gate_idx_and_rdx))[0]
gate_radixes = list(zip(*gate_idx_and_rdx))[1]
gate = draw(gates(gate_radixes, constant))
params = floats(allow_nan=False, allow_infinity=False, width=16)
params = floats(-2**16, 2**16, allow_nan=False, allow_infinity=False)
num_params = gate.num_params
gate_params = draw(
lists(
Expand Down
Loading