Skip to content

Commit fbf29b7

Browse files
committed
fix(perf): prevent sending data back to client
1 parent 3eaefca commit fbf29b7

File tree

3 files changed

+85
-18
lines changed

3 files changed

+85
-18
lines changed

examples/perf_v2.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import time
2+
3+
from trame.app import TrameApp
4+
from trame.ui.html import DivLayout
5+
from trame.widgets import html
6+
from trame_dataclass.v2 import StateDataModel, Sync, watch
7+
8+
9+
class Data(StateDataModel):
10+
value = Sync(int, 1)
11+
12+
@watch("value")
13+
def _fake_busy(self, _):
14+
time.sleep(0.1)
15+
16+
17+
class Test(TrameApp):
18+
def __init__(self, server=None):
19+
super().__init__(server)
20+
self.data_1 = Data(self.server, enable_collaboration=True)
21+
self.data_2 = Data(self.server, enable_collaboration=False)
22+
self._build_ui()
23+
24+
def _build_ui(self):
25+
with DivLayout(self.server) as self.ui:
26+
with self.data_1.provide_as("data_slow"):
27+
with self.data_2.provide_as("data_fast"):
28+
html.Div("Collaboration ON: {{ data_slow.value }}")
29+
html.Input(
30+
type="range",
31+
v_model_number="data_slow.value",
32+
min=0,
33+
max=500,
34+
step=1,
35+
style="width: 100%",
36+
)
37+
html.Div("Collaboration OFF (default): {{ data_fast.value }}")
38+
html.Input(
39+
type="range",
40+
v_model_number="data_fast.value",
41+
min=0,
42+
max=200,
43+
step=1,
44+
style="width: 100%",
45+
)
46+
47+
48+
def main():
49+
app = Test()
50+
app.server.start()
51+
52+
53+
if __name__ == "__main__":
54+
main()

src/trame_dataclass/module/protocol_v2.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ def compute_definition(trame_dataclass_class):
1616
class TrameDataclassProtocol(LinkProtocol):
1717
def __init__(self, *args, **kwargs):
1818
super().__init__(*args, **kwargs)
19-
2019
self.class_definitions = {}
2120
self.next_class_definition_id = 1
2221

@@ -77,14 +76,8 @@ def get_state(self, instance_id):
7776
def update_state(self, msg):
7877
for dc_id, state in msg.items():
7978
obj = get_instance(dc_id)
80-
encoders = obj.ENCODERS
8179
if obj is not None:
82-
for k, v in state.items():
83-
convert = encoders.get(k)
84-
if convert:
85-
setattr(obj, k, convert.decoder(v))
86-
else:
87-
setattr(obj, k, v)
80+
obj.update_from_client_state(state)
8881

8982
def push_delta(self, msg):
9083
self.publish("trame.dataclass.publish", msg)

src/trame_dataclass/v2.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,15 @@ def can_be_decorated(x):
112112

113113
def _save_field(name, src, dst, encoder=None):
114114
value = getattr(src, name)
115+
115116
if encoder:
116-
dst[name] = encoder(value)
117-
else:
117+
value = encoder(value)
118+
119+
if name not in dst or dst[name] != value:
118120
dst[name] = value
121+
return True
122+
123+
return False
119124

120125

121126
def _setup_class_fields(owner):
@@ -145,13 +150,14 @@ class TypeValidation(Enum):
145150

146151

147152
class StateDataModel:
148-
def __init__(self, trame_server=None, **kwargs):
153+
def __init__(self, trame_server=None, enable_collaboration=False, **kwargs):
149154
self.__id = _next_id()
150155
self.__trame_server = trame_server
151156

152157
# Register all instances
153158
INSTANCES[self.__id] = self
154159

160+
self._enable_collaboration = enable_collaboration
155161
self._server_state = {}
156162
self._client_state = {}
157163
self._dirty_set = set()
@@ -326,6 +332,9 @@ def client_state(self):
326332
def update_from_client_state(self, partial_state):
327333
encoders = self.ENCODERS
328334
for k, v in partial_state.items():
335+
if not self._enable_collaboration:
336+
self._client_state[k] = v
337+
329338
convert = encoders.get(k)
330339
if convert:
331340
setattr(self, k, convert.decoder(v))
@@ -336,7 +345,12 @@ def update_from_client_state(self, partial_state):
336345
def _id(self):
337346
return self.__id
338347

339-
def flush(self, dirty_set: set[str] | None = None):
348+
def dirty(self, *keys):
349+
"""Mark variable dirty and trigger watchers"""
350+
self._dirty_set.update(keys)
351+
self._on_dirty()
352+
353+
def flush(self, dirty_set: set[str] | None = None, force_push=False):
340354
"""Flush the data to the client."""
341355
if self._flush_impl is None:
342356
return
@@ -349,19 +363,25 @@ def flush(self, dirty_set: set[str] | None = None):
349363
self._dirty_set.discard(name)
350364

351365
key_to_send = list(dirty_set & self.CLIENT_NAMES)
366+
modified_keys = []
352367
for name in key_to_send:
353368
encoder = None
354369
if name in self.ENCODERS:
355370
encoder = self.ENCODERS[name].encoder
356371

357-
_save_field(name, self, self._client_state, encoder)
372+
if _save_field(name, self, self._client_state, encoder):
373+
modified_keys.append(name)
374+
375+
if force_push:
376+
modified_keys = list(key_to_send)
358377

359378
# Send data over the network
360-
msg = {
361-
"id": self._id,
362-
"state": {k: self._client_state[k] for k in key_to_send},
363-
}
364-
self._flush_impl(msg)
379+
if modified_keys:
380+
msg = {
381+
"id": self._id,
382+
"state": {k: self._client_state[k] for k in modified_keys},
383+
}
384+
self._flush_impl(msg)
365385

366386

367387
# -----------------------------------------------------------------------------

0 commit comments

Comments
 (0)