Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/braket/aws/aws_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def __init__(
self._aws_session = self._get_session_and_initialize(aws_session or AwsSession())
self._ports = None
self._frames = None
if noise_model:
self._validate_device_noise_model_support(noise_model)
self._noise_model = noise_model
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation still makes sense here


def run(
Expand Down
11 changes: 11 additions & 0 deletions src/braket/devices/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,17 @@ def status(self) -> str:
"""
return self._status

def set_noise_model(self, noise_model: NoiseModel) -> None:
"""Set the noise model of the device.

Args:
noise_model (NoiseModel): The Braket noise model to apply to the circuit before
execution. Noise model can only be added to the devices that support noise
simulation.
"""
self._validate_device_noise_model_support(noise_model)
self._noise_model = noise_model

def _validate_device_noise_model_support(self, noise_model: NoiseModel) -> None:
supported_noises = {
SUPPORTED_NOISE_PRAGMA_TO_NOISE[pragma].__name__
Expand Down
2 changes: 0 additions & 2 deletions src/braket/devices/local_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,6 @@ def __init__(
status="AVAILABLE",
)
self._delegate = delegate
if noise_model:
self._validate_device_noise_model_support(noise_model)
self._noise_model = noise_model

def run(
Expand Down
24 changes: 18 additions & 6 deletions test/unit_tests/braket/devices/test_local_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,24 +670,36 @@ def noise_model():


@pytest.mark.parametrize("backend", ["dummy_oq3_dm"])
def test_valid_local_device_for_noise_model(backend, noise_model):
device = LocalSimulator(backend, noise_model=noise_model)
def test_set_noise_model(backend, noise_model):
device = LocalSimulator(backend)
device.set_noise_model(noise_model)
assert device._noise_model.instructions == [
NoiseModelInstruction(Noise.BitFlip(0.05), GateCriteria(Gate.H)),
NoiseModelInstruction(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)),
]


@pytest.mark.parametrize("backend", ["dummy_oq3"])
def test_invalid_local_device_for_noise_model(backend, noise_model):
def test_set_noise_model_invalid_device(backend, noise_model):
with pytest.raises(ValueError):
_ = LocalSimulator(backend, noise_model=noise_model)
device = LocalSimulator(backend)
device.set_noise_model(noise_model)


@pytest.mark.parametrize("backend", ["dummy_oq3_dm"])
def test_local_device_with_invalid_noise_model(backend, noise_model):
def test_set_noise_model_invalid_noise_model(backend, noise_model):
with pytest.raises(TypeError):
_ = LocalSimulator(backend, noise_model=Mock())
device = LocalSimulator(backend)
device.set_noise_model(Mock())


@pytest.mark.parametrize("backend", ["dummy_oq3_dm"])
def test_valid_local_device_for_noise_model(backend, noise_model):
device = LocalSimulator(backend, noise_model=noise_model)
assert device._noise_model.instructions == [
NoiseModelInstruction(Noise.BitFlip(0.05), GateCriteria(Gate.H)),
NoiseModelInstruction(Noise.TwoQubitDepolarizing(0.10), GateCriteria(Gate.CNot)),
]


@patch.object(DummyProgramDensityMatrixSimulator, "run")
Expand Down