diff --git a/src/dodal/devices/oav/pin_image_recognition/utils.py b/src/dodal/devices/oav/pin_image_recognition/utils.py index d1f5fdc440d..72a16fcb4d0 100644 --- a/src/dodal/devices/oav/pin_image_recognition/utils.py +++ b/src/dodal/devices/oav/pin_image_recognition/utils.py @@ -14,6 +14,12 @@ class ScanDirections(Enum): REVERSE = -1 +""" +See https://opencv24-python-tutorials.readthedocs.io/en/latest/py_tutorials/py_imgproc/py_morphological_ops/py_morphological_ops.html +for description of functions below. +""" + + def identity(*args, **kwargs) -> Callable[[np.ndarray], np.ndarray]: return lambda arr: arr diff --git a/src/dodal/devices/scintillator.py b/src/dodal/devices/scintillator.py index c6c98491dd9..bfa4506b1e7 100644 --- a/src/dodal/devices/scintillator.py +++ b/src/dodal/devices/scintillator.py @@ -8,9 +8,10 @@ class InOut(StrictEnum): - """Currently Hyperion only needs to move the scintillator out for data collection.""" + """Moves scintillator in and out of the beam.""" - OUT = "Out" + OUT = "Out" # Out of beam + IN = "In" # In to beam UNKNOWN = "Unknown" @@ -45,27 +46,46 @@ def __init__( self._scintillator_out_yz_mm = [ float(beamline_parameters[f"scin_{axis}_SCIN_OUT"]) for axis in ("y", "z") ] + self._scintillator_in_yz_mm = [ + float(beamline_parameters[f"scin_{axis}_SCIN_IN"]) for axis in ("y", "z") + ] self._yz_tolerance_mm = [ float(beamline_parameters[f"scin_{axis}_tolerance"]) for axis in ("y", "z") ] super().__init__(name) - def _get_selected_position(self, y: float, z: float) -> InOut: - current_pos = [y, z] - if all( + def _check_position(self, current_pos: list[float], pos_to_check: list[float]): + return all( isclose(axis_pos, axis_in_beam, abs_tol=axis_tolerance) for axis_pos, axis_in_beam, axis_tolerance in zip( current_pos, - self._scintillator_out_yz_mm, + pos_to_check, self._yz_tolerance_mm, strict=False, ) - ): + ) + + def _get_selected_position(self, y: float, z: float) -> InOut: + current_pos = [y, z] + if self._check_position(current_pos, self._scintillator_out_yz_mm): return InOut.OUT + + elif self._check_position(current_pos, self._scintillator_in_yz_mm): + return InOut.IN + else: return InOut.UNKNOWN + async def _check_aperture_parked(self): + if ( + await self._aperture_scatterguard().selected_aperture.get_value() + != ApertureValue.PARKED + ): + raise ValueError( + f"Cannot move scintillator if aperture/scatterguard is not parked. Position is currently {await self._aperture_scatterguard().selected_aperture.get_value()}" + ) + async def _set_selected_position(self, position: InOut) -> None: match position: case InOut.OUT: @@ -73,14 +93,16 @@ async def _set_selected_position(self, position: InOut) -> None: current_z = await self.z_mm.user_readback.get_value() if self._get_selected_position(current_y, current_z) == InOut.OUT: return - if ( - self._aperture_scatterguard().selected_aperture.get_value() - != ApertureValue.PARKED - ): - raise ValueError( - "Cannot move scintillator out if aperture/scatterguard is not parked" - ) + await self._check_aperture_parked() await self.y_mm.set(self._scintillator_out_yz_mm[0]) await self.z_mm.set(self._scintillator_out_yz_mm[1]) + case InOut.IN: + current_y = await self.y_mm.user_readback.get_value() + current_z = await self.z_mm.user_readback.get_value() + if self._get_selected_position(current_y, current_z) == InOut.IN: + return + await self._check_aperture_parked() + await self.z_mm.set(self._scintillator_in_yz_mm[1]) + await self.y_mm.set(self._scintillator_in_yz_mm[0]) case _: raise ValueError(f"Cannot set scintillator to position {position}") diff --git a/tests/devices/test_scintillator.py b/tests/devices/test_scintillator.py index 976dbbf997b..8264b03d182 100644 --- a/tests/devices/test_scintillator.py +++ b/tests/devices/test_scintillator.py @@ -1,10 +1,10 @@ from collections.abc import AsyncGenerator from contextlib import ExitStack -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from ophyd_async.core import init_devices -from ophyd_async.testing import get_mock_put +from ophyd_async.testing import assert_value, get_mock_put from dodal.common.beamlines.beamline_parameters import GDABeamlineParameters from dodal.devices.aperturescatterguard import ApertureScatterguard, ApertureValue @@ -32,6 +32,8 @@ async def scintillator_and_ap_sg( ) -> AsyncGenerator[tuple[Scintillator, MagicMock], None]: async with init_devices(mock=True): mock_ap_sg = MagicMock() + mock_ap_sg.return_value.selected_aperture.set = AsyncMock() + mock_ap_sg.return_value.selected_aperture.get_value = AsyncMock() scintillator = Scintillator( prefix="", name="test_scin", @@ -49,6 +51,7 @@ async def scintillator_and_ap_sg( @pytest.mark.parametrize( "y, z, expected_position", [ + (100.855, 101.5115, InOut.IN), (-0.02, 0.1, InOut.OUT), (0.1, 0.1, InOut.UNKNOWN), (10.2, 15.6, InOut.UNKNOWN), @@ -84,34 +87,56 @@ async def test_given_aperture_scatterguard_parked_when_set_to_out_position_then_ await scintillator.selected_pos.set(InOut.OUT) - assert await scintillator.y_mm.user_setpoint.get_value() == -0.02 - assert await scintillator.z_mm.user_setpoint.get_value() == 0.1 + await assert_value(scintillator.y_mm.user_setpoint, -0.02) + await assert_value(scintillator.z_mm.user_setpoint, 0.1) -async def test_given_aperture_scatterguard_not_parked_when_set_to_out_position_then_exception_raised( +async def test_given_aperture_scatterguard_parked_when_set_to_in_position_then_returns_expected( scintillator_and_ap_sg: tuple[Scintillator, ApertureScatterguard], +): + scintillator, ap_sg = scintillator_and_ap_sg + ap_sg.return_value.selected_aperture.get_value.return_value = ApertureValue.PARKED # type: ignore + + await scintillator.selected_pos.set(InOut.IN) + + await assert_value(scintillator.y_mm.user_setpoint, 100.855) + await assert_value(scintillator.z_mm.user_setpoint, 101.5115) + + +@pytest.mark.parametrize("scint_pos", [InOut.OUT, InOut.IN]) +async def test_given_aperture_scatterguard_not_parked_when_set_to_in_or_out_position_then_exception_raised( + scintillator_and_ap_sg: tuple[Scintillator, ApertureScatterguard], scint_pos ): for position in ApertureValue: if position != ApertureValue.PARKED: scintillator, ap_sg = scintillator_and_ap_sg ap_sg.return_value.selected_aperture.get_value.return_value = position # type: ignore - with pytest.raises(ValueError): - await scintillator.selected_pos.set(InOut.OUT) + await scintillator.selected_pos.set(scint_pos) -async def test_given_scintillator_already_out_when_moved_out_then_does_nothing( +@pytest.mark.parametrize( + "y, z, expected_position", + [ + (100.855, 101.5115, InOut.IN), + (-0.02, 0.1, InOut.OUT), + ], +) +async def test_given_scintillator_already_out_when_moved_in_or_out_then_does_nothing( scintillator_and_ap_sg: tuple[Scintillator, ApertureScatterguard], + expected_position, + y, + z, ): scintillator, ap_sg = scintillator_and_ap_sg - await scintillator.y_mm.set(0) - await scintillator.z_mm.set(0) + await scintillator.y_mm.set(y) + await scintillator.z_mm.set(z) get_mock_put(scintillator.y_mm.user_setpoint).reset_mock() get_mock_put(scintillator.z_mm.user_setpoint).reset_mock() ap_sg.return_value.selected_aperture.get_value.return_value = ApertureValue.LARGE # type: ignore - await scintillator.selected_pos.set(InOut.OUT) + await scintillator.selected_pos.set(expected_position) get_mock_put(scintillator.y_mm.user_setpoint).assert_not_called() get_mock_put(scintillator.z_mm.user_setpoint).assert_not_called()