diff --git a/src/nipanel/_panel_value_accessor.py b/src/nipanel/_panel_value_accessor.py index f05d2ab..c174843 100644 --- a/src/nipanel/_panel_value_accessor.py +++ b/src/nipanel/_panel_value_accessor.py @@ -1,6 +1,7 @@ from __future__ import annotations from abc import ABC +from typing import TypeVar, overload import grpc from ni_measurement_plugin_sdk_service.discovery import DiscoveryClient @@ -8,6 +9,8 @@ from nipanel._panel_client import PanelClient +_T = TypeVar("_T") + class PanelValueAccessor(ABC): """This class allows you to access values for a panel's controls.""" @@ -41,7 +44,13 @@ def panel_id(self) -> str: """Read-only accessor for the panel ID.""" return self._panel_id - def get_value(self, value_id: str, default_value: object = None) -> object: + @overload + def get_value(self, value_id: str) -> object: ... + + @overload + def get_value(self, value_id: str, default_value: _T) -> _T: ... + + def get_value(self, value_id: str, default_value: _T | None = None) -> _T | object: """Get the value for a control on the panel with an optional default value. Args: @@ -52,7 +61,11 @@ def get_value(self, value_id: str, default_value: object = None) -> object: The value, or the default value if not set """ try: - return self._panel_client.get_value(self._panel_id, value_id) + value = self._panel_client.get_value(self._panel_id, value_id) + if default_value is not None and not isinstance(value, type(default_value)): + raise TypeError("Value type does not match default value type.") + return value + except grpc.RpcError as e: if e.code() == grpc.StatusCode.NOT_FOUND and default_value is not None: return default_value diff --git a/tests/unit/test_streamlit_panel.py b/tests/unit/test_streamlit_panel.py index 78e77a8..0d65778 100644 --- a/tests/unit/test_streamlit_panel.py +++ b/tests/unit/test_streamlit_panel.py @@ -1,5 +1,6 @@ import grpc import pytest +from typing_extensions import assert_type import tests.types as test_types from nipanel import StreamlitPanel, StreamlitPanelValueAccessor @@ -146,7 +147,7 @@ def test___panel___set_value___gets_value( assert panel.get_value(value_id) == string_value -def test___set_value___get_value_ignores_default( +def test___panel___set_value___get_value_ignores_default( fake_panel_channel: grpc.Channel, ) -> None: panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) @@ -158,7 +159,7 @@ def test___set_value___get_value_ignores_default( assert panel.get_value(value_id, "default") == string_value -def test___get_value_returns_default_when_value_not_set( +def test___no_set_value___get_value_returns_default( fake_panel_channel: grpc.Channel, ) -> None: panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) @@ -168,6 +169,77 @@ def test___get_value_returns_default_when_value_not_set( assert panel.get_value("missing_float", 1.23) == 1.23 assert panel.get_value("missing_bool", True) is True assert panel.get_value("missing_list", [1, 2, 3]) == [1, 2, 3] + assert_type(panel.get_value("missing_string", "default"), str) + assert_type(panel.get_value("missing_int", 123), int) + assert_type(panel.get_value("missing_float", 1.23), float) + assert_type(panel.get_value("missing_bool", True), bool) + assert_type(panel.get_value("missing_list", [1, 2, 3]), list[int]) + + +def test___set_string_type___get_value_with_string_default___returns_string_type( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + string_value = "test_value" + panel.set_value(value_id, string_value) + + value = panel.get_value(value_id, "") + + assert_type(value, str) + assert value == string_value + + +def test___set_int_type___get_value_with_int_default___returns_int_type( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + int_value = 10 + panel.set_value(value_id, int_value) + + value = panel.get_value(value_id, 0) + + assert_type(value, int) + assert value == int_value + + +def test___set_bool_type___get_value_with_bool_default___returns_bool_type( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + bool_value = True + panel.set_value(value_id, bool_value) + + value = panel.get_value(value_id, False) + + assert_type(value, bool) + assert value is bool_value + + +def test___set_string_type___get_value_with_int_default___raises_exception( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + string_value = "test_value" + panel.set_value(value_id, string_value) + + with pytest.raises(TypeError): + panel.get_value(value_id, 0) + + +def test___set_int_type___get_value_with_bool_default___raises_exception( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + int_value = 10 + panel.set_value(value_id, int_value) + + with pytest.raises(TypeError): + panel.get_value(value_id, False) @pytest.mark.parametrize(