Skip to content

Commit 957a996

Browse files
njzjzpre-commit-ci[bot]Copilot
authored
feat: add Node class for serialization and implement display functionality (#5158)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Human-readable number formatting for clearer size displays. * Enhanced model serialization tree visualization with compact, size-aware, and more readable pretty-print output. * **Tests** * Added tests validating the improved serialization tree display and number-formatting output. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <njzjz@qq.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 8b0c3b0 commit 957a996

File tree

2 files changed

+219
-0
lines changed

2 files changed

+219
-0
lines changed

deepmd/dpmodel/utils/serialization.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
from collections.abc import (
55
Callable,
66
)
7+
from copy import (
8+
deepcopy,
9+
)
10+
from functools import (
11+
cached_property,
12+
)
713
from pathlib import (
814
Path,
915
)
@@ -170,3 +176,203 @@ def convert_numpy_ndarray(x: Any) -> Any:
170176
else:
171177
raise ValueError(f"Unknown filename extension: {filename_extension}")
172178
return model_dict
179+
180+
181+
def format_big_number(x: int) -> str:
182+
"""Format a big number with suffixes.
183+
184+
Parameters
185+
----------
186+
x : int
187+
The number to format.
188+
189+
Returns
190+
-------
191+
str
192+
The formatted string.
193+
"""
194+
if x >= 1_000_000_000:
195+
return f"{x / 1_000_000_000:.1f}B"
196+
elif x >= 1_000_000:
197+
return f"{x / 1_000_000:.1f}M"
198+
elif x >= 1_000:
199+
return f"{x / 1_000:.1f}K"
200+
else:
201+
return str(x)
202+
203+
204+
class Node:
205+
"""A node in a serialization tree.
206+
207+
Examples
208+
--------
209+
>>> model_dict = load_dp_model("model.dp") # Example filename
210+
>>> root_node = Node.deserialize(model_dict["model"])
211+
>>> print(root_node)
212+
"""
213+
214+
def __init__(
215+
self,
216+
name: str,
217+
children: dict[str, "Node"],
218+
data: dict[str, Any],
219+
variables: dict[str, Any],
220+
) -> None:
221+
self.name = name
222+
self.children: dict[str, Node] = children
223+
self.data: dict[str, Any] = data
224+
self.variables: dict[str, Any] = variables
225+
226+
@cached_property
227+
def size(self) -> int:
228+
"""Get the size of the node.
229+
230+
Returns
231+
-------
232+
int
233+
The size of the node.
234+
"""
235+
total_size = 0
236+
237+
def count_variables(x: Any) -> Any:
238+
nonlocal total_size
239+
if isinstance(x, np.ndarray):
240+
total_size += x.size
241+
return x
242+
243+
traverse_model_dict(
244+
self.variables,
245+
count_variables,
246+
is_variable=True,
247+
)
248+
for child in self.children.values():
249+
total_size += child.size
250+
return total_size
251+
252+
@classmethod
253+
def deserialize(cls, data: Any) -> "Node":
254+
"""Deserialize a Node from a dictionary.
255+
256+
Parameters
257+
----------
258+
data : Any
259+
The data to deserialize from.
260+
261+
Returns
262+
-------
263+
Node
264+
The deserialized node.
265+
"""
266+
if isinstance(data, dict):
267+
return cls.from_dict(data)
268+
elif isinstance(data, list):
269+
return cls.from_list(data)
270+
else:
271+
raise ValueError("Cannot deserialize Node from non-dict/list data.")
272+
273+
@classmethod
274+
def from_dict(cls, data_dict: dict) -> "Node":
275+
"""Create a Node from a dictionary.
276+
277+
Parameters
278+
----------
279+
data_dict : dict
280+
The dictionary to create the node from.
281+
282+
Returns
283+
-------
284+
Node
285+
The created node.
286+
"""
287+
class_name = data_dict.get("@class")
288+
type_name = data_dict.get("type")
289+
if class_name is not None:
290+
if type_name is not None:
291+
name = f"{class_name} {type_name}"
292+
else:
293+
name = class_name
294+
else:
295+
name = "Node"
296+
variables = {}
297+
children = {}
298+
data = {}
299+
for kk, vv in data_dict.items():
300+
if kk == "@variables":
301+
variables = deepcopy(vv)
302+
elif isinstance(vv, dict):
303+
children[kk] = cls.from_dict(vv)
304+
elif isinstance(vv, list):
305+
# drop if no children inside a list
306+
list_node = cls.from_list(vv)
307+
if len(list_node.children) > 0:
308+
children[kk] = list_node
309+
else:
310+
data[kk] = vv
311+
return cls(name, children, data, variables)
312+
313+
@classmethod
314+
def from_list(cls, data_list: list[Any]) -> "Node":
315+
"""Create a Node from a list.
316+
317+
Parameters
318+
----------
319+
data_list : list
320+
The list to create the node from.
321+
322+
Returns
323+
-------
324+
Node
325+
The created node.
326+
"""
327+
variables = {}
328+
children = {}
329+
data = {}
330+
for ii, vv in enumerate(data_list):
331+
if isinstance(vv, dict):
332+
children[f"{ii:d}"] = cls.from_dict(vv)
333+
elif isinstance(vv, list):
334+
children[f"{ii:d}"] = cls.from_list(vv)
335+
else:
336+
data[f"{ii:d}"] = vv
337+
return cls("ListNode", children, data, variables)
338+
339+
def __str__(self) -> str:
340+
elbow = "└──"
341+
pipe = "│ "
342+
tee = "├──"
343+
blank = " "
344+
linebreak = "\n"
345+
buff = []
346+
buff.append(f"{self.name} (size={format_big_number(self.size)})")
347+
children_buff = []
348+
for ii, (kk, vv) in enumerate(self.children.items()):
349+
# add indentation
350+
child_repr = str(vv)
351+
if len(children_buff) > 0:
352+
# check if it is the same as the last one
353+
last_repr = children_buff[-1][1]
354+
if child_repr == last_repr:
355+
# merge
356+
last_kk, _ = children_buff[-1]
357+
children_buff[-1] = (f"{last_kk}, {kk}", last_repr)
358+
continue
359+
children_buff.append((kk, child_repr))
360+
361+
def format_list_keys(kk: str) -> str:
362+
if self.name == "ListNode":
363+
keys = kk.split(", ")
364+
if len(keys) > 2:
365+
return f"[{keys[0]}...{keys[-1]}]"
366+
return kk
367+
368+
def format_value(vv: str, current_index: int) -> str:
369+
return vv.replace(
370+
linebreak,
371+
linebreak + (pipe if current_index < len(children_buff) - 1 else blank),
372+
)
373+
374+
buff.extend(
375+
f"{tee if ii < len(children_buff) - 1 else elbow}{format_list_keys(kk)} -> {format_value(vv, ii)}"
376+
for ii, (kk, vv) in enumerate(children_buff)
377+
)
378+
return "\n".join(buff)

source/tests/common/dpmodel/test_network.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
22
import itertools
33
import os
4+
import textwrap
45
import unittest
56
from copy import (
67
deepcopy,
@@ -20,6 +21,9 @@
2021
load_dp_model,
2122
save_dp_model,
2223
)
24+
from deepmd.dpmodel.utils.serialization import (
25+
Node,
26+
)
2327

2428

2529
class TestNativeLayer(unittest.TestCase):
@@ -273,6 +277,7 @@ def setUp(self) -> None:
273277
self.w = np.full((3, 2), 3.0)
274278
self.b = np.full((3,), 4.0)
275279
self.model_dict = {
280+
"@class": "some_class",
276281
"type": "some_type",
277282
"layers": [
278283
{
@@ -304,6 +309,14 @@ def test_save_load_model_yaml(self) -> None:
304309
assert "software" in model
305310
assert "version" in model
306311

312+
def test_node_display(self):
313+
disp_expected = textwrap.dedent("""\
314+
some_class some_type (size=18)
315+
└──layers -> ListNode (size=18)
316+
└──0, 1 -> Node (size=9)""")
317+
disp = str(Node.deserialize(self.model_dict))
318+
self.assertEqual(disp, disp_expected)
319+
307320
def tearDown(self) -> None:
308321
if os.path.exists(self.filename):
309322
os.remove(self.filename)

0 commit comments

Comments
 (0)