Skip to content

Commit 862ffbf

Browse files
Support overridden dataclass defaults in control encoding (#5861)
Added logic to retrieve default values from base dataclasses when encoding controls, ensuring overridden defaults in subclasses are correctly emitted. Includes new tests and example usage with custom button controls.
1 parent f37c1f4 commit 862ffbf

File tree

3 files changed

+87
-5
lines changed

3 files changed

+87
-5
lines changed
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any
3+
4+
import flet as ft
5+
6+
7+
def main(page: ft.Page):
8+
@ft.control
9+
class MyButton(ft.Button):
10+
expand: int = field(default_factory=lambda: 1)
11+
style: ft.ButtonStyle = field(
12+
default_factory=lambda: ft.ButtonStyle(
13+
shape=ft.RoundedRectangleBorder(radius=10)
14+
)
15+
)
16+
bgcolor: ft.Colors = ft.Colors.BLUE_ACCENT
17+
icon: Any = ft.Icons.HEADPHONES
18+
19+
@dataclass
20+
class MyButton2(ft.Button):
21+
expand: Any = 1
22+
bgcolor: ft.Colors = ft.Colors.GREEN_ACCENT
23+
style: ft.ButtonStyle = field(
24+
default_factory=lambda: ft.ButtonStyle(
25+
shape=ft.RoundedRectangleBorder(radius=20)
26+
)
27+
)
28+
icon: ft.IconDataOrControl = ft.Icons.HEADPHONES
29+
30+
@ft.control
31+
class MyButton3(ft.Button):
32+
def init(self):
33+
self.expand = 1
34+
self.bgcolor = ft.Colors.RED_ACCENT
35+
self.style = ft.ButtonStyle(shape=ft.RoundedRectangleBorder(radius=30))
36+
self.icon = ft.Icons.HEADPHONES
37+
38+
page.add(
39+
ft.Row([MyButton(content="1")]),
40+
ft.Row([MyButton2(content="2")]),
41+
ft.Row([MyButton3(content="3")]),
42+
)
43+
44+
45+
ft.run(main)

sdk/python/packages/flet/src/flet/messaging/protocol.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,20 @@
88
from flet.controls.duration import Duration
99

1010

11+
def _get_root_dataclass_field(cls, field_name):
12+
"""
13+
Returns the field definition from the earliest dataclass in the MRO
14+
that declares `field_name`. This lets us recover defaults configured
15+
on base controls before subclasses override them.
16+
"""
17+
18+
for base in reversed(cls.__mro__):
19+
dataclass_fields = getattr(base, "__dataclass_fields__", None)
20+
if dataclass_fields and field_name in dataclass_fields:
21+
return dataclass_fields[field_name]
22+
return None
23+
24+
1125
def configure_encode_object_for_msgpack(control_cls):
1226
def encode_object_for_msgpack(obj):
1327
if is_dataclass(obj):
@@ -36,10 +50,18 @@ def encode_object_for_msgpack(obj):
3650
elif is_dataclass(v):
3751
r[field.name] = v
3852
prev_classes[field.name] = v
39-
elif v is not None and (
40-
v != field.default or not isinstance(obj, control_cls)
41-
):
42-
r[field.name] = v
53+
else:
54+
default_value = field.default
55+
if isinstance(obj, control_cls):
56+
root_field = _get_root_dataclass_field(
57+
obj.__class__, field.name
58+
)
59+
if root_field is not None:
60+
default_value = root_field.default
61+
if v is not None and (
62+
v != default_value or not isinstance(obj, control_cls)
63+
):
64+
r[field.name] = v
4365

4466
if not hasattr(obj, "_frozen"):
4567
setattr(obj, "__prev_lists", prev_lists)

sdk/python/packages/flet/tests/test_patch_dataclass.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import msgpack
44

5-
from flet.controls.base_control import BaseControl
5+
from flet.controls.base_control import BaseControl, control
66
from flet.controls.base_page import PageMediaData
77
from flet.controls.object_patch import ObjectPatch
88
from flet.controls.padding import Padding
@@ -37,6 +37,21 @@ class AppSettings:
3737
assert settings.config.timeout == 2.5
3838

3939

40+
def test_encode_emits_overridden_defaults():
41+
@control("BaseTestControl")
42+
class BaseTestControl(BaseControl):
43+
foo: int = 0
44+
45+
@control("ChildTestControl")
46+
class ChildTestControl(BaseTestControl):
47+
foo: int = 5
48+
49+
encoder = configure_encode_object_for_msgpack(BaseControl)
50+
encoded = encoder(ChildTestControl())
51+
52+
assert encoded["foo"] == 5
53+
54+
4055
def test_page_patch_dataclass():
4156
conn = Connection()
4257
conn.pubsubhub = PubSubHub()

0 commit comments

Comments
 (0)