diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 47f8cc17..2f7b6163 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + python-version: ["3.12"] steps: - uses: actions/checkout@v3 diff --git a/mypy.ini b/mypy.ini index 765bff0e..05bf0a23 100644 --- a/mypy.ini +++ b/mypy.ini @@ -90,4 +90,10 @@ follow_imports = silent [mypy-fixate.drivers,fixate.drivers.dso.helper,fixate.drivers.funcgen.helper,fixate.drivers.funcgen.rigol_dg1022,fixate.drivers.pps,fixate.drivers.pps.helper,fixate.drivers.ftdi] follow_imports = silent [mypy-fixate.examples.function_generator,fixate.examples.programmable_power_supply,fixate.examples.test_script] -follow_imports = silent \ No newline at end of file +follow_imports = silent + +# this example demos the type checker warning you when using the wrong arguments +# we explicitly suppress the errors in the script, but make sure that mypy +# actually found an error +[mypy-fixate.examples.jig_driver] +enable_error_code = unused-ignore \ No newline at end of file diff --git a/src/fixate/_switching.py b/src/fixate/_switching.py index afade427..48edec3a 100644 --- a/src/fixate/_switching.py +++ b/src/fixate/_switching.py @@ -43,17 +43,24 @@ Dict, FrozenSet, Iterable, + Literal, + Annotated, + get_origin, + get_args, ) from dataclasses import dataclass from functools import reduce from operator import or_ + Signal = str +EmptySignal = Literal[""] Pin = str PinList = Sequence[Pin] PinSet = FrozenSet[Pin] +MapList = Sequence[Sequence[Union[Signal, Pin]]] SignalMap = Dict[Signal, PinSet] -TreeDef = Sequence[Union[Signal, "TreeDef"]] +TreeDef = Sequence[Union[Optional[Signal], "TreeDef"]] @dataclass(frozen=True) @@ -87,15 +94,23 @@ def __or__(self, other: PinUpdate) -> PinUpdate: PinUpdateCallback = Callable[[PinUpdate, bool], None] +S = TypeVar("S", bound=str) + +from types import get_original_bases, resolve_bases -class VirtualMux: + +class VirtualMux(Generic[S]): + map_tree: Optional[TreeDef] = None + map_list: Optional[Sequence[Sequence[str]]] = None pin_list: PinList = () clearing_time: float = 0.0 ########################################################################### # These methods are the public API for the class - def __init__(self, update_pins: Optional[PinUpdateCallback] = None): + # digest all the typing information if there is any to set pin_list and map_list + self._digest_type_hints() + self._last_update_time = time.monotonic() self._update_pins: PinUpdateCallback @@ -129,7 +144,9 @@ def __init__(self, update_pins: Optional[PinUpdateCallback] = None): if hasattr(self, "default_signal"): raise ValueError("'default_signal' should not be set on a VirtualMux") - def __call__(self, signal: Signal, trigger_update: bool = True) -> None: + def __call__( + self, signal: Union[S, EmptySignal], trigger_update: bool = True + ) -> None: """ Convenience to avoid having to type jig.mux..multiplex. @@ -138,7 +155,9 @@ def __call__(self, signal: Signal, trigger_update: bool = True) -> None: """ self.multiplex(signal, trigger_update) - def multiplex(self, signal: Signal, trigger_update: bool = True) -> None: + def multiplex( + self, signal: Union[S, EmptySignal], trigger_update: bool = True + ) -> None: """ Update the multiplexer state to signal. @@ -230,13 +249,13 @@ def _map_signals(self) -> SignalMap: Avoid subclassing. Consider creating helper functions to build map_tree or map_list. """ - if hasattr(self, "map_tree"): + if self.map_tree is not None: return self._map_tree(self.map_tree, self.pin_list, fixed_pins=frozenset()) - elif hasattr(self, "map_list"): + elif self.map_list is not None: return {sig: frozenset(pins) for sig, *pins in self.map_list} else: raise ValueError( - "VirtualMux subclass must define either map_tree or map_list" + "VirtualMux subclass must define either map_tree or map_list or provide a type to VirtualMux" ) def _map_tree(self, tree: TreeDef, pins: PinList, fixed_pins: PinSet) -> SignalMap: @@ -441,6 +460,49 @@ def _default_update_pins( """ print(pin_updates, trigger_update) + def _digest_type_hints(self) -> None: + # digest all the typing information if there is any + bases = get_original_bases(self.__class__) + resolved_bases = resolve_bases(bases) + first_resolved_base = resolved_bases[0] + assert issubclass( + first_resolved_base, VirtualMux + ), f"{first_resolved_base} should be VirtualMux subclass" + if bases != resolved_bases: + args = get_args(bases[0]) + plist, mlist = self._unpack_muxdef(args[0]) + self.pin_list = plist + self.map_list = mlist + + @staticmethod + def _unpack_muxdef(muxdef: type) -> tuple[PinList, MapList]: + # muxdef is the signal definition + if get_origin(muxdef) == Union: + signals = get_args(muxdef) + elif get_origin(muxdef) == Annotated: + # Union FORCES you to have two or more types, so this handles the case of only one pin + signals = (muxdef,) + else: + raise TypeError("Signal definition must be Union or Annotated") + + map_list: list[tuple[str]] = [] + pin_list: list[str] = [] + for s in signals: + assert get_origin(s) == Annotated, "Signal definition must be Annotated" + # get_args gives Literal + sigdef, *pins = get_args(s) + assert ( + get_origin(sigdef) == Literal + ), "Signal definition must be string literal" + # get_args gives members of Literal + (signame,) = get_args(sigdef) + assert isinstance(signame, Signal), "Signal name must be signal type" + assert all(isinstance(p, Pin) for p in pins), "Pins must be pin type" + pin_list.extend(pins) + map_list.append((signame, *pins)) + + return pin_list, map_list + class VirtualSwitch(VirtualMux): """ @@ -482,7 +544,7 @@ def __init__( super().__init__(update_pins) -class RelayMatrixMux(VirtualMux): +class RelayMatrixMux(VirtualMux[S]): clearing_time = 0.01 def _calculate_pins( diff --git a/src/fixate/examples/jig_driver.py b/src/fixate/examples/jig_driver.py index 6d56cdf6..774082f0 100644 --- a/src/fixate/examples/jig_driver.py +++ b/src/fixate/examples/jig_driver.py @@ -10,6 +10,7 @@ MuxGroup, PinValueAddressHandler, VirtualSwitch, + RelayMatrixMux, ) @@ -58,3 +59,87 @@ class JigMuxGroup(MuxGroup): jig.mux.mux_two("sig5") jig.mux.mux_three("On") jig.mux.mux_three(False) + + +# VirtualMuxes can be created with type annotations to provide the signal map +from typing import Literal, Annotated, Union + +# a signal is a typing Annotation +# the first Literal is the signal name, the rest are the pin names +# the signal name MUST be a Literal +# multiple signals can be combined with a Union +# assigning annotations to variables is possible +# fmt: off +MuxOneSigDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a2"], + Annotated[Literal["sig_a2"], "a1"], +] + +MuxTwoSigDef = Union[ + Annotated[Literal["sig_b1"], "b0", "b2"], + Annotated[Literal["sig_b2"], "b1"], +] + +# if defining only a single signal, the Union is omitted in the definition +SingleSingleDef = Annotated[Literal["sig_c1"], "c0", "c1"] +# fmt: on + +# VirtualMuxes can now be created with type annotations to provide the signal map +# this only works when subclassing +class MyMux(VirtualMux[MuxOneSigDef]): + """A helpful description for my mux that is used in this jig driver""" + + +muxa = MyMux() + +muxa("sig_a1") +muxa("sig_a2") + +# using the wrong signal name will be caught at runtime and by the type checker +try: + muxa("unknown signal name") # type: ignore[arg-type] +except ValueError as e: + print(f"An Exception would have occurred: {e}") +else: + raise ValueError("Should have raised an exception") + + +class MultiPinSwitch(VirtualMux[SingleSingleDef]): + """This acts like a switch, but has to coordinate two pins""" + + +ls = MultiPinSwitch() +ls("sig_c1") +ls("") + +# further generic types can be created by subclassing from VirtualMux using a TypeVar +# compared to the above way of subclassing, this way lets you reuse the class + +from typing import TypeVar + +S = TypeVar("S", bound=str) + + +class MyGenericMux(VirtualMux[S]): + ... + + def extra_method(self) -> None: + print("Extra method") + + +class MyConcreteMux(MyGenericMux[MuxTwoSigDef]): + pass + + +generic_mux = MyConcreteMux() +generic_mux("sig_b1") +generic_mux("sig_b2") + +# RelayMatrixMux is an example of a reusable generic class +class MyRelayMatrixMux(RelayMatrixMux[MuxOneSigDef]): + pass + + +rmm = MyRelayMatrixMux() +rmm("sig_a1") +rmm("sig_a2") diff --git a/test/test_switching.py b/test/test_switching.py index 527f1d2c..0c1e2c77 100644 --- a/test/test_switching.py +++ b/test/test_switching.py @@ -15,6 +15,8 @@ JigDriver, ) +from typing import Literal, Union, TypeVar, get_args, get_origin, Annotated + import pytest ################################################################ @@ -592,3 +594,157 @@ def test_pin_update_or(): 2.0, ) assert expected == a | b + + +# fmt: off +MuxASigDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1"], + Annotated[Literal["sig_a2"], "a1"] +] +# fmt: on + + +def test_typed_mux_using_subclass(): + class SubMux(VirtualMux[MuxASigDef]): + pass + + sm = SubMux(update_pins=print) + assert sm._signal_map == MuxA()._signal_map + assert sm._pin_set == MuxA()._pin_set + + +def test_typed_relaymux_using_subclass(): + class SubRelayMux(RelayMatrixMux[MuxASigDef]): + pass + + srm = SubRelayMux() + assert srm._signal_map == MuxA()._signal_map + assert srm._pin_set == MuxA()._pin_set + + +def test_typed_mux_generic_subclass(): + T = TypeVar("T", bound=str) + + class GenericSubMux(VirtualMux[T]): + pass + + class ConcreteMux(GenericSubMux[MuxASigDef]): + pass + + gsm = ConcreteMux() + assert gsm._signal_map == MuxA()._signal_map + assert gsm._pin_set == MuxA()._pin_set + + +def test_typed_mux_one_signal(): + + muxbdef = Annotated[Literal["sig1"], "a0"] + + class MuxB(VirtualMux[muxbdef]): + ... + + clear = PinSetState(off=frozenset({"a0"})) + a1 = PinSetState(on=frozenset({"a0"})) + + updates = [] + muxb = MuxB(update_pins=lambda x, y: updates.append((x, y))) + muxb("sig1") + assert updates.pop() == (PinUpdate(PinSetState(), a1), True) + + muxb("") + assert updates.pop() == (PinUpdate(PinSetState(), clear), True) + + +def test_annotated_preserve_pin_defs(): + annotated = Annotated[Literal["sig_a1"], "a0", "a1"] + sigdef, *pins = get_args(annotated) + + +def test_annotated_raises_on_missing_pin_def(): + with pytest.raises(TypeError): + annotated = Annotated[Literal["sig_a1"]] + + +def test_annotation_bad_pindefs(): + BadMuxDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1"], + Annotated[Literal["sig_a1"], "a0", 1], + ] + + class BadMux(VirtualMux[BadMuxDef]): + pass + + with pytest.raises(AssertionError): + mux = BadMux() + + +def test_annotation_bad_brackets(): + """ + We put the brackets in the wrong spot and accidentally defined + one of the signals as one of the pins of the previous signal + """ + BadMuxDef = Union[ + Annotated[Literal["sig_a1"], "a0", "a1", Annotated[Literal["sig_a2"], "a1"]], + Annotated[Literal["sig_a1"], "a0", "a1"], + ] + + class BadMux(VirtualMux[BadMuxDef]): + pass + + with pytest.raises(AssertionError): + mux = BadMux() + + +def test_annotated_get_origin(): + # Annotated behaviour is different between python versions + # fails 3.8, passes >=3.9 + assert get_origin(Annotated[Literal["sig_a1"], "a0", "a1"]) == Annotated + + +def test_annotated_get_args(): + assert get_args(Annotated[Literal["sig_a1"], "a0", "a1"]) == ( + Literal["sig_a1"], + "a0", + "a1", + ) + + +@pytest.mark.skip( + reason="Revisit this idea once we have a way to stop Generic breaking getattr" +) +def test_typed_mux_class_getitem(): + clear = PinSetState(off=frozenset({"a0", "a1"})) + a1 = PinSetState(on=frozenset({"a0", "a1"})) + a2 = PinSetState(on=frozenset({"a1"}), off=frozenset({"a0"})) + + updates_class_mux = [] + updates_mux_a = [] + + class_mux = VirtualMux[MuxASigDef](lambda x, y: updates_class_mux.append((x, y))) + mux_a = MuxA(lambda x, y: updates_mux_a.append((x, y))) + assert mux_a._signal_map == class_mux._signal_map + assert mux_a._pin_set == class_mux._pin_set + + class_mux("sig_a1") + mux_a("sig_a1") + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), a1), True) + ) + + class_mux.multiplex("sig_a2", trigger_update=False) + mux_a.multiplex("sig_a2", trigger_update=False) + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), a2), False) + ) + + class_mux("") + mux_a("") + assert ( + updates_class_mux.pop() + == updates_mux_a.pop() + == (PinUpdate(PinSetState(), clear), True) + )