diff --git a/examples/all_types/all_types_panel.py b/examples/all_types/all_types_panel.py index be5b0a0..b2e046b 100644 --- a/examples/all_types/all_types_panel.py +++ b/examples/all_types/all_types_panel.py @@ -1,9 +1,12 @@ """A Streamlit visualization panel for the all_types.py example script.""" +from enum import Enum, Flag + import streamlit as st from define_types import all_types_with_values import nipanel +from nipanel.controls import enum_selectbox, flag_checkboxes st.set_page_config(page_title="All Types Example", page_icon="📊", layout="wide") @@ -11,10 +14,27 @@ panel = nipanel.get_panel_accessor() for name in all_types_with_values.keys(): - col1, col2 = st.columns([0.4, 0.6]) + st.markdown("---") + + default_value = all_types_with_values[name] + col1, col2, col3 = st.columns([0.2, 0.2, 0.6]) with col1: st.write(name) with col2: - st.write(panel.get_value(name)) + if isinstance(default_value, bool): + st.checkbox(label=name, value=default_value, key=name) + elif isinstance(default_value, Flag): + flag_checkboxes(panel, label=name, value=default_value, key=name) + elif isinstance(default_value, Enum) and not isinstance(default_value, Flag): + enum_selectbox(panel, label=name, value=default_value, key=name) + elif isinstance(default_value, int) and not isinstance(default_value, Flag): + st.number_input(label=name, value=default_value, key=name) + elif isinstance(default_value, float): + st.number_input(label=name, value=default_value, key=name, format="%.2f") + elif isinstance(default_value, str): + st.text_input(label=name, value=default_value, key=name) + + with col3: + st.write(panel.get_value(name, default_value=default_value)) diff --git a/examples/all_types/define_types.py b/examples/all_types/define_types.py index 5d5906e..33c8c47 100644 --- a/examples/all_types/define_types.py +++ b/examples/all_types/define_types.py @@ -15,6 +15,14 @@ class MyIntFlags(enum.IntFlag): VALUE4 = 4 +class MyIntableFlags(enum.Flag): + """Example of an Flag enum with int values.""" + + VALUE8 = 8 + VALUE16 = 16 + VALUE32 = 32 + + class MyIntEnum(enum.IntEnum): """Example of an IntEnum enum.""" @@ -23,6 +31,14 @@ class MyIntEnum(enum.IntEnum): VALUE30 = 30 +class MyIntableEnum(enum.Enum): + """Example of an enum with int values.""" + + VALUE100 = 100 + VALUE200 = 200 + VALUE300 = 300 + + class MyStrEnum(str, enum.Enum): """Example of a mixin string enum.""" @@ -31,6 +47,22 @@ class MyStrEnum(str, enum.Enum): VALUE3 = "value3" +class MyStringableEnum(enum.Enum): + """Example of an enum with string values.""" + + VALUE1 = "value1" + VALUE2 = "value2" + VALUE3 = "value3" + + +class MyMixedEnum(enum.Enum): + """Example of an enum with mixed values.""" + + VALUE1 = "value1" + VALUE2 = 2 + VALUE3 = 3.0 + + all_types_with_values = { # supported scalar types "bool": True, @@ -38,16 +70,23 @@ class MyStrEnum(str, enum.Enum): "float": 13.12, "int": 42, "str": "sample string", + # supported enum and flag types + "intflags": MyIntFlags.VALUE1 | MyIntFlags.VALUE4, + "intenum": MyIntEnum.VALUE20, + "strenum": MyStrEnum.VALUE3, + "intableenum": MyIntableEnum.VALUE200, + "intableflags": MyIntableFlags.VALUE8 | MyIntableFlags.VALUE32, + "stringableenum": MyStringableEnum.VALUE2, + "mixedenum": MyMixedEnum.VALUE2, + # NI types + "nitypes_Scalar": Scalar(42, "m"), + "nitypes_AnalogWaveform": AnalogWaveform.from_array_1d(np.array([1.0, 2.0, 3.0])), # supported collection types "bool_collection": [True, False, True], "bytes_collection": [b"one", b"two", b"three"], "float_collection": [1.1, 2.2, 3.3], "int_collection": [1, 2, 3], "str_collection": ["one", "two", "three"], - # supported enum and flag types - "intflags": MyIntFlags.VALUE1 | MyIntFlags.VALUE4, - "intenum": MyIntEnum.VALUE20, - "strenum": MyStrEnum.VALUE3, "intflags_collection": [MyIntFlags.VALUE1, MyIntFlags.VALUE2, MyIntFlags.VALUE4], "intenum_collection": [MyIntEnum.VALUE10, MyIntEnum.VALUE20, MyIntEnum.VALUE30], "strenum_collection": [MyStrEnum.VALUE1, MyStrEnum.VALUE2, MyStrEnum.VALUE3], @@ -56,9 +95,6 @@ class MyStrEnum(str, enum.Enum): "tuple": (4, 5, 6), "set": {7, 8, 9}, "frozenset": frozenset([10, 11, 12]), - # NI types - "nitypes_Scalar": Scalar(42, "m"), - "nitypes_AnalogWaveform": AnalogWaveform.from_array_1d(np.array([1.0, 2.0, 3.0])), # supported 2D collections "list_list_float": [[1.0, 2.0], [3.0, 4.0]], "tuple_tuple_float": ((1.0, 2.0), (3.0, 4.0)), diff --git a/examples/nidaqmx/nidaqmx_continuous_analog_input.py b/examples/nidaqmx/nidaqmx_continuous_analog_input.py index a00b768..bed43eb 100644 --- a/examples/nidaqmx/nidaqmx_continuous_analog_input.py +++ b/examples/nidaqmx/nidaqmx_continuous_analog_input.py @@ -1,34 +1,83 @@ """Data acquisition script that continuously acquires analog input data.""" +import time from pathlib import Path import nidaqmx -from nidaqmx.constants import AcquisitionType +from nidaqmx.constants import ( + AcquisitionType, + TerminalConfiguration, + CJCSource, + TemperatureUnits, + ThermocoupleType, + LoggingMode, + LoggingOperation, +) import nipanel panel_script_path = Path(__file__).with_name("nidaqmx_continuous_analog_input_panel.py") panel = nipanel.create_panel(panel_script_path) -# How to use nidaqmx: https://nidaqmx-python.readthedocs.io/en/stable/ -with nidaqmx.Task() as task: - task.ai_channels.add_ai_voltage_chan("Dev1/ai0") - task.ai_channels.add_ai_thrmcpl_chan("Dev1/ai1") - task.timing.cfg_samp_clk_timing( - rate=1000.0, sample_mode=AcquisitionType.CONTINUOUS, samps_per_chan=3000 - ) - panel.set_value("sample_rate", task._timing.samp_clk_rate) - task.start() +try: print(f"Panel URL: {panel.panel_url}") - try: - print(f"Press Ctrl + C to stop") - while True: - data = task.read( - number_of_samples_per_channel=1000 # pyright: ignore[reportArgumentType] + print(f"Waiting for the 'Run' button to be pressed...") + print(f"(Press Ctrl + C to quit)") + while True: + panel.set_value("run_button", False) + while not panel.get_value("run_button", False): + time.sleep(0.1) + + # How to use nidaqmx: https://nidaqmx-python.readthedocs.io/en/stable/ + with nidaqmx.Task() as task: + task.ai_channels.add_ai_voltage_chan( + physical_channel="Dev1/ai0", + min_val=panel.get_value("voltage_min_value", -5.0), + max_val=panel.get_value("voltage_max_value", 5.0), + terminal_config=panel.get_value( + "terminal_configuration", TerminalConfiguration.DEFAULT + ), + ) + task.ai_channels.add_ai_thrmcpl_chan( + "Dev1/ai1", + min_val=panel.get_value("thermocouple_min_value", 0.0), + max_val=panel.get_value("thermocouple_max_value", 100.0), + units=panel.get_value("thermocouple_units", TemperatureUnits.DEG_C), + thermocouple_type=panel.get_value("thermocouple_type", ThermocoupleType.K), + cjc_source=panel.get_value( + "thermocouple_cjc_source", CJCSource.CONSTANT_USER_VALUE + ), + cjc_val=panel.get_value("thermocouple_cjc_val", 25.0), + ) + task.timing.cfg_samp_clk_timing( + rate=panel.get_value("sample_rate_input", 1000.0), + sample_mode=AcquisitionType.CONTINUOUS, + samps_per_chan=panel.get_value("samples_per_channel", 3000), ) - panel.set_value("voltage_data", data[0]) - panel.set_value("thermocouple_data", data[1]) - except KeyboardInterrupt: - pass - finally: - task.stop() + task.in_stream.configure_logging( + file_path=panel.get_value("tdms_file_path", "data.tdms"), + logging_mode=panel.get_value("logging_mode", LoggingMode.OFF), + operation=LoggingOperation.OPEN_OR_CREATE, + ) + panel.set_value("sample_rate", task._timing.samp_clk_rate) + try: + print(f"Starting data acquisition...") + task.start() + panel.set_value("is_running", True) + + panel.set_value("stop_button", False) + while not panel.get_value("stop_button", False): + data = task.read( + number_of_samples_per_channel=1000 # pyright: ignore[reportArgumentType] + ) + panel.set_value("voltage_data", data[0]) + panel.set_value("thermocouple_data", data[1]) + except KeyboardInterrupt: + raise + finally: + print(f"Stopping data acquisition...") + task.stop() + panel.set_value("is_running", False) + +except KeyboardInterrupt: + pass diff --git a/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py b/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py index e9163b0..59e108e 100644 --- a/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py +++ b/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py @@ -1,20 +1,33 @@ """Streamlit visualization script to display data acquired by nidaqmx_continuous_analog_input.py.""" import streamlit as st +from nidaqmx.constants import ( + TerminalConfiguration, + CJCSource, + TemperatureUnits, + ThermocoupleType, + LoggingMode, +) from streamlit_echarts import st_echarts import nipanel +from nipanel.controls import enum_selectbox st.set_page_config(page_title="NI-DAQmx Example", page_icon="📈", layout="wide") st.title("Analog Input - Voltage and Thermocouple in a Single Task") -voltage_tab, thermocouple_tab = st.tabs(["Voltage", "Thermocouple"]) st.markdown( """ """, @@ -22,82 +35,203 @@ ) panel = nipanel.get_panel_accessor() +is_running = panel.get_value("is_running", False) + +if is_running: + st.button(r"âšī¸ Stop", key="stop_button") +else: + st.button(r"â–ļī¸ Run", key="run_button") + thermocouple_data = panel.get_value("thermocouple_data", [0.0]) voltage_data = panel.get_value("voltage_data", [0.0]) - sample_rate = panel.get_value("sample_rate", 0.0) -st.header("Voltage & Thermocouple") -voltage_therm_graph = { - "animation": False, - "tooltip": {"trigger": "axis"}, - "legend": {"data": ["Voltage (V)", "Temperature (C)"]}, - "xAxis": { - "type": "category", - "data": [x / sample_rate for x in range(len(voltage_data))], - "name": "Time", - "nameLocation": "center", - "nameGap": 40, - }, - "yAxis": { - "type": "value", - "name": "Measurement", - "nameRotate": 90, - "nameLocation": "center", - "nameGap": 40, - }, - "series": [ - { - "name": "voltage_amplitude", - "type": "line", - "data": voltage_data, - "emphasis": {"focus": "series"}, - "smooth": True, - "seriesLayoutBy": "row", - }, - { - "name": "thermocouple_amplitude", - "type": "line", - "data": thermocouple_data, - "color": "red", - "emphasis": {"focus": "series"}, - "smooth": True, - "seriesLayoutBy": "row", - }, - ], -} -st_echarts(options=voltage_therm_graph, height="400px", key="voltage_therm_graph") +# Create two-column layout for the entire interface +left_column, right_column = st.columns([1, 1]) + +# Left column - Channel tabs and Timing Settings +with left_column: + # Channel Settings tabs + with st.container(border=True): + st.header("Channel Settings") + voltage_tab, thermocouple_tab = st.tabs(["Voltage", "Thermocouple"]) + + voltage_tab.header("Voltage") + with voltage_tab: + channel_left_column, channel_right_column = st.columns(2) + with channel_left_column: + st.selectbox(options=["Dev1/ai0"], label="Physical Channels", disabled=True) + st.number_input( + "Min Value", + value=-5.0, + step=0.1, + disabled=panel.get_value("is_running", False), + key="voltage_min_value", + ) + st.number_input( + "Max Value", + value=5.0, + step=0.1, + disabled=panel.get_value("is_running", False), + key="voltage_max_value", + ) + with channel_right_column: + enum_selectbox( + panel, + label="Terminal Configuration", + value=TerminalConfiguration.DEFAULT, + disabled=panel.get_value("is_running", False), + key="terminal_configuration", + ) -voltage_tab.header("Voltage") -with voltage_tab: - left_volt_tab, center_volt_tab, right_volt_tab = st.columns(3) - with left_volt_tab: - st.selectbox(options=["Dev1/ai0"], label="Physical Channels", disabled=True) - st.selectbox(options=["Off"], label="Logging Modes", disabled=False) - with center_volt_tab: - st.selectbox(options=["-5"], label="Min Value") - st.selectbox(options=["5"], label="Max Value") - st.selectbox(options=["1000"], label="Samples per Loops", disabled=False) - with right_volt_tab: - st.selectbox(options=["default"], label="Terminal Configurations") - st.selectbox(options=["OnboardClock"], label="Sample Clock Sources", disabled=False) + thermocouple_tab.header("Thermocouple") + with thermocouple_tab: + channel_left_column, channel_middle_column, channel_right_column = st.columns(3) + with channel_left_column: + st.selectbox(options=["Dev1/ai1"], label="Physical Channel", disabled=True) + st.number_input( + "Min Value", + value=0.0, + step=1.0, + disabled=panel.get_value("is_running", False), + key="thermocouple_min_value", + ) + st.number_input( + "Max Value", + value=100.0, + step=1.0, + disabled=panel.get_value("is_running", False), + key="thermocouple_max_value", + ) + with channel_middle_column: + enum_selectbox( + panel, + label="Units", + value=TemperatureUnits.DEG_C, + disabled=panel.get_value("is_running", False), + key="thermocouple_units", + ) + enum_selectbox( + panel, + label="Thermocouple Type", + value=ThermocoupleType.K, + disabled=panel.get_value("is_running", False), + key="thermocouple_type", + ) + with channel_right_column: + enum_selectbox( + panel, + label="CJC Source", + value=CJCSource.CONSTANT_USER_VALUE, + disabled=panel.get_value("is_running", False), + key="thermocouple_cjc_source", + ) + st.number_input( + "CJC Value", + value=25.0, + step=1.0, + disabled=panel.get_value("is_running", False), + key="thermocouple_cjc_val", + ) + # Timing Settings section in left column + with st.container(border=True): + st.header("Timing Settings") + timing_left_column, timing_right_column = st.columns(2) + with timing_left_column: + st.selectbox( + options=["OnboardClock"], + label="Sample Clock Source", + disabled=True, + ) + st.number_input( + "Sample Rate", + value=1000.0, + step=100.0, + min_value=1.0, + disabled=panel.get_value("is_running", False), + key="sample_rate_input", + ) + with timing_right_column: + st.number_input( + "Samples per Loop", + value=3000, + step=100, + min_value=10, + disabled=panel.get_value("is_running", False), + key="samples_per_channel", + ) + st.text_input( + label="Actual Sample Rate", + value=str(sample_rate) if sample_rate else "", + key="actual_sample_rate_display", + ) -thermocouple_tab.header("Thermocouple") -with thermocouple_tab: - left, middle, right = st.columns(3) - with left: - st.selectbox(options=["Dev1/ai1"], label="Physical Channel", disabled=True) - st.selectbox(options=["0"], label="Min", disabled=False) - st.selectbox(options=["100"], label="Max", disabled=False) - st.selectbox(options=["Off"], label="Logging Mode", disabled=False) +# Right column - Graph and Logging Settings +with right_column: + with st.container(border=True): + # Graph section + st.header("Voltage & Thermocouple") + voltage_therm_graph = { + "animation": False, + "tooltip": {"trigger": "axis"}, + "legend": {"data": ["Voltage (V)", "Temperature (C)"]}, + "xAxis": { + "type": "category", + "data": [ + x / sample_rate if sample_rate > 0.001 else x for x in range(len(voltage_data)) + ], + "name": "Time", + "nameLocation": "center", + "nameGap": 40, + }, + "yAxis": { + "type": "value", + "name": "Measurement", + "nameRotate": 90, + "nameLocation": "center", + "nameGap": 40, + }, + "series": [ + { + "name": "voltage_amplitude", + "type": "line", + "data": voltage_data, + "emphasis": {"focus": "series"}, + "smooth": True, + "seriesLayoutBy": "row", + }, + { + "name": "thermocouple_amplitude", + "type": "line", + "data": thermocouple_data, + "color": "red", + "emphasis": {"focus": "series"}, + "smooth": True, + "seriesLayoutBy": "row", + }, + ], + } + st_echarts(options=voltage_therm_graph, height="446px", key="voltage_therm_graph") - with middle: - st.selectbox(options=["Deg C"], label="Units", disabled=False) - st.selectbox(options=["J"], label="Thermocouple Type", disabled=False) - st.selectbox(options=["Constant Value"], label="CJC Source", disabled=False) - st.selectbox(options=["1000"], label="Samples per Loop", disabled=False) - with right: - st.selectbox(options=["25"], label="CJC Value", disabled=False) - st.selectbox(options=["OnboardClock"], label="Sample Clock Source", disabled=False) - st.selectbox(options=[" "], label="Actual Sample Rate", disabled=True) + # Logging Settings section in right column + with st.container(border=True): + st.header("Logging Settings") + logging_left_column, logging_right_column = st.columns(2) + with logging_left_column: + enum_selectbox( + panel, + label="Logging Mode", + value=LoggingMode.OFF, + disabled=panel.get_value("is_running", False), + key="logging_mode", + ) + with logging_right_column: + left_sub_column, right_sub_column = st.columns([3, 1]) + with left_sub_column: + tdms_file_path = st.text_input( + label="TDMS File Path", + disabled=panel.get_value("is_running", False), + value="data.tdms", + key="tdms_file_path", + ) diff --git a/src/nipanel/_convert.py b/src/nipanel/_convert.py index 5bbc6d0..61a4858 100644 --- a/src/nipanel/_convert.py +++ b/src/nipanel/_convert.py @@ -129,3 +129,12 @@ def from_any(protobuf_any: any_pb2.Any) -> object: converter = _CONVERTER_FOR_GRPC_TYPE[underlying_typename] return converter.to_python(protobuf_any) + + +def is_supported_type(value: object) -> bool: + """Check if a given Python value can be converted to protobuf Any.""" + try: + _get_best_matching_type(value) + return True + except TypeError: + return False diff --git a/src/nipanel/_panel_value_accessor.py b/src/nipanel/_panel_value_accessor.py index d60f7d7..cb53d56 100644 --- a/src/nipanel/_panel_value_accessor.py +++ b/src/nipanel/_panel_value_accessor.py @@ -1,5 +1,7 @@ from __future__ import annotations +import collections +import enum from abc import ABC from typing import TypeVar, overload @@ -15,7 +17,13 @@ class PanelValueAccessor(ABC): """This class allows you to access values for a panel's controls.""" - __slots__ = ["_panel_client", "_panel_id", "_notify_on_set_value", "__weakref__"] + __slots__ = [ + "_panel_client", + "_panel_id", + "_notify_on_set_value", + "_last_values", + "__weakref__", + ] def __init__( self, @@ -38,6 +46,9 @@ def __init__( ) self._panel_id = panel_id self._notify_on_set_value = notify_on_set_value + self._last_values: collections.defaultdict[str, object] = collections.defaultdict( + lambda: object() + ) @property def panel_id(self) -> str: @@ -67,7 +78,16 @@ def get_value(self, value_id: str, default_value: _T | None = None) -> _T | obje raise KeyError(f"Value with id '{value_id}' not found on panel '{self._panel_id}'.") if default_value is not None and not isinstance(value, type(default_value)): - raise TypeError("Value type does not match default value type.") + if isinstance(default_value, enum.Enum): + enum_type = type(default_value) + return enum_type(value) + + # lists are allowed to not match, since sets and tuples are converted to lists + if not isinstance(value, list): + raise TypeError( + f"Value type {type(value).__name__} does not match default value type {type(default_value).__name__}." + ) + return value def set_value(self, value_id: str, value: object) -> None: @@ -77,6 +97,22 @@ def set_value(self, value_id: str, value: object) -> None: value_id: The id of the value value: The value """ + if isinstance(value, enum.Enum): + value = value.value + self._panel_client.set_value( self._panel_id, value_id, value, notify=self._notify_on_set_value ) + self._last_values[value_id] = value + + def set_value_if_changed(self, value_id: str, value: object) -> None: + """Set the value for a control on the panel only if it has changed since the last call. + + This method helps reduce unnecessary updates when the value hasn't changed. + + Args: + value_id: The id of the value + value: The value to set + """ + if value != self._last_values[value_id]: + self.set_value(value_id, value) diff --git a/src/nipanel/_streamlit_panel_initializer.py b/src/nipanel/_streamlit_panel_initializer.py index 7dee3fd..a05822d 100644 --- a/src/nipanel/_streamlit_panel_initializer.py +++ b/src/nipanel/_streamlit_panel_initializer.py @@ -3,6 +3,7 @@ import streamlit as st +from nipanel._convert import is_supported_type from nipanel._streamlit_panel import StreamlitPanel from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor from nipanel.streamlit_refresh import initialize_refresh_component @@ -61,6 +62,7 @@ def get_panel_accessor() -> StreamlitPanelValueAccessor: st.session_state[PANEL_ACCESSOR_KEY] = _initialize_panel_from_base_path() panel = cast(StreamlitPanelValueAccessor, st.session_state[PANEL_ACCESSOR_KEY]) + _sync_session_state(panel) refresh_component = initialize_refresh_component(panel.panel_id) refresh_component() return panel @@ -75,3 +77,11 @@ def _initialize_panel_from_base_path() -> StreamlitPanelValueAccessor: if not panel_id: raise ValueError(f"Panel ID is empty in baseUrlPath: '{base_url_path}'") return StreamlitPanelValueAccessor(panel_id) + + +def _sync_session_state(panel: StreamlitPanelValueAccessor) -> None: + """Automatically read keyed control values from the session state.""" + for key in st.session_state.keys(): + value = st.session_state[key] + if is_supported_type(value): + panel.set_value_if_changed(str(key), value) diff --git a/src/nipanel/controls/__init__.py b/src/nipanel/controls/__init__.py new file mode 100644 index 0000000..6882af0 --- /dev/null +++ b/src/nipanel/controls/__init__.py @@ -0,0 +1,6 @@ +"""Controls for nipanel.""" + +from nipanel.controls._enum_selectbox import enum_selectbox +from nipanel.controls._flag_checkboxes import flag_checkboxes + +__all__ = ["enum_selectbox", "flag_checkboxes"] diff --git a/src/nipanel/controls/_enum_selectbox.py b/src/nipanel/controls/_enum_selectbox.py new file mode 100644 index 0000000..a75277c --- /dev/null +++ b/src/nipanel/controls/_enum_selectbox.py @@ -0,0 +1,53 @@ +"""A selectbox that allows selecting an Enum value.""" + +from enum import Enum +from typing import Any, Callable, TypeVar + +import streamlit as st + +from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor + +TEnumType = TypeVar("TEnumType", bound=Enum) + + +def enum_selectbox( + panel: StreamlitPanelValueAccessor, + label: str, + value: TEnumType, + key: str, + disabled: bool = False, + format_func: Callable[[Any], str] = lambda x: x[0], +) -> TEnumType: + """Create a selectbox for an Enum. + + The selectbox will display the names of all the enum values, and when a value is selected, + that value will be stored in the panel under the specified key. + + Args: + panel: The panel + label: Label to display for the selectbox + value: The default enum value to select (also determines the specific enum type) + key: Key to use for storing the enum value in the panel + + Returns: + The selected enum value of the same specific enum subclass as the input value + """ + enum_class = type(value) + if not issubclass(enum_class, Enum): + raise TypeError(f"Expected an Enum type, got {type(value)}") + + options = [(e.name, e.value) for e in enum_class] + + default_index = 0 + if value is not None: + for i, (name, _) in enumerate(options): + if name == value.name: + default_index = i + break + + box_tuple = st.selectbox( + label, options=options, format_func=format_func, index=default_index, disabled=disabled + ) + enum_value = enum_class[box_tuple[0]] + panel.set_value_if_changed(key, enum_value) + return enum_value diff --git a/src/nipanel/controls/_flag_checkboxes.py b/src/nipanel/controls/_flag_checkboxes.py new file mode 100644 index 0000000..7887a53 --- /dev/null +++ b/src/nipanel/controls/_flag_checkboxes.py @@ -0,0 +1,77 @@ +"""A set of checkboxes for selecting Flag enum values.""" + +from enum import Flag +from typing import TypeVar, Callable, Optional + +import streamlit as st + +from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor + +TFlagType = TypeVar("TFlagType", bound=Flag) + + +def flag_checkboxes( + panel: StreamlitPanelValueAccessor, + label: str, + value: TFlagType, + key: str, + disabled_values: Optional[TFlagType] = None, + label_formatter: Callable[[Flag], str] = lambda x: str(x.name), +) -> TFlagType: + """Create a set of checkboxes for a Flag enum. + + This will display a checkbox for each individual flag value in the enum. When checkboxes + are selected or deselected, the combined Flag value will be stored in the panel under + the specified key. + + Args: + panel: The panel + label: Label to display above the checkboxes + value: The default Flag enum value (also determines the specific Flag enum type) + key: Key to use for storing the Flag value in the panel + disabled_values: A Flag enum value indicating which flags should be disabled. + If None or flag_type(0), no checkboxes are disabled. + label_formatter: Function that formats the flag to a string for display. Default + uses flag.name. + + Returns: + The selected Flag enum value with all selected flags combined + """ + flag_type = type(value) + if not issubclass(flag_type, Flag): + raise TypeError(f"Expected a Flag enum type, got {type(value)}") + + st.markdown(f"{label}:", unsafe_allow_html=True) + + # Get all individual flag values (skip composite values and zero value) + flag_values = [ + flag for flag in flag_type if flag.value & (flag.value - 1) == 0 and flag.value != 0 + ] + + # Create a container for flag checkboxes + flag_container = st.container(border=True) + + # Use the provided value as the initial state for selected flags + selected_flags = value + + # Create a checkbox for each flag + for flag in flag_values: + is_selected = bool(selected_flags & flag) + + is_disabled = False + if disabled_values is not None: + is_disabled = bool(disabled_values & flag) + + if flag_container.checkbox( + label=label_formatter(flag), + value=is_selected, + key=f"{key}_{flag.name}", + disabled=is_disabled, + ): + selected_flags |= flag + else: + selected_flags &= ~flag + + # Store the selected flags in the panel + panel.set_value_if_changed(key, selected_flags) + return selected_flags diff --git a/tests/types.py b/tests/types.py index 0d17f9b..fa5cb0e 100644 --- a/tests/types.py +++ b/tests/types.py @@ -21,6 +21,14 @@ class MyIntFlags(enum.IntFlag): VALUE4 = 4 +class MyIntableFlags(enum.Flag): + """Example of a simple flag with int values.""" + + VALUE8 = 8 + VALUE16 = 16 + VALUE32 = 32 + + class MyIntEnum(enum.IntEnum): """Example of an IntEnum enum.""" @@ -53,17 +61,25 @@ class MixinStrEnum(str, enum.Enum): VALUE33 = "value33" -class MyEnum(enum.Enum): - """Example of a simple enum.""" +class MyIntableEnum(enum.Enum): + """Example of a simple enum with int values.""" VALUE100 = 100 VALUE200 = 200 VALUE300 = 300 -class MyFlags(enum.Flag): - """Example of a simple flag.""" +class MyStringableEnum(StrEnum): + """Example of a simple enum with str values.""" - VALUE8 = 8 - VALUE16 = 16 - VALUE32 = 32 + VALUE1 = "value10" + VALUE2 = "value20" + VALUE3 = "value30" + + +class MyMixedEnum(enum.Enum): + """Example of an enum with mixed values.""" + + VALUE1 = "value1" + VALUE2 = 2 + VALUE3 = 3.0 diff --git a/tests/unit/test_streamlit_panel.py b/tests/unit/test_streamlit_panel.py index c1c3382..9b67bf6 100644 --- a/tests/unit/test_streamlit_panel.py +++ b/tests/unit/test_streamlit_panel.py @@ -1,3 +1,6 @@ +import enum +from datetime import datetime + import grpc import pytest from typing_extensions import assert_type @@ -242,6 +245,17 @@ def test___set_int_type___get_value_with_bool_default___raises_exception( panel.get_value(value_id, False) +def test___set_string_enum_type___get_value_with_int_enum_default___raises_exception( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + panel.set_value(value_id, test_types.MyStrEnum.VALUE3) + + with pytest.raises(ValueError): + panel.get_value(value_id, test_types.MyIntEnum.VALUE10) + + @pytest.mark.parametrize( "value_payload", [ @@ -250,31 +264,63 @@ def test___set_int_type___get_value_with_bool_default___raises_exception( 3.14, True, b"robotext", + ], +) +def test___builtin_scalar_type___set_value___gets_same_value( + fake_panel_channel: grpc.Channel, + value_payload: object, +) -> None: + """Test that set_value() and get_value() work for builtin scalar types.""" + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + + value_id = "test_id" + panel.set_value(value_id, value_payload) + + assert panel.get_value(value_id) == value_payload + + +@pytest.mark.parametrize( + "value_payload", + [ test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4, + test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32, test_types.MyIntEnum.VALUE20, + test_types.MyIntableEnum.VALUE200, test_types.MyStrEnum.VALUE3, + test_types.MyStringableEnum.VALUE2, test_types.MixinIntEnum.VALUE33, test_types.MixinStrEnum.VALUE11, + test_types.MyMixedEnum.VALUE2, ], ) -def test___builtin_scalar_type___set_value___gets_same_value( +def test___enum_type___set_value___gets_same_value( fake_panel_channel: grpc.Channel, - value_payload: object, + value_payload: enum.Enum, ) -> None: - """Test that set_value() and get_value() work for builtin scalar types.""" + """Test that set_value() and get_value() work for enum types.""" panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) value_id = "test_id" panel.set_value(value_id, value_payload) - assert panel.get_value(value_id) == value_payload + # without providing a default value, get_value will return the raw value, not the enum + assert panel.get_value(value_id) == value_payload.value @pytest.mark.parametrize( "value_payload", [ - test_types.MyEnum.VALUE300, - test_types.MyFlags.VALUE8 | test_types.MyFlags.VALUE16, + datetime.now(), + lambda x: x + 1, + [1, "string"], + ["string", []], + (42, "hello", 3.14, b"bytes"), + set([1, "mixed", True]), + (i for i in range(5)), + { + "key1": [1, 2, 3], + "key2": {"nested": True, "values": [4.5, 6.7]}, + }, ], ) def test___unsupported_type___set_value___raises( @@ -341,6 +387,117 @@ def test___sequence_of_builtin_type___set_value___gets_same_value( assert list(received_value) == list(value_payload) # type: ignore [call-overload] +def test___set_int_enum_value___get_value___returns_int_enum( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + enum_value = test_types.MyIntEnum.VALUE20 + panel.set_value(value_id, enum_value) + + retrieved_value = panel.get_value(value_id, test_types.MyIntEnum.VALUE10) + + assert_type(retrieved_value, test_types.MyIntEnum) + assert retrieved_value is test_types.MyIntEnum.VALUE20 + assert retrieved_value.value == enum_value.value + assert retrieved_value.name == enum_value.name + + +def test___set_intable_enum_value___get_value___returns_intable_enum( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + enum_value = test_types.MyIntableEnum.VALUE200 + panel.set_value(value_id, enum_value) + + retrieved_value = panel.get_value(value_id, test_types.MyIntableEnum.VALUE100) + + assert_type(retrieved_value, test_types.MyIntableEnum) + assert retrieved_value is test_types.MyIntableEnum.VALUE200 + assert retrieved_value.value == enum_value.value + assert retrieved_value.name == enum_value.name + + +def test___set_string_enum_value___get_value___returns_string_enum( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + enum_value = test_types.MyStrEnum.VALUE3 + panel.set_value(value_id, enum_value) + + retrieved_value = panel.get_value(value_id, test_types.MyStrEnum.VALUE1) + + assert_type(retrieved_value, test_types.MyStrEnum) + assert retrieved_value is test_types.MyStrEnum.VALUE3 + assert retrieved_value.value == enum_value.value + assert retrieved_value.name == enum_value.name + + +def test___set_stringable_enum_value___get_value___returns_stringable_enum( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + enum_value = test_types.MyStringableEnum.VALUE3 + panel.set_value(value_id, enum_value) + + retrieved_value = panel.get_value(value_id, test_types.MyStringableEnum.VALUE1) + + assert_type(retrieved_value, test_types.MyStringableEnum) + assert retrieved_value is test_types.MyStringableEnum.VALUE3 + assert retrieved_value.value == enum_value.value + assert retrieved_value.name == enum_value.name + + +def test___set_mixed_enum_value___get_value___returns_mixed_enum( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + enum_value = test_types.MyMixedEnum.VALUE2 + panel.set_value(value_id, enum_value) + + retrieved_value = panel.get_value(value_id, test_types.MyMixedEnum.VALUE1) + + assert_type(retrieved_value, test_types.MyMixedEnum) + assert retrieved_value is test_types.MyMixedEnum.VALUE2 + assert retrieved_value.value == enum_value.value + assert retrieved_value.name == enum_value.name + + +def test___set_int_flags_value___get_value___returns_int_flags( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + flags_value = test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4 + panel.set_value(value_id, flags_value) + + retrieved_value = panel.get_value(value_id, test_types.MyIntFlags.VALUE2) + + assert_type(retrieved_value, test_types.MyIntFlags) + assert retrieved_value == (test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4) + assert retrieved_value.value == flags_value.value + + +def test___set_intable_flags_value___get_value___returns_intable_flags( + fake_panel_channel: grpc.Channel, +) -> None: + panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel) + value_id = "test_id" + flags_value = test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32 + panel.set_value(value_id, flags_value) + + retrieved_value = panel.get_value(value_id, test_types.MyIntableFlags.VALUE8) + + assert_type(retrieved_value, test_types.MyIntableFlags) + assert retrieved_value is test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32 + assert retrieved_value.value == flags_value.value + assert retrieved_value.name == flags_value.name + + def test___panel___panel_is_running_and_in_memory( fake_panel_channel: grpc.Channel, ) -> None: diff --git a/tests/unit/test_streamlit_panel_value_accessor.py b/tests/unit/test_streamlit_panel_value_accessor.py new file mode 100644 index 0000000..4a845b6 --- /dev/null +++ b/tests/unit/test_streamlit_panel_value_accessor.py @@ -0,0 +1,110 @@ +import grpc + +from nipanel import StreamlitPanelValueAccessor +from tests.types import MyIntEnum +from tests.utils._fake_python_panel_service import FakePythonPanelService + + +def test___no_previous_value___set_value_if_changed___sets_value( + fake_panel_channel: grpc.Channel, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + + accessor.set_value_if_changed("test_id", "test_value") + + assert accessor.get_value("test_id") == "test_value" + + +def test___set_value_if_changed___set_same_value___does_not_set_value_again( + fake_panel_channel: grpc.Channel, + fake_python_panel_service: FakePythonPanelService, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", "test_value") + initial_set_count = fake_python_panel_service.servicer.set_count + + accessor.set_value_if_changed("test_id", "test_value") + + assert fake_python_panel_service.servicer.set_count == initial_set_count + assert accessor.get_value("test_id") == "test_value" + + +def test___set_value_if_changed___set_different_value___sets_new_value( + fake_panel_channel: grpc.Channel, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", "test_value") + + accessor.set_value_if_changed("test_id", "new_value") + + assert accessor.get_value("test_id") == "new_value" + + +def test___set_value_if_changed___different_value_ids___tracks_separately( + fake_panel_channel: grpc.Channel, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("id1", "value1") + accessor.set_value_if_changed("id2", "value2") + + accessor.set_value_if_changed("id1", "value1") + accessor.set_value_if_changed("id2", "new_value2") + + assert accessor.get_value("id1") == "value1" + assert accessor.get_value("id2") == "new_value2" + + +def test___set_value_if_changed_with_list_value___set_same_value___does_not_set_value_again( + fake_panel_channel: grpc.Channel, + fake_python_panel_service: FakePythonPanelService, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", [1, 2, 3]) + initial_set_count = fake_python_panel_service.servicer.set_count + + accessor.set_value_if_changed("test_id", [1, 2, 3]) + + assert fake_python_panel_service.servicer.set_count == initial_set_count + assert accessor.get_value("test_id") == [1, 2, 3] + + +def test___set_value_if_changed_with_list_value___set_different_value___sets_new_value( + fake_panel_channel: grpc.Channel, + fake_python_panel_service: FakePythonPanelService, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", [1, 2, 3]) + initial_set_count = fake_python_panel_service.servicer.set_count + + accessor.set_value_if_changed("test_id", [1, 2, 4]) + + assert fake_python_panel_service.servicer.set_count > initial_set_count + assert accessor.get_value("test_id") == [1, 2, 4] + + +def test___set_value_if_changed_with_enum_value___set_same_value___does_not_set_value_again( + fake_panel_channel: grpc.Channel, + fake_python_panel_service: FakePythonPanelService, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20) + initial_set_count = fake_python_panel_service.servicer.set_count + + accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20) + + assert fake_python_panel_service.servicer.set_count == initial_set_count + assert accessor.get_value("test_id") == 20 # Enums are stored as their values + + +def test___set_value_if_changed_with_enum_value___set_different_value___sets_new_value( + fake_panel_channel: grpc.Channel, + fake_python_panel_service: FakePythonPanelService, +) -> None: + accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel) + accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20) + initial_set_count = fake_python_panel_service.servicer.set_count + + accessor.set_value_if_changed("test_id", MyIntEnum.VALUE30) + + assert fake_python_panel_service.servicer.set_count > initial_set_count + assert accessor.get_value("test_id") == 30 # New enum value should be set diff --git a/tests/utils/_fake_python_panel_servicer.py b/tests/utils/_fake_python_panel_servicer.py index e6d59ee..ac74fe6 100644 --- a/tests/utils/_fake_python_panel_servicer.py +++ b/tests/utils/_fake_python_panel_servicer.py @@ -26,6 +26,7 @@ def __init__(self) -> None: self._panel_is_running: dict[str, bool] = {} self._panel_value_ids: dict[str, dict[str, Any]] = {} self._fail_next_start_panel = False + self._set_count: int = 0 self._notification_count: int = 0 self._python_path: str = "" @@ -70,6 +71,7 @@ def SetValue(self, request: SetValueRequest, context: Any) -> SetValueResponse: """Trivial implementation for testing.""" self._init_panel(request.panel_id) self._panel_value_ids[request.panel_id][request.value_id] = request.value + self._set_count += 1 if request.notify: self._notification_count += 1 return SetValueResponse() @@ -78,6 +80,11 @@ def fail_next_start_panel(self) -> None: """Set whether the StartPanel method should fail the next time it is called.""" self._fail_next_start_panel = True + @property + def set_count(self) -> int: + """Get the total number of times SetValue was called.""" + return self._set_count + @property def notification_count(self) -> int: """Get the number of notifications sent from SetValue."""