Skip to content

Commit 0538676

Browse files
Mike ProsserMike Prosser
authored andcommitted
Enhance enum support in Streamlit panel and tests
- Added MyIntableFlags and MyIntableEnum classes to define new enum types. - Updated all_types_with_values to include new enum types. - Modified PanelValueAccessor to allow list values without type matching. - Improved test coverage for enum types in StreamlitPanel.
1 parent 3c84a99 commit 0538676

File tree

5 files changed

+181
-22
lines changed

5 files changed

+181
-22
lines changed

examples/all_types/all_types_panel.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""A Streamlit visualization panel for the all_types.py example script."""
22

3-
from enum import Enum
3+
from enum import Enum, Flag
44
from typing import cast
55

66
import streamlit as st
@@ -23,16 +23,20 @@
2323
st.write(name)
2424

2525
with col2:
26-
if isinstance(default_value, Enum):
27-
nipanel.enum_selectbox(panel, label=name, value=cast(Enum, default_value), key=name)
28-
elif isinstance(default_value, bool):
26+
if isinstance(default_value, bool):
2927
st.checkbox(label=name, value=cast(bool, default_value), key=name)
30-
elif isinstance(default_value, int):
28+
elif isinstance(default_value, Enum) and not isinstance(default_value, Flag):
29+
nipanel.enum_selectbox(panel, label=name, value=cast(Enum, default_value), key=name)
30+
elif isinstance(default_value, int) and not isinstance(default_value, Flag):
3131
st.number_input(label=name, value=cast(int, default_value), key=name)
3232
elif isinstance(default_value, float):
3333
st.number_input(label=name, value=cast(float, default_value), key=name, format="%.2f")
3434
elif isinstance(default_value, str):
3535
st.text_input(label=name, value=cast(str, default_value), key=name)
3636

3737
with col3:
38-
st.write(panel.get_value(name))
38+
value = panel.get_value(name)
39+
value_with_default = panel.get_value(name, default_value=default_value)
40+
st.write(value_with_default)
41+
if str(value) != str(value_with_default):
42+
st.write("(", value, ")")

examples/all_types/define_types.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ class MyIntFlags(enum.IntFlag):
1515
VALUE4 = 4
1616

1717

18+
class MyIntableFlags(enum.Flag):
19+
"""Example of an Flag enum with int values."""
20+
21+
VALUE1 = 1
22+
VALUE2 = 2
23+
VALUE4 = 4
24+
25+
1826
class MyIntEnum(enum.IntEnum):
1927
"""Example of an IntEnum enum."""
2028

@@ -23,6 +31,14 @@ class MyIntEnum(enum.IntEnum):
2331
VALUE30 = 30
2432

2533

34+
class MyIntableEnum(enum.Enum):
35+
"""Example of an enum with int values."""
36+
37+
VALUE100 = 100
38+
VALUE200 = 200
39+
VALUE300 = 300
40+
41+
2642
class MyStrEnum(str, enum.Enum):
2743
"""Example of a mixin string enum."""
2844

@@ -31,6 +47,22 @@ class MyStrEnum(str, enum.Enum):
3147
VALUE3 = "value3"
3248

3349

50+
class MyStringableEnum(enum.Enum):
51+
"""Example of an enum with string values."""
52+
53+
VALUE1 = "value1"
54+
VALUE2 = "value2"
55+
VALUE3 = "value3"
56+
57+
58+
class MyMixedEnum(enum.Enum):
59+
"""Example of an enum with mixed values."""
60+
61+
VALUE1 = "value1"
62+
VALUE2 = 2
63+
VALUE3 = 3.0
64+
65+
3466
all_types_with_values = {
3567
# supported scalar types
3668
"bool": True,
@@ -42,6 +74,10 @@ class MyStrEnum(str, enum.Enum):
4274
"intflags": MyIntFlags.VALUE1 | MyIntFlags.VALUE4,
4375
"intenum": MyIntEnum.VALUE20,
4476
"strenum": MyStrEnum.VALUE3,
77+
"intableenum": MyIntableEnum.VALUE200,
78+
"intableflags": MyIntableFlags.VALUE1 | MyIntableFlags.VALUE2,
79+
"stringableenum": MyStringableEnum.VALUE2,
80+
"mixedenum": MyMixedEnum.VALUE2,
4581
# NI types
4682
"nitypes_Scalar": Scalar(42, "m"),
4783
"nitypes_AnalogWaveform": AnalogWaveform.from_array_1d(np.array([1.0, 2.0, 3.0])),
@@ -62,6 +98,6 @@ class MyStrEnum(str, enum.Enum):
6298
# supported 2D collections
6399
"list_list_float": [[1.0, 2.0], [3.0, 4.0]],
64100
"tuple_tuple_float": ((1.0, 2.0), (3.0, 4.0)),
65-
"set_list_float": set([(1.0, 2.0), (3.0, 4.0)]),
101+
"set_tuple_float": set([(1.0, 2.0), (3.0, 4.0)]),
66102
"frozenset_frozenset_float": frozenset([frozenset([1.0, 2.0]), frozenset([3.0, 4.0])]),
67103
}

src/nipanel/_panel_value_accessor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ def get_value(self, value_id: str, default_value: _T | None = None) -> _T | obje
7272
enum_type = type(default_value)
7373
return enum_type(value)
7474

75-
raise TypeError(
76-
f"Value type {type(value).__name__} does not match default value type {type(default_value).__name__}."
77-
)
75+
# lists are allowed to not match, since sets and tuples are converted to lists
76+
if not isinstance(value, list):
77+
raise TypeError(
78+
f"Value type {type(value).__name__} does not match default value type {type(default_value).__name__}."
79+
)
7880

7981
return value
8082

@@ -85,6 +87,9 @@ def set_value(self, value_id: str, value: object) -> None:
8587
value_id: The id of the value
8688
value: The value
8789
"""
90+
if isinstance(value, enum.Enum):
91+
value = value.value
92+
8893
self._panel_client.set_value(
8994
self._panel_id, value_id, value, notify=self._notify_on_set_value
9095
)

