diff --git a/examples/all_types/all_types_panel.py b/examples/all_types/all_types_panel.py
index be5b0a0..b2e046b 100644
--- a/examples/all_types/all_types_panel.py
+++ b/examples/all_types/all_types_panel.py
@@ -1,9 +1,12 @@
"""A Streamlit visualization panel for the all_types.py example script."""
+from enum import Enum, Flag
+
import streamlit as st
from define_types import all_types_with_values
import nipanel
+from nipanel.controls import enum_selectbox, flag_checkboxes
st.set_page_config(page_title="All Types Example", page_icon="đ", layout="wide")
@@ -11,10 +14,27 @@
panel = nipanel.get_panel_accessor()
for name in all_types_with_values.keys():
- col1, col2 = st.columns([0.4, 0.6])
+ st.markdown("---")
+
+ default_value = all_types_with_values[name]
+ col1, col2, col3 = st.columns([0.2, 0.2, 0.6])
with col1:
st.write(name)
with col2:
- st.write(panel.get_value(name))
+ if isinstance(default_value, bool):
+ st.checkbox(label=name, value=default_value, key=name)
+ elif isinstance(default_value, Flag):
+ flag_checkboxes(panel, label=name, value=default_value, key=name)
+ elif isinstance(default_value, Enum) and not isinstance(default_value, Flag):
+ enum_selectbox(panel, label=name, value=default_value, key=name)
+ elif isinstance(default_value, int) and not isinstance(default_value, Flag):
+ st.number_input(label=name, value=default_value, key=name)
+ elif isinstance(default_value, float):
+ st.number_input(label=name, value=default_value, key=name, format="%.2f")
+ elif isinstance(default_value, str):
+ st.text_input(label=name, value=default_value, key=name)
+
+ with col3:
+ st.write(panel.get_value(name, default_value=default_value))
diff --git a/examples/all_types/define_types.py b/examples/all_types/define_types.py
index 5d5906e..33c8c47 100644
--- a/examples/all_types/define_types.py
+++ b/examples/all_types/define_types.py
@@ -15,6 +15,14 @@ class MyIntFlags(enum.IntFlag):
VALUE4 = 4
+class MyIntableFlags(enum.Flag):
+ """Example of an Flag enum with int values."""
+
+ VALUE8 = 8
+ VALUE16 = 16
+ VALUE32 = 32
+
+
class MyIntEnum(enum.IntEnum):
"""Example of an IntEnum enum."""
@@ -23,6 +31,14 @@ class MyIntEnum(enum.IntEnum):
VALUE30 = 30
+class MyIntableEnum(enum.Enum):
+ """Example of an enum with int values."""
+
+ VALUE100 = 100
+ VALUE200 = 200
+ VALUE300 = 300
+
+
class MyStrEnum(str, enum.Enum):
"""Example of a mixin string enum."""
@@ -31,6 +47,22 @@ class MyStrEnum(str, enum.Enum):
VALUE3 = "value3"
+class MyStringableEnum(enum.Enum):
+ """Example of an enum with string values."""
+
+ VALUE1 = "value1"
+ VALUE2 = "value2"
+ VALUE3 = "value3"
+
+
+class MyMixedEnum(enum.Enum):
+ """Example of an enum with mixed values."""
+
+ VALUE1 = "value1"
+ VALUE2 = 2
+ VALUE3 = 3.0
+
+
all_types_with_values = {
# supported scalar types
"bool": True,
@@ -38,16 +70,23 @@ class MyStrEnum(str, enum.Enum):
"float": 13.12,
"int": 42,
"str": "sample string",
+ # supported enum and flag types
+ "intflags": MyIntFlags.VALUE1 | MyIntFlags.VALUE4,
+ "intenum": MyIntEnum.VALUE20,
+ "strenum": MyStrEnum.VALUE3,
+ "intableenum": MyIntableEnum.VALUE200,
+ "intableflags": MyIntableFlags.VALUE8 | MyIntableFlags.VALUE32,
+ "stringableenum": MyStringableEnum.VALUE2,
+ "mixedenum": MyMixedEnum.VALUE2,
+ # NI types
+ "nitypes_Scalar": Scalar(42, "m"),
+ "nitypes_AnalogWaveform": AnalogWaveform.from_array_1d(np.array([1.0, 2.0, 3.0])),
# supported collection types
"bool_collection": [True, False, True],
"bytes_collection": [b"one", b"two", b"three"],
"float_collection": [1.1, 2.2, 3.3],
"int_collection": [1, 2, 3],
"str_collection": ["one", "two", "three"],
- # supported enum and flag types
- "intflags": MyIntFlags.VALUE1 | MyIntFlags.VALUE4,
- "intenum": MyIntEnum.VALUE20,
- "strenum": MyStrEnum.VALUE3,
"intflags_collection": [MyIntFlags.VALUE1, MyIntFlags.VALUE2, MyIntFlags.VALUE4],
"intenum_collection": [MyIntEnum.VALUE10, MyIntEnum.VALUE20, MyIntEnum.VALUE30],
"strenum_collection": [MyStrEnum.VALUE1, MyStrEnum.VALUE2, MyStrEnum.VALUE3],
@@ -56,9 +95,6 @@ class MyStrEnum(str, enum.Enum):
"tuple": (4, 5, 6),
"set": {7, 8, 9},
"frozenset": frozenset([10, 11, 12]),
- # NI types
- "nitypes_Scalar": Scalar(42, "m"),
- "nitypes_AnalogWaveform": AnalogWaveform.from_array_1d(np.array([1.0, 2.0, 3.0])),
# supported 2D collections
"list_list_float": [[1.0, 2.0], [3.0, 4.0]],
"tuple_tuple_float": ((1.0, 2.0), (3.0, 4.0)),
diff --git a/examples/nidaqmx/nidaqmx_continuous_analog_input.py b/examples/nidaqmx/nidaqmx_continuous_analog_input.py
index a00b768..bed43eb 100644
--- a/examples/nidaqmx/nidaqmx_continuous_analog_input.py
+++ b/examples/nidaqmx/nidaqmx_continuous_analog_input.py
@@ -1,34 +1,83 @@
"""Data acquisition script that continuously acquires analog input data."""
+import time
from pathlib import Path
import nidaqmx
-from nidaqmx.constants import AcquisitionType
+from nidaqmx.constants import (
+ AcquisitionType,
+ TerminalConfiguration,
+ CJCSource,
+ TemperatureUnits,
+ ThermocoupleType,
+ LoggingMode,
+ LoggingOperation,
+)
import nipanel
panel_script_path = Path(__file__).with_name("nidaqmx_continuous_analog_input_panel.py")
panel = nipanel.create_panel(panel_script_path)
-# How to use nidaqmx: https://nidaqmx-python.readthedocs.io/en/stable/
-with nidaqmx.Task() as task:
- task.ai_channels.add_ai_voltage_chan("Dev1/ai0")
- task.ai_channels.add_ai_thrmcpl_chan("Dev1/ai1")
- task.timing.cfg_samp_clk_timing(
- rate=1000.0, sample_mode=AcquisitionType.CONTINUOUS, samps_per_chan=3000
- )
- panel.set_value("sample_rate", task._timing.samp_clk_rate)
- task.start()
+try:
print(f"Panel URL: {panel.panel_url}")
- try:
- print(f"Press Ctrl + C to stop")
- while True:
- data = task.read(
- number_of_samples_per_channel=1000 # pyright: ignore[reportArgumentType]
+ print(f"Waiting for the 'Run' button to be pressed...")
+ print(f"(Press Ctrl + C to quit)")
+ while True:
+ panel.set_value("run_button", False)
+ while not panel.get_value("run_button", False):
+ time.sleep(0.1)
+
+ # How to use nidaqmx: https://nidaqmx-python.readthedocs.io/en/stable/
+ with nidaqmx.Task() as task:
+ task.ai_channels.add_ai_voltage_chan(
+ physical_channel="Dev1/ai0",
+ min_val=panel.get_value("voltage_min_value", -5.0),
+ max_val=panel.get_value("voltage_max_value", 5.0),
+ terminal_config=panel.get_value(
+ "terminal_configuration", TerminalConfiguration.DEFAULT
+ ),
+ )
+ task.ai_channels.add_ai_thrmcpl_chan(
+ "Dev1/ai1",
+ min_val=panel.get_value("thermocouple_min_value", 0.0),
+ max_val=panel.get_value("thermocouple_max_value", 100.0),
+ units=panel.get_value("thermocouple_units", TemperatureUnits.DEG_C),
+ thermocouple_type=panel.get_value("thermocouple_type", ThermocoupleType.K),
+ cjc_source=panel.get_value(
+ "thermocouple_cjc_source", CJCSource.CONSTANT_USER_VALUE
+ ),
+ cjc_val=panel.get_value("thermocouple_cjc_val", 25.0),
+ )
+ task.timing.cfg_samp_clk_timing(
+ rate=panel.get_value("sample_rate_input", 1000.0),
+ sample_mode=AcquisitionType.CONTINUOUS,
+ samps_per_chan=panel.get_value("samples_per_channel", 3000),
)
- panel.set_value("voltage_data", data[0])
- panel.set_value("thermocouple_data", data[1])
- except KeyboardInterrupt:
- pass
- finally:
- task.stop()
+ task.in_stream.configure_logging(
+ file_path=panel.get_value("tdms_file_path", "data.tdms"),
+ logging_mode=panel.get_value("logging_mode", LoggingMode.OFF),
+ operation=LoggingOperation.OPEN_OR_CREATE,
+ )
+ panel.set_value("sample_rate", task._timing.samp_clk_rate)
+ try:
+ print(f"Starting data acquisition...")
+ task.start()
+ panel.set_value("is_running", True)
+
+ panel.set_value("stop_button", False)
+ while not panel.get_value("stop_button", False):
+ data = task.read(
+ number_of_samples_per_channel=1000 # pyright: ignore[reportArgumentType]
+ )
+ panel.set_value("voltage_data", data[0])
+ panel.set_value("thermocouple_data", data[1])
+ except KeyboardInterrupt:
+ raise
+ finally:
+ print(f"Stopping data acquisition...")
+ task.stop()
+ panel.set_value("is_running", False)
+
+except KeyboardInterrupt:
+ pass
diff --git a/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py b/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py
index e9163b0..59e108e 100644
--- a/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py
+++ b/examples/nidaqmx/nidaqmx_continuous_analog_input_panel.py
@@ -1,20 +1,33 @@
"""Streamlit visualization script to display data acquired by nidaqmx_continuous_analog_input.py."""
import streamlit as st
+from nidaqmx.constants import (
+ TerminalConfiguration,
+ CJCSource,
+ TemperatureUnits,
+ ThermocoupleType,
+ LoggingMode,
+)
from streamlit_echarts import st_echarts
import nipanel
+from nipanel.controls import enum_selectbox
st.set_page_config(page_title="NI-DAQmx Example", page_icon="đ", layout="wide")
st.title("Analog Input - Voltage and Thermocouple in a Single Task")
-voltage_tab, thermocouple_tab = st.tabs(["Voltage", "Thermocouple"])
st.markdown(
"""
""",
@@ -22,82 +35,203 @@
)
panel = nipanel.get_panel_accessor()
+is_running = panel.get_value("is_running", False)
+
+if is_running:
+ st.button(r"âšī¸ Stop", key="stop_button")
+else:
+ st.button(r"âļī¸ Run", key="run_button")
+
thermocouple_data = panel.get_value("thermocouple_data", [0.0])
voltage_data = panel.get_value("voltage_data", [0.0])
-
sample_rate = panel.get_value("sample_rate", 0.0)
-st.header("Voltage & Thermocouple")
-voltage_therm_graph = {
- "animation": False,
- "tooltip": {"trigger": "axis"},
- "legend": {"data": ["Voltage (V)", "Temperature (C)"]},
- "xAxis": {
- "type": "category",
- "data": [x / sample_rate for x in range(len(voltage_data))],
- "name": "Time",
- "nameLocation": "center",
- "nameGap": 40,
- },
- "yAxis": {
- "type": "value",
- "name": "Measurement",
- "nameRotate": 90,
- "nameLocation": "center",
- "nameGap": 40,
- },
- "series": [
- {
- "name": "voltage_amplitude",
- "type": "line",
- "data": voltage_data,
- "emphasis": {"focus": "series"},
- "smooth": True,
- "seriesLayoutBy": "row",
- },
- {
- "name": "thermocouple_amplitude",
- "type": "line",
- "data": thermocouple_data,
- "color": "red",
- "emphasis": {"focus": "series"},
- "smooth": True,
- "seriesLayoutBy": "row",
- },
- ],
-}
-st_echarts(options=voltage_therm_graph, height="400px", key="voltage_therm_graph")
+# Create two-column layout for the entire interface
+left_column, right_column = st.columns([1, 1])
+
+# Left column - Channel tabs and Timing Settings
+with left_column:
+ # Channel Settings tabs
+ with st.container(border=True):
+ st.header("Channel Settings")
+ voltage_tab, thermocouple_tab = st.tabs(["Voltage", "Thermocouple"])
+
+ voltage_tab.header("Voltage")
+ with voltage_tab:
+ channel_left_column, channel_right_column = st.columns(2)
+ with channel_left_column:
+ st.selectbox(options=["Dev1/ai0"], label="Physical Channels", disabled=True)
+ st.number_input(
+ "Min Value",
+ value=-5.0,
+ step=0.1,
+ disabled=panel.get_value("is_running", False),
+ key="voltage_min_value",
+ )
+ st.number_input(
+ "Max Value",
+ value=5.0,
+ step=0.1,
+ disabled=panel.get_value("is_running", False),
+ key="voltage_max_value",
+ )
+ with channel_right_column:
+ enum_selectbox(
+ panel,
+ label="Terminal Configuration",
+ value=TerminalConfiguration.DEFAULT,
+ disabled=panel.get_value("is_running", False),
+ key="terminal_configuration",
+ )
-voltage_tab.header("Voltage")
-with voltage_tab:
- left_volt_tab, center_volt_tab, right_volt_tab = st.columns(3)
- with left_volt_tab:
- st.selectbox(options=["Dev1/ai0"], label="Physical Channels", disabled=True)
- st.selectbox(options=["Off"], label="Logging Modes", disabled=False)
- with center_volt_tab:
- st.selectbox(options=["-5"], label="Min Value")
- st.selectbox(options=["5"], label="Max Value")
- st.selectbox(options=["1000"], label="Samples per Loops", disabled=False)
- with right_volt_tab:
- st.selectbox(options=["default"], label="Terminal Configurations")
- st.selectbox(options=["OnboardClock"], label="Sample Clock Sources", disabled=False)
+ thermocouple_tab.header("Thermocouple")
+ with thermocouple_tab:
+ channel_left_column, channel_middle_column, channel_right_column = st.columns(3)
+ with channel_left_column:
+ st.selectbox(options=["Dev1/ai1"], label="Physical Channel", disabled=True)
+ st.number_input(
+ "Min Value",
+ value=0.0,
+ step=1.0,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_min_value",
+ )
+ st.number_input(
+ "Max Value",
+ value=100.0,
+ step=1.0,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_max_value",
+ )
+ with channel_middle_column:
+ enum_selectbox(
+ panel,
+ label="Units",
+ value=TemperatureUnits.DEG_C,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_units",
+ )
+ enum_selectbox(
+ panel,
+ label="Thermocouple Type",
+ value=ThermocoupleType.K,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_type",
+ )
+ with channel_right_column:
+ enum_selectbox(
+ panel,
+ label="CJC Source",
+ value=CJCSource.CONSTANT_USER_VALUE,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_cjc_source",
+ )
+ st.number_input(
+ "CJC Value",
+ value=25.0,
+ step=1.0,
+ disabled=panel.get_value("is_running", False),
+ key="thermocouple_cjc_val",
+ )
+ # Timing Settings section in left column
+ with st.container(border=True):
+ st.header("Timing Settings")
+ timing_left_column, timing_right_column = st.columns(2)
+ with timing_left_column:
+ st.selectbox(
+ options=["OnboardClock"],
+ label="Sample Clock Source",
+ disabled=True,
+ )
+ st.number_input(
+ "Sample Rate",
+ value=1000.0,
+ step=100.0,
+ min_value=1.0,
+ disabled=panel.get_value("is_running", False),
+ key="sample_rate_input",
+ )
+ with timing_right_column:
+ st.number_input(
+ "Samples per Loop",
+ value=3000,
+ step=100,
+ min_value=10,
+ disabled=panel.get_value("is_running", False),
+ key="samples_per_channel",
+ )
+ st.text_input(
+ label="Actual Sample Rate",
+ value=str(sample_rate) if sample_rate else "",
+ key="actual_sample_rate_display",
+ )
-thermocouple_tab.header("Thermocouple")
-with thermocouple_tab:
- left, middle, right = st.columns(3)
- with left:
- st.selectbox(options=["Dev1/ai1"], label="Physical Channel", disabled=True)
- st.selectbox(options=["0"], label="Min", disabled=False)
- st.selectbox(options=["100"], label="Max", disabled=False)
- st.selectbox(options=["Off"], label="Logging Mode", disabled=False)
+# Right column - Graph and Logging Settings
+with right_column:
+ with st.container(border=True):
+ # Graph section
+ st.header("Voltage & Thermocouple")
+ voltage_therm_graph = {
+ "animation": False,
+ "tooltip": {"trigger": "axis"},
+ "legend": {"data": ["Voltage (V)", "Temperature (C)"]},
+ "xAxis": {
+ "type": "category",
+ "data": [
+ x / sample_rate if sample_rate > 0.001 else x for x in range(len(voltage_data))
+ ],
+ "name": "Time",
+ "nameLocation": "center",
+ "nameGap": 40,
+ },
+ "yAxis": {
+ "type": "value",
+ "name": "Measurement",
+ "nameRotate": 90,
+ "nameLocation": "center",
+ "nameGap": 40,
+ },
+ "series": [
+ {
+ "name": "voltage_amplitude",
+ "type": "line",
+ "data": voltage_data,
+ "emphasis": {"focus": "series"},
+ "smooth": True,
+ "seriesLayoutBy": "row",
+ },
+ {
+ "name": "thermocouple_amplitude",
+ "type": "line",
+ "data": thermocouple_data,
+ "color": "red",
+ "emphasis": {"focus": "series"},
+ "smooth": True,
+ "seriesLayoutBy": "row",
+ },
+ ],
+ }
+ st_echarts(options=voltage_therm_graph, height="446px", key="voltage_therm_graph")
- with middle:
- st.selectbox(options=["Deg C"], label="Units", disabled=False)
- st.selectbox(options=["J"], label="Thermocouple Type", disabled=False)
- st.selectbox(options=["Constant Value"], label="CJC Source", disabled=False)
- st.selectbox(options=["1000"], label="Samples per Loop", disabled=False)
- with right:
- st.selectbox(options=["25"], label="CJC Value", disabled=False)
- st.selectbox(options=["OnboardClock"], label="Sample Clock Source", disabled=False)
- st.selectbox(options=[" "], label="Actual Sample Rate", disabled=True)
+ # Logging Settings section in right column
+ with st.container(border=True):
+ st.header("Logging Settings")
+ logging_left_column, logging_right_column = st.columns(2)
+ with logging_left_column:
+ enum_selectbox(
+ panel,
+ label="Logging Mode",
+ value=LoggingMode.OFF,
+ disabled=panel.get_value("is_running", False),
+ key="logging_mode",
+ )
+ with logging_right_column:
+ left_sub_column, right_sub_column = st.columns([3, 1])
+ with left_sub_column:
+ tdms_file_path = st.text_input(
+ label="TDMS File Path",
+ disabled=panel.get_value("is_running", False),
+ value="data.tdms",
+ key="tdms_file_path",
+ )
diff --git a/src/nipanel/_convert.py b/src/nipanel/_convert.py
index 5bbc6d0..61a4858 100644
--- a/src/nipanel/_convert.py
+++ b/src/nipanel/_convert.py
@@ -129,3 +129,12 @@ def from_any(protobuf_any: any_pb2.Any) -> object:
converter = _CONVERTER_FOR_GRPC_TYPE[underlying_typename]
return converter.to_python(protobuf_any)
+
+
+def is_supported_type(value: object) -> bool:
+ """Check if a given Python value can be converted to protobuf Any."""
+ try:
+ _get_best_matching_type(value)
+ return True
+ except TypeError:
+ return False
diff --git a/src/nipanel/_panel_value_accessor.py b/src/nipanel/_panel_value_accessor.py
index d60f7d7..cb53d56 100644
--- a/src/nipanel/_panel_value_accessor.py
+++ b/src/nipanel/_panel_value_accessor.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+import collections
+import enum
from abc import ABC
from typing import TypeVar, overload
@@ -15,7 +17,13 @@
class PanelValueAccessor(ABC):
"""This class allows you to access values for a panel's controls."""
- __slots__ = ["_panel_client", "_panel_id", "_notify_on_set_value", "__weakref__"]
+ __slots__ = [
+ "_panel_client",
+ "_panel_id",
+ "_notify_on_set_value",
+ "_last_values",
+ "__weakref__",
+ ]
def __init__(
self,
@@ -38,6 +46,9 @@ def __init__(
)
self._panel_id = panel_id
self._notify_on_set_value = notify_on_set_value
+ self._last_values: collections.defaultdict[str, object] = collections.defaultdict(
+ lambda: object()
+ )
@property
def panel_id(self) -> str:
@@ -67,7 +78,16 @@ def get_value(self, value_id: str, default_value: _T | None = None) -> _T | obje
raise KeyError(f"Value with id '{value_id}' not found on panel '{self._panel_id}'.")
if default_value is not None and not isinstance(value, type(default_value)):
- raise TypeError("Value type does not match default value type.")
+ if isinstance(default_value, enum.Enum):
+ enum_type = type(default_value)
+ return enum_type(value)
+
+ # lists are allowed to not match, since sets and tuples are converted to lists
+ if not isinstance(value, list):
+ raise TypeError(
+ f"Value type {type(value).__name__} does not match default value type {type(default_value).__name__}."
+ )
+
return value
def set_value(self, value_id: str, value: object) -> None:
@@ -77,6 +97,22 @@ def set_value(self, value_id: str, value: object) -> None:
value_id: The id of the value
value: The value
"""
+ if isinstance(value, enum.Enum):
+ value = value.value
+
self._panel_client.set_value(
self._panel_id, value_id, value, notify=self._notify_on_set_value
)
+ self._last_values[value_id] = value
+
+ def set_value_if_changed(self, value_id: str, value: object) -> None:
+ """Set the value for a control on the panel only if it has changed since the last call.
+
+ This method helps reduce unnecessary updates when the value hasn't changed.
+
+ Args:
+ value_id: The id of the value
+ value: The value to set
+ """
+ if value != self._last_values[value_id]:
+ self.set_value(value_id, value)
diff --git a/src/nipanel/_streamlit_panel_initializer.py b/src/nipanel/_streamlit_panel_initializer.py
index 7dee3fd..a05822d 100644
--- a/src/nipanel/_streamlit_panel_initializer.py
+++ b/src/nipanel/_streamlit_panel_initializer.py
@@ -3,6 +3,7 @@
import streamlit as st
+from nipanel._convert import is_supported_type
from nipanel._streamlit_panel import StreamlitPanel
from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor
from nipanel.streamlit_refresh import initialize_refresh_component
@@ -61,6 +62,7 @@ def get_panel_accessor() -> StreamlitPanelValueAccessor:
st.session_state[PANEL_ACCESSOR_KEY] = _initialize_panel_from_base_path()
panel = cast(StreamlitPanelValueAccessor, st.session_state[PANEL_ACCESSOR_KEY])
+ _sync_session_state(panel)
refresh_component = initialize_refresh_component(panel.panel_id)
refresh_component()
return panel
@@ -75,3 +77,11 @@ def _initialize_panel_from_base_path() -> StreamlitPanelValueAccessor:
if not panel_id:
raise ValueError(f"Panel ID is empty in baseUrlPath: '{base_url_path}'")
return StreamlitPanelValueAccessor(panel_id)
+
+
+def _sync_session_state(panel: StreamlitPanelValueAccessor) -> None:
+ """Automatically read keyed control values from the session state."""
+ for key in st.session_state.keys():
+ value = st.session_state[key]
+ if is_supported_type(value):
+ panel.set_value_if_changed(str(key), value)
diff --git a/src/nipanel/controls/__init__.py b/src/nipanel/controls/__init__.py
new file mode 100644
index 0000000..6882af0
--- /dev/null
+++ b/src/nipanel/controls/__init__.py
@@ -0,0 +1,6 @@
+"""Controls for nipanel."""
+
+from nipanel.controls._enum_selectbox import enum_selectbox
+from nipanel.controls._flag_checkboxes import flag_checkboxes
+
+__all__ = ["enum_selectbox", "flag_checkboxes"]
diff --git a/src/nipanel/controls/_enum_selectbox.py b/src/nipanel/controls/_enum_selectbox.py
new file mode 100644
index 0000000..a75277c
--- /dev/null
+++ b/src/nipanel/controls/_enum_selectbox.py
@@ -0,0 +1,53 @@
+"""A selectbox that allows selecting an Enum value."""
+
+from enum import Enum
+from typing import Any, Callable, TypeVar
+
+import streamlit as st
+
+from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor
+
+TEnumType = TypeVar("TEnumType", bound=Enum)
+
+
+def enum_selectbox(
+ panel: StreamlitPanelValueAccessor,
+ label: str,
+ value: TEnumType,
+ key: str,
+ disabled: bool = False,
+ format_func: Callable[[Any], str] = lambda x: x[0],
+) -> TEnumType:
+ """Create a selectbox for an Enum.
+
+ The selectbox will display the names of all the enum values, and when a value is selected,
+ that value will be stored in the panel under the specified key.
+
+ Args:
+ panel: The panel
+ label: Label to display for the selectbox
+ value: The default enum value to select (also determines the specific enum type)
+ key: Key to use for storing the enum value in the panel
+
+ Returns:
+ The selected enum value of the same specific enum subclass as the input value
+ """
+ enum_class = type(value)
+ if not issubclass(enum_class, Enum):
+ raise TypeError(f"Expected an Enum type, got {type(value)}")
+
+ options = [(e.name, e.value) for e in enum_class]
+
+ default_index = 0
+ if value is not None:
+ for i, (name, _) in enumerate(options):
+ if name == value.name:
+ default_index = i
+ break
+
+ box_tuple = st.selectbox(
+ label, options=options, format_func=format_func, index=default_index, disabled=disabled
+ )
+ enum_value = enum_class[box_tuple[0]]
+ panel.set_value_if_changed(key, enum_value)
+ return enum_value
diff --git a/src/nipanel/controls/_flag_checkboxes.py b/src/nipanel/controls/_flag_checkboxes.py
new file mode 100644
index 0000000..7887a53
--- /dev/null
+++ b/src/nipanel/controls/_flag_checkboxes.py
@@ -0,0 +1,77 @@
+"""A set of checkboxes for selecting Flag enum values."""
+
+from enum import Flag
+from typing import TypeVar, Callable, Optional
+
+import streamlit as st
+
+from nipanel._streamlit_panel_value_accessor import StreamlitPanelValueAccessor
+
+TFlagType = TypeVar("TFlagType", bound=Flag)
+
+
+def flag_checkboxes(
+ panel: StreamlitPanelValueAccessor,
+ label: str,
+ value: TFlagType,
+ key: str,
+ disabled_values: Optional[TFlagType] = None,
+ label_formatter: Callable[[Flag], str] = lambda x: str(x.name),
+) -> TFlagType:
+ """Create a set of checkboxes for a Flag enum.
+
+ This will display a checkbox for each individual flag value in the enum. When checkboxes
+ are selected or deselected, the combined Flag value will be stored in the panel under
+ the specified key.
+
+ Args:
+ panel: The panel
+ label: Label to display above the checkboxes
+ value: The default Flag enum value (also determines the specific Flag enum type)
+ key: Key to use for storing the Flag value in the panel
+ disabled_values: A Flag enum value indicating which flags should be disabled.
+ If None or flag_type(0), no checkboxes are disabled.
+ label_formatter: Function that formats the flag to a string for display. Default
+ uses flag.name.
+
+ Returns:
+ The selected Flag enum value with all selected flags combined
+ """
+ flag_type = type(value)
+ if not issubclass(flag_type, Flag):
+ raise TypeError(f"Expected a Flag enum type, got {type(value)}")
+
+ st.markdown(f"{label}:", unsafe_allow_html=True)
+
+ # Get all individual flag values (skip composite values and zero value)
+ flag_values = [
+ flag for flag in flag_type if flag.value & (flag.value - 1) == 0 and flag.value != 0
+ ]
+
+ # Create a container for flag checkboxes
+ flag_container = st.container(border=True)
+
+ # Use the provided value as the initial state for selected flags
+ selected_flags = value
+
+ # Create a checkbox for each flag
+ for flag in flag_values:
+ is_selected = bool(selected_flags & flag)
+
+ is_disabled = False
+ if disabled_values is not None:
+ is_disabled = bool(disabled_values & flag)
+
+ if flag_container.checkbox(
+ label=label_formatter(flag),
+ value=is_selected,
+ key=f"{key}_{flag.name}",
+ disabled=is_disabled,
+ ):
+ selected_flags |= flag
+ else:
+ selected_flags &= ~flag
+
+ # Store the selected flags in the panel
+ panel.set_value_if_changed(key, selected_flags)
+ return selected_flags
diff --git a/tests/types.py b/tests/types.py
index 0d17f9b..fa5cb0e 100644
--- a/tests/types.py
+++ b/tests/types.py
@@ -21,6 +21,14 @@ class MyIntFlags(enum.IntFlag):
VALUE4 = 4
+class MyIntableFlags(enum.Flag):
+ """Example of a simple flag with int values."""
+
+ VALUE8 = 8
+ VALUE16 = 16
+ VALUE32 = 32
+
+
class MyIntEnum(enum.IntEnum):
"""Example of an IntEnum enum."""
@@ -53,17 +61,25 @@ class MixinStrEnum(str, enum.Enum):
VALUE33 = "value33"
-class MyEnum(enum.Enum):
- """Example of a simple enum."""
+class MyIntableEnum(enum.Enum):
+ """Example of a simple enum with int values."""
VALUE100 = 100
VALUE200 = 200
VALUE300 = 300
-class MyFlags(enum.Flag):
- """Example of a simple flag."""
+class MyStringableEnum(StrEnum):
+ """Example of a simple enum with str values."""
- VALUE8 = 8
- VALUE16 = 16
- VALUE32 = 32
+ VALUE1 = "value10"
+ VALUE2 = "value20"
+ VALUE3 = "value30"
+
+
+class MyMixedEnum(enum.Enum):
+ """Example of an enum with mixed values."""
+
+ VALUE1 = "value1"
+ VALUE2 = 2
+ VALUE3 = 3.0
diff --git a/tests/unit/test_streamlit_panel.py b/tests/unit/test_streamlit_panel.py
index c1c3382..9b67bf6 100644
--- a/tests/unit/test_streamlit_panel.py
+++ b/tests/unit/test_streamlit_panel.py
@@ -1,3 +1,6 @@
+import enum
+from datetime import datetime
+
import grpc
import pytest
from typing_extensions import assert_type
@@ -242,6 +245,17 @@ def test___set_int_type___get_value_with_bool_default___raises_exception(
panel.get_value(value_id, False)
+def test___set_string_enum_type___get_value_with_int_enum_default___raises_exception(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ panel.set_value(value_id, test_types.MyStrEnum.VALUE3)
+
+ with pytest.raises(ValueError):
+ panel.get_value(value_id, test_types.MyIntEnum.VALUE10)
+
+
@pytest.mark.parametrize(
"value_payload",
[
@@ -250,31 +264,63 @@ def test___set_int_type___get_value_with_bool_default___raises_exception(
3.14,
True,
b"robotext",
+ ],
+)
+def test___builtin_scalar_type___set_value___gets_same_value(
+ fake_panel_channel: grpc.Channel,
+ value_payload: object,
+) -> None:
+ """Test that set_value() and get_value() work for builtin scalar types."""
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+
+ value_id = "test_id"
+ panel.set_value(value_id, value_payload)
+
+ assert panel.get_value(value_id) == value_payload
+
+
+@pytest.mark.parametrize(
+ "value_payload",
+ [
test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4,
+ test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32,
test_types.MyIntEnum.VALUE20,
+ test_types.MyIntableEnum.VALUE200,
test_types.MyStrEnum.VALUE3,
+ test_types.MyStringableEnum.VALUE2,
test_types.MixinIntEnum.VALUE33,
test_types.MixinStrEnum.VALUE11,
+ test_types.MyMixedEnum.VALUE2,
],
)
-def test___builtin_scalar_type___set_value___gets_same_value(
+def test___enum_type___set_value___gets_same_value(
fake_panel_channel: grpc.Channel,
- value_payload: object,
+ value_payload: enum.Enum,
) -> None:
- """Test that set_value() and get_value() work for builtin scalar types."""
+ """Test that set_value() and get_value() work for enum types."""
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
value_id = "test_id"
panel.set_value(value_id, value_payload)
- assert panel.get_value(value_id) == value_payload
+ # without providing a default value, get_value will return the raw value, not the enum
+ assert panel.get_value(value_id) == value_payload.value
@pytest.mark.parametrize(
"value_payload",
[
- test_types.MyEnum.VALUE300,
- test_types.MyFlags.VALUE8 | test_types.MyFlags.VALUE16,
+ datetime.now(),
+ lambda x: x + 1,
+ [1, "string"],
+ ["string", []],
+ (42, "hello", 3.14, b"bytes"),
+ set([1, "mixed", True]),
+ (i for i in range(5)),
+ {
+ "key1": [1, 2, 3],
+ "key2": {"nested": True, "values": [4.5, 6.7]},
+ },
],
)
def test___unsupported_type___set_value___raises(
@@ -341,6 +387,117 @@ def test___sequence_of_builtin_type___set_value___gets_same_value(
assert list(received_value) == list(value_payload) # type: ignore [call-overload]
+def test___set_int_enum_value___get_value___returns_int_enum(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ enum_value = test_types.MyIntEnum.VALUE20
+ panel.set_value(value_id, enum_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyIntEnum.VALUE10)
+
+ assert_type(retrieved_value, test_types.MyIntEnum)
+ assert retrieved_value is test_types.MyIntEnum.VALUE20
+ assert retrieved_value.value == enum_value.value
+ assert retrieved_value.name == enum_value.name
+
+
+def test___set_intable_enum_value___get_value___returns_intable_enum(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ enum_value = test_types.MyIntableEnum.VALUE200
+ panel.set_value(value_id, enum_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyIntableEnum.VALUE100)
+
+ assert_type(retrieved_value, test_types.MyIntableEnum)
+ assert retrieved_value is test_types.MyIntableEnum.VALUE200
+ assert retrieved_value.value == enum_value.value
+ assert retrieved_value.name == enum_value.name
+
+
+def test___set_string_enum_value___get_value___returns_string_enum(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ enum_value = test_types.MyStrEnum.VALUE3
+ panel.set_value(value_id, enum_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyStrEnum.VALUE1)
+
+ assert_type(retrieved_value, test_types.MyStrEnum)
+ assert retrieved_value is test_types.MyStrEnum.VALUE3
+ assert retrieved_value.value == enum_value.value
+ assert retrieved_value.name == enum_value.name
+
+
+def test___set_stringable_enum_value___get_value___returns_stringable_enum(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ enum_value = test_types.MyStringableEnum.VALUE3
+ panel.set_value(value_id, enum_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyStringableEnum.VALUE1)
+
+ assert_type(retrieved_value, test_types.MyStringableEnum)
+ assert retrieved_value is test_types.MyStringableEnum.VALUE3
+ assert retrieved_value.value == enum_value.value
+ assert retrieved_value.name == enum_value.name
+
+
+def test___set_mixed_enum_value___get_value___returns_mixed_enum(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ enum_value = test_types.MyMixedEnum.VALUE2
+ panel.set_value(value_id, enum_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyMixedEnum.VALUE1)
+
+ assert_type(retrieved_value, test_types.MyMixedEnum)
+ assert retrieved_value is test_types.MyMixedEnum.VALUE2
+ assert retrieved_value.value == enum_value.value
+ assert retrieved_value.name == enum_value.name
+
+
+def test___set_int_flags_value___get_value___returns_int_flags(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ flags_value = test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4
+ panel.set_value(value_id, flags_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyIntFlags.VALUE2)
+
+ assert_type(retrieved_value, test_types.MyIntFlags)
+ assert retrieved_value == (test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4)
+ assert retrieved_value.value == flags_value.value
+
+
+def test___set_intable_flags_value___get_value___returns_intable_flags(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
+ value_id = "test_id"
+ flags_value = test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32
+ panel.set_value(value_id, flags_value)
+
+ retrieved_value = panel.get_value(value_id, test_types.MyIntableFlags.VALUE8)
+
+ assert_type(retrieved_value, test_types.MyIntableFlags)
+ assert retrieved_value is test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32
+ assert retrieved_value.value == flags_value.value
+ assert retrieved_value.name == flags_value.name
+
+
def test___panel___panel_is_running_and_in_memory(
fake_panel_channel: grpc.Channel,
) -> None:
diff --git a/tests/unit/test_streamlit_panel_value_accessor.py b/tests/unit/test_streamlit_panel_value_accessor.py
new file mode 100644
index 0000000..4a845b6
--- /dev/null
+++ b/tests/unit/test_streamlit_panel_value_accessor.py
@@ -0,0 +1,110 @@
+import grpc
+
+from nipanel import StreamlitPanelValueAccessor
+from tests.types import MyIntEnum
+from tests.utils._fake_python_panel_service import FakePythonPanelService
+
+
+def test___no_previous_value___set_value_if_changed___sets_value(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+
+ accessor.set_value_if_changed("test_id", "test_value")
+
+ assert accessor.get_value("test_id") == "test_value"
+
+
+def test___set_value_if_changed___set_same_value___does_not_set_value_again(
+ fake_panel_channel: grpc.Channel,
+ fake_python_panel_service: FakePythonPanelService,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", "test_value")
+ initial_set_count = fake_python_panel_service.servicer.set_count
+
+ accessor.set_value_if_changed("test_id", "test_value")
+
+ assert fake_python_panel_service.servicer.set_count == initial_set_count
+ assert accessor.get_value("test_id") == "test_value"
+
+
+def test___set_value_if_changed___set_different_value___sets_new_value(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", "test_value")
+
+ accessor.set_value_if_changed("test_id", "new_value")
+
+ assert accessor.get_value("test_id") == "new_value"
+
+
+def test___set_value_if_changed___different_value_ids___tracks_separately(
+ fake_panel_channel: grpc.Channel,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("id1", "value1")
+ accessor.set_value_if_changed("id2", "value2")
+
+ accessor.set_value_if_changed("id1", "value1")
+ accessor.set_value_if_changed("id2", "new_value2")
+
+ assert accessor.get_value("id1") == "value1"
+ assert accessor.get_value("id2") == "new_value2"
+
+
+def test___set_value_if_changed_with_list_value___set_same_value___does_not_set_value_again(
+ fake_panel_channel: grpc.Channel,
+ fake_python_panel_service: FakePythonPanelService,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", [1, 2, 3])
+ initial_set_count = fake_python_panel_service.servicer.set_count
+
+ accessor.set_value_if_changed("test_id", [1, 2, 3])
+
+ assert fake_python_panel_service.servicer.set_count == initial_set_count
+ assert accessor.get_value("test_id") == [1, 2, 3]
+
+
+def test___set_value_if_changed_with_list_value___set_different_value___sets_new_value(
+ fake_panel_channel: grpc.Channel,
+ fake_python_panel_service: FakePythonPanelService,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", [1, 2, 3])
+ initial_set_count = fake_python_panel_service.servicer.set_count
+
+ accessor.set_value_if_changed("test_id", [1, 2, 4])
+
+ assert fake_python_panel_service.servicer.set_count > initial_set_count
+ assert accessor.get_value("test_id") == [1, 2, 4]
+
+
+def test___set_value_if_changed_with_enum_value___set_same_value___does_not_set_value_again(
+ fake_panel_channel: grpc.Channel,
+ fake_python_panel_service: FakePythonPanelService,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20)
+ initial_set_count = fake_python_panel_service.servicer.set_count
+
+ accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20)
+
+ assert fake_python_panel_service.servicer.set_count == initial_set_count
+ assert accessor.get_value("test_id") == 20 # Enums are stored as their values
+
+
+def test___set_value_if_changed_with_enum_value___set_different_value___sets_new_value(
+ fake_panel_channel: grpc.Channel,
+ fake_python_panel_service: FakePythonPanelService,
+) -> None:
+ accessor = StreamlitPanelValueAccessor("panel_id", grpc_channel=fake_panel_channel)
+ accessor.set_value_if_changed("test_id", MyIntEnum.VALUE20)
+ initial_set_count = fake_python_panel_service.servicer.set_count
+
+ accessor.set_value_if_changed("test_id", MyIntEnum.VALUE30)
+
+ assert fake_python_panel_service.servicer.set_count > initial_set_count
+ assert accessor.get_value("test_id") == 30 # New enum value should be set
diff --git a/tests/utils/_fake_python_panel_servicer.py b/tests/utils/_fake_python_panel_servicer.py
index e6d59ee..ac74fe6 100644
--- a/tests/utils/_fake_python_panel_servicer.py
+++ b/tests/utils/_fake_python_panel_servicer.py
@@ -26,6 +26,7 @@ def __init__(self) -> None:
self._panel_is_running: dict[str, bool] = {}
self._panel_value_ids: dict[str, dict[str, Any]] = {}
self._fail_next_start_panel = False
+ self._set_count: int = 0
self._notification_count: int = 0
self._python_path: str = ""
@@ -70,6 +71,7 @@ def SetValue(self, request: SetValueRequest, context: Any) -> SetValueResponse:
"""Trivial implementation for testing."""
self._init_panel(request.panel_id)
self._panel_value_ids[request.panel_id][request.value_id] = request.value
+ self._set_count += 1
if request.notify:
self._notification_count += 1
return SetValueResponse()
@@ -78,6 +80,11 @@ def fail_next_start_panel(self) -> None:
"""Set whether the StartPanel method should fail the next time it is called."""
self._fail_next_start_panel = True
+ @property
+ def set_count(self) -> int:
+ """Get the total number of times SetValue was called."""
+ return self._set_count
+
@property
def notification_count(self) -> int:
"""Get the number of notifications sent from SetValue."""