Skip to content

Commit 585948c

Browse files
Mike ProsserMike Prosser
authored andcommitted
get_value returns the enum type when a default is provided
1 parent 8be90f7 commit 585948c

File tree

3 files changed

+69
-2
lines changed

3 files changed

+69
-2
lines changed

examples/simple_graph/simple_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
# Generate and update the sine wave data periodically
2323
while True:
24-
amplitude_enum = AmplitudeEnum(panel.get_value("amplitude_enum", AmplitudeEnum.SMALL.value))
24+
amplitude_enum = panel.get_value("amplitude_enum", AmplitudeEnum.SMALL)
2525
base_frequency = panel.get_value("base_frequency", 1.0)
2626

2727
# Slowly vary the total frequency for a more dynamic visualization

src/nipanel/_panel_value_accessor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import enum
34
from abc import ABC
45
from typing import TypeVar, overload
56

@@ -62,8 +63,16 @@ def get_value(self, value_id: str, default_value: _T | None = None) -> _T | obje
6263
"""
6364
try:
6465
value = self._panel_client.get_value(self._panel_id, value_id)
66+
6567
if default_value is not None and not isinstance(value, type(default_value)):
66-
raise TypeError("Value type does not match default value type.")
68+
if isinstance(default_value, enum.Enum):
69+
enum_type = type(default_value)
70+
return enum_type(value)
71+
72+
raise TypeError(
73+
f"Value type {type(value).__name__} does not match default value type {type(default_value).__name__}."
74+
)
75+
6776
return value
6877

6978
except grpc.RpcError as e:

tests/unit/test_streamlit_panel.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ def test___set_int_type___get_value_with_bool_default___raises_exception(
242242
panel.get_value(value_id, False)
243243

244244

245+
def test___set_string_enum_type___get_value_with_int_enum_default___raises_exception(
246+
fake_panel_channel: grpc.Channel,
247+
) -> None:
248+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
249+
value_id = "test_id"
250+
panel.set_value(value_id, test_types.MyStrEnum.VALUE3)
251+
252+
with pytest.raises(ValueError):
253+
panel.get_value(value_id, test_types.MyIntEnum.VALUE10)
254+
255+
245256
@pytest.mark.parametrize(
246257
"value_payload",
247258
[
@@ -341,6 +352,53 @@ def test___sequence_of_builtin_type___set_value___gets_same_value(
341352
assert list(received_value) == list(value_payload) # type: ignore [call-overload]
342353

343354

355+
def test___set_int_enum_value___get_value___returns_int_enum(
356+
fake_panel_channel: grpc.Channel,
357+
) -> None:
358+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
359+
value_id = "test_id"
360+
enum_value = test_types.MyIntEnum.VALUE20
361+
panel.set_value(value_id, enum_value)
362+
363+
retrieved_value = panel.get_value(value_id, test_types.MyIntEnum.VALUE10)
364+
365+
assert_type(retrieved_value, test_types.MyIntEnum)
366+
assert retrieved_value is test_types.MyIntEnum.VALUE20
367+
assert retrieved_value.value == enum_value.value
368+
assert retrieved_value.name == enum_value.name
369+
370+
371+
def test___set_string_enum_value___get_value___returns_string_enum(
372+
fake_panel_channel: grpc.Channel,
373+
) -> None:
374+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
375+
value_id = "test_id"
376+
enum_value = test_types.MyStrEnum.VALUE3
377+
panel.set_value(value_id, enum_value)
378+
379+
retrieved_value = panel.get_value(value_id, test_types.MyStrEnum.VALUE1)
380+
381+
assert_type(retrieved_value, test_types.MyStrEnum)
382+
assert retrieved_value is test_types.MyStrEnum.VALUE3
383+
assert retrieved_value.value == enum_value.value
384+
assert retrieved_value.name == enum_value.name
385+
386+
387+
def test___set_int_flags_value___get_value___returns_int_flags(
388+
fake_panel_channel: grpc.Channel,
389+
) -> None:
390+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
391+
value_id = "test_id"
392+
flags_value = test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4
393+
panel.set_value(value_id, flags_value)
394+
395+
retrieved_value = panel.get_value(value_id, test_types.MyIntFlags.VALUE2)
396+
397+
assert_type(retrieved_value, test_types.MyIntFlags)
398+
assert retrieved_value == (test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4)
399+
assert retrieved_value.value == flags_value.value
400+
401+
344402
def test___panel___panel_is_running_and_in_memory(
345403
fake_panel_channel: grpc.Channel,
346404
) -> None:

0 commit comments

Comments
 (0)