tests/types.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ class MyIntFlags(enum.IntFlag):
2121
VALUE4 = 4
2222

2323

24+
class MyIntableFlags(enum.Flag):
25+
"""Example of a simple flag with int values."""
26+
27+
VALUE8 = 8
28+
VALUE16 = 16
29+
VALUE32 = 32
30+
31+
2432
class MyIntEnum(enum.IntEnum):
2533
"""Example of an IntEnum enum."""
2634

@@ -53,17 +61,25 @@ class MixinStrEnum(str, enum.Enum):
5361
VALUE33 = "value33"
5462

5563

56-
class MyEnum(enum.Enum):
57-
"""Example of a simple enum."""
64+
class MyIntableEnum(enum.Enum):
65+
"""Example of a simple enum with int values."""
5866

5967
VALUE100 = 100
6068
VALUE200 = 200
6169
VALUE300 = 300
6270

6371

64-
class MyFlags(enum.Flag):
65-
"""Example of a simple flag."""
72+
class MyStringableEnum(StrEnum):
73+
"""Example of a simple enum with str values."""
6674

67-
VALUE8 = 8
68-
VALUE16 = 16
69-
VALUE32 = 32
75+
VALUE1 = "value10"
76+
VALUE2 = "value20"
77+
VALUE3 = "value30"
78+
79+
80+
class MyMixedEnum(enum.Enum):
81+
"""Example of an enum with mixed values."""
82+
83+
VALUE1 = "value1"
84+
VALUE2 = 2
85+
VALUE3 = 3.0

tests/unit/test_streamlit_panel.py

Lines changed: 103 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import enum
12
import grpc
23
import pytest
4+
from datetime import datetime
35
from typing_extensions import assert_type
46

57
import tests.types as test_types
@@ -261,31 +263,63 @@ def test___set_string_enum_type___get_value_with_int_enum_default___raises_excep
261263
3.14,
262264
True,
263265
b"robotext",
266+
],
267+
)
268+
def test___builtin_scalar_type___set_value___gets_same_value(
269+
fake_panel_channel: grpc.Channel,
270+
value_payload: object,
271+
) -> None:
272+
"""Test that set_value() and get_value() work for builtin scalar types."""
273+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
274+
275+
value_id = "test_id"
276+
panel.set_value(value_id, value_payload)
277+
278+
assert panel.get_value(value_id) == value_payload
279+
280+
281+
@pytest.mark.parametrize(
282+
"value_payload",
283+
[
264284
test_types.MyIntFlags.VALUE1 | test_types.MyIntFlags.VALUE4,
285+
test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32,
265286
test_types.MyIntEnum.VALUE20,
287+
test_types.MyIntableEnum.VALUE200,
266288
test_types.MyStrEnum.VALUE3,
289+
test_types.MyStringableEnum.VALUE2,
267290
test_types.MixinIntEnum.VALUE33,
268291
test_types.MixinStrEnum.VALUE11,
292+
test_types.MyMixedEnum.VALUE2,
269293
],
270294
)
271-
def test___builtin_scalar_type___set_value___gets_same_value(
295+
def test___enum_type___set_value___gets_same_value(
272296
fake_panel_channel: grpc.Channel,
273-
value_payload: object,
297+
value_payload: enum.Enum,
274298
) -> None:
275299
"""Test that set_value() and get_value() work for builtin scalar types."""
276300
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
277301

278302
value_id = "test_id"
279303
panel.set_value(value_id, value_payload)
280304

281-
assert panel.get_value(value_id) == value_payload
305+
# without providing a default value, get_value will return the raw value, not the enum
306+
assert panel.get_value(value_id) == value_payload.value
282307

283308

284309
@pytest.mark.parametrize(
285310
"value_payload",
286311
[
287-
test_types.MyEnum.VALUE300,
288-
test_types.MyFlags.VALUE8 | test_types.MyFlags.VALUE16,
312+
datetime.now(),
313+
lambda x: x + 1,
314+
[1, "string"],
315+
["string", []],
316+
(42, "hello", 3.14, b"bytes"),
317+
set([1, "mixed", True]),
318+
(i for i in range(5)),
319+
{
320+
"key1": [1, 2, 3],
321+
"key2": {"nested": True, "values": [4.5, 6.7]},
322+
},
289323
],
290324
)
291325
def test___unsupported_type___set_value___raises(
@@ -368,6 +402,22 @@ def test___set_int_enum_value___get_value___returns_int_enum(
368402
assert retrieved_value.name == enum_value.name
369403

370404

405+
def test___set_intable_enum_value___get_value___returns_enum(
406+
fake_panel_channel: grpc.Channel,
407+
) -> None:
408+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
409+
value_id = "test_id"
410+
enum_value = test_types.MyIntableEnum.VALUE200
411+
panel.set_value(value_id, enum_value)
412+
413+
retrieved_value = panel.get_value(value_id, test_types.MyIntableEnum.VALUE100)
414+
415+
assert_type(retrieved_value, test_types.MyIntableEnum)
416+
assert retrieved_value is test_types.MyIntableEnum.VALUE200
417+
assert retrieved_value.value == enum_value.value
418+
assert retrieved_value.name == enum_value.name
419+
420+
371421
def test___set_string_enum_value___get_value___returns_string_enum(
372422
fake_panel_channel: grpc.Channel,
373423
) -> None:
@@ -384,6 +434,38 @@ def test___set_string_enum_value___get_value___returns_string_enum(
384434
assert retrieved_value.name == enum_value.name
385435

386436

437+
def test___set_stringable_enum_value___get_value___returns_enum(
438+
fake_panel_channel: grpc.Channel,
439+
) -> None:
440+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
441+
value_id = "test_id"
442+
enum_value = test_types.MyStringableEnum.VALUE3
443+
panel.set_value(value_id, enum_value)
444+
445+
retrieved_value = panel.get_value(value_id, test_types.MyStringableEnum.VALUE1)
446+
447+
assert_type(retrieved_value, test_types.MyStringableEnum)
448+
assert retrieved_value is test_types.MyStringableEnum.VALUE3
449+
assert retrieved_value.value == enum_value.value
450+
assert retrieved_value.name == enum_value.name
451+
452+
453+
def test___set_mixed_enum_value___get_value___returns_enum(
454+
fake_panel_channel: grpc.Channel,
455+
) -> None:
456+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
457+
value_id = "test_id"
458+
enum_value = test_types.MyMixedEnum.VALUE2
459+
panel.set_value(value_id, enum_value)
460+
461+
retrieved_value = panel.get_value(value_id, test_types.MyMixedEnum.VALUE1)
462+
463+
assert_type(retrieved_value, test_types.MyMixedEnum)
464+
assert retrieved_value is test_types.MyMixedEnum.VALUE2
465+
assert retrieved_value.value == enum_value.value
466+
assert retrieved_value.name == enum_value.name
467+
468+
387469
def test___set_int_flags_value___get_value___returns_int_flags(
388470
fake_panel_channel: grpc.Channel,
389471
) -> None:
@@ -399,6 +481,22 @@ def test___set_int_flags_value___get_value___returns_int_flags(
399481
assert retrieved_value.value == flags_value.value
400482

401483

484+
def test___set_intable_flags_value___get_value___returns_flags(
485+
fake_panel_channel: grpc.Channel,
486+
) -> None:
487+
panel = StreamlitPanel("my_panel", "path/to/script", grpc_channel=fake_panel_channel)
488+
value_id = "test_id"
489+
flags_value = test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32
490+
panel.set_value(value_id, flags_value)
491+
492+
retrieved_value = panel.get_value(value_id, test_types.MyIntableFlags.VALUE8)
493+
494+
assert_type(retrieved_value, test_types.MyIntableFlags)
495+
assert retrieved_value is test_types.MyIntableFlags.VALUE16 | test_types.MyIntableFlags.VALUE32
496+
assert retrieved_value.value == flags_value.value
497+
assert retrieved_value.name == flags_value.name
498+
499+
402500
def test___panel___panel_is_running_and_in_memory(
403501
fake_panel_channel: grpc.Channel,
404502
) -> None:

0 commit comments

Comments
 (0)