diff --git a/deepmd/dpmodel/utils/serialization.py b/deepmd/dpmodel/utils/serialization.py index 74d31ea589..36eb42a981 100644 --- a/deepmd/dpmodel/utils/serialization.py +++ b/deepmd/dpmodel/utils/serialization.py @@ -4,6 +4,12 @@ from collections.abc import ( Callable, ) +from copy import ( + deepcopy, +) +from functools import ( + cached_property, +) from pathlib import ( Path, ) @@ -170,3 +176,203 @@ def convert_numpy_ndarray(x: Any) -> Any: else: raise ValueError(f"Unknown filename extension: {filename_extension}") return model_dict + + +def format_big_number(x: int) -> str: + """Format a big number with suffixes. + + Parameters + ---------- + x : int + The number to format. + + Returns + ------- + str + The formatted string. + """ + if x >= 1_000_000_000: + return f"{x / 1_000_000_000:.1f}B" + elif x >= 1_000_000: + return f"{x / 1_000_000:.1f}M" + elif x >= 1_000: + return f"{x / 1_000:.1f}K" + else: + return str(x) + + +class Node: + """A node in a serialization tree. + + Examples + -------- + >>> model_dict = load_dp_model("model.dp") # Example filename + >>> root_node = Node.deserialize(model_dict["model"]) + >>> print(root_node) + """ + + def __init__( + self, + name: str, + children: dict[str, "Node"], + data: dict[str, Any], + variables: dict[str, Any], + ) -> None: + self.name = name + self.children: dict[str, Node] = children + self.data: dict[str, Any] = data + self.variables: dict[str, Any] = variables + + @cached_property + def size(self) -> int: + """Get the size of the node. + + Returns + ------- + int + The size of the node. + """ + total_size = 0 + + def count_variables(x: Any) -> Any: + nonlocal total_size + if isinstance(x, np.ndarray): + total_size += x.size + return x + + traverse_model_dict( + self.variables, + count_variables, + is_variable=True, + ) + for child in self.children.values(): + total_size += child.size + return total_size + + @classmethod + def deserialize(cls, data: Any) -> "Node": + """Deserialize a Node from a dictionary. + + Parameters + ---------- + data : Any + The data to deserialize from. + + Returns + ------- + Node + The deserialized node. + """ + if isinstance(data, dict): + return cls.from_dict(data) + elif isinstance(data, list): + return cls.from_list(data) + else: + raise ValueError("Cannot deserialize Node from non-dict/list data.") + + @classmethod + def from_dict(cls, data_dict: dict) -> "Node": + """Create a Node from a dictionary. + + Parameters + ---------- + data_dict : dict + The dictionary to create the node from. + + Returns + ------- + Node + The created node. + """ + class_name = data_dict.get("@class") + type_name = data_dict.get("type") + if class_name is not None: + if type_name is not None: + name = f"{class_name} {type_name}" + else: + name = class_name + else: + name = "Node" + variables = {} + children = {} + data = {} + for kk, vv in data_dict.items(): + if kk == "@variables": + variables = deepcopy(vv) + elif isinstance(vv, dict): + children[kk] = cls.from_dict(vv) + elif isinstance(vv, list): + # drop if no children inside a list + list_node = cls.from_list(vv) + if len(list_node.children) > 0: + children[kk] = list_node + else: + data[kk] = vv + return cls(name, children, data, variables) + + @classmethod + def from_list(cls, data_list: list[Any]) -> "Node": + """Create a Node from a list. + + Parameters + ---------- + data_list : list + The list to create the node from. + + Returns + ------- + Node + The created node. + """ + variables = {} + children = {} + data = {} + for ii, vv in enumerate(data_list): + if isinstance(vv, dict): + children[f"{ii:d}"] = cls.from_dict(vv) + elif isinstance(vv, list): + children[f"{ii:d}"] = cls.from_list(vv) + else: + data[f"{ii:d}"] = vv + return cls("ListNode", children, data, variables) + + def __str__(self) -> str: + elbow = "└──" + pipe = "│ " + tee = "├──" + blank = " " + linebreak = "\n" + buff = [] + buff.append(f"{self.name} (size={format_big_number(self.size)})") + children_buff = [] + for ii, (kk, vv) in enumerate(self.children.items()): + # add indentation + child_repr = str(vv) + if len(children_buff) > 0: + # check if it is the same as the last one + last_repr = children_buff[-1][1] + if child_repr == last_repr: + # merge + last_kk, _ = children_buff[-1] + children_buff[-1] = (f"{last_kk}, {kk}", last_repr) + continue + children_buff.append((kk, child_repr)) + + def format_list_keys(kk: str) -> str: + if self.name == "ListNode": + keys = kk.split(", ") + if len(keys) > 2: + return f"[{keys[0]}...{keys[-1]}]" + return kk + + def format_value(vv: str, current_index: int) -> str: + return vv.replace( + linebreak, + linebreak + (pipe if current_index < len(children_buff) - 1 else blank), + ) + + buff.extend( + f"{tee if ii < len(children_buff) - 1 else elbow}{format_list_keys(kk)} -> {format_value(vv, ii)}" + for ii, (kk, vv) in enumerate(children_buff) + ) + return "\n".join(buff) diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 3feb64f72f..1ea5b1fdf9 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import itertools import os +import textwrap import unittest from copy import ( deepcopy, @@ -20,6 +21,9 @@ load_dp_model, save_dp_model, ) +from deepmd.dpmodel.utils.serialization import ( + Node, +) class TestNativeLayer(unittest.TestCase): @@ -273,6 +277,7 @@ def setUp(self) -> None: self.w = np.full((3, 2), 3.0) self.b = np.full((3,), 4.0) self.model_dict = { + "@class": "some_class", "type": "some_type", "layers": [ { @@ -304,6 +309,14 @@ def test_save_load_model_yaml(self) -> None: assert "software" in model assert "version" in model + def test_node_display(self): + disp_expected = textwrap.dedent("""\ + some_class some_type (size=18) + └──layers -> ListNode (size=18) + └──0, 1 -> Node (size=9)""") + disp = str(Node.deserialize(self.model_dict)) + self.assertEqual(disp, disp_expected) + def tearDown(self) -> None: if os.path.exists(self.filename): os.remove(self.filename)