Skip to content
Open
206 changes: 206 additions & 0 deletions deepmd/dpmodel/utils/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from collections.abc import (
Callable,
)
from copy import (
deepcopy,
)
from functools import (
cached_property,
)
from pathlib import (
Path,
)
Expand Down Expand Up @@ -170,3 +176,203 @@ def convert_numpy_ndarray(x: Any) -> Any:
else:
raise ValueError(f"Unknown filename extension: {filename_extension}")
return model_dict


def format_big_number(x: int) -> str:
"""Format a big number with suffixes.
Parameters
----------
x : int
The number to format.
Returns
-------
str
The formatted string.
"""
if x >= 1_000_000_000:
return f"{x / 1_000_000_000:.1f}B"
elif x >= 1_000_000:
return f"{x / 1_000_000:.1f}M"
elif x >= 1_000:
return f"{x / 1_000:.1f}K"
else:
return str(x)


class Node:
"""A node in a serialization tree.
Examples
--------
>>> model_dict = load_dp_model("model.dp") # Example filename
>>> root_node = Node.deserialize(model_dict["model"])
>>> print(root_node)
"""

def __init__(
self,
name: str,
children: dict[str, "Node"],
data: dict[str, Any],
variables: dict[str, Any],
) -> None:
self.name = name
self.children: dict[str, Node] = children
self.data: dict[str, Any] = data
self.variables: dict[str, Any] = variables

@cached_property
def size(self) -> int:
"""Get the size of the node.
Returns
-------
int
The size of the node.
"""
total_size = 0

def count_variables(x: Any) -> Any:
nonlocal total_size
if isinstance(x, np.ndarray):
total_size += x.size
return x

traverse_model_dict(
self.variables,
count_variables,
is_variable=True,
)
for child in self.children.values():
total_size += child.size
return total_size

@classmethod
def deserialize(cls, data: Any) -> "Node":
"""Deserialize a Node from a dictionary.
Parameters
----------
data : Any
The data to deserialize from.
Returns
-------
Node
The deserialized node.
"""
if isinstance(data, dict):
return cls.from_dict(data)
elif isinstance(data, list):
return cls.from_list(data)
else:
raise ValueError("Cannot deserialize Node from non-dict/list data.")

@classmethod
def from_dict(cls, data_dict: dict) -> "Node":
"""Create a Node from a dictionary.
Parameters
----------
data_dict : dict
The dictionary to create the node from.
Returns
-------
Node
The created node.
"""
class_name = data_dict.get("@class")
type_name = data_dict.get("type")
if class_name is not None:
if type_name is not None:
name = f"{class_name} {type_name}"
else:
name = class_name
else:
name = "Node"
variables = {}
children = {}
data = {}
for kk, vv in data_dict.items():
if kk == "@variables":
variables = deepcopy(vv)
elif isinstance(vv, dict):
children[kk] = cls.from_dict(vv)
elif isinstance(vv, list):
# drop if no children inside a list
list_node = cls.from_list(vv)
if len(list_node.children) > 0:
children[kk] = list_node
else:
data[kk] = vv
return cls(name, children, data, variables)

@classmethod
def from_list(cls, data_list: list[Any]) -> "Node":
"""Create a Node from a list.
Parameters
----------
data_list : list
The list to create the node from.
Returns
-------
Node
The created node.
"""
variables = {}
children = {}
data = {}
for ii, vv in enumerate(data_list):
if isinstance(vv, dict):
children[f"{ii:d}"] = cls.from_dict(vv)
elif isinstance(vv, list):
children[f"{ii:d}"] = cls.from_list(vv)
else:
data[f"{ii:d}"] = vv
return cls("ListNode", children, data, variables)

def __str__(self) -> str:
elbow = "└──"
pipe = "│ "
tee = "├──"
blank = " "
linebreak = "\n"
buff = []
buff.append(f"{self.name} (size={format_big_number(self.size)})")
children_buff = []
for ii, (kk, vv) in enumerate(self.children.items()):
# add indentation
child_repr = str(vv)
if len(children_buff) > 0:
# check if it is the same as the last one
last_repr = children_buff[-1][1]
if child_repr == last_repr:
# merge
last_kk, _ = children_buff[-1]
children_buff[-1] = (f"{last_kk}, {kk}", last_repr)
continue
children_buff.append((kk, child_repr))

def format_list_keys(kk: str) -> str:
if self.name == "ListNode":
keys = kk.split(", ")
if len(keys) > 2:
return f"[{keys[0]}...{keys[-1]}]"
return kk

def format_value(vv: str, current_index: int) -> str:
return vv.replace(
linebreak,
linebreak + (pipe if current_index < len(children_buff) - 1 else blank),
)

buff.extend(
f"{tee if ii < len(children_buff) - 1 else elbow}{format_list_keys(kk)} -> {format_value(vv, ii)}"
for ii, (kk, vv) in enumerate(children_buff)
)
return "\n".join(buff)
13 changes: 13 additions & 0 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import itertools
import os
import textwrap
import unittest
from copy import (
deepcopy,
Expand All @@ -20,6 +21,9 @@
load_dp_model,
save_dp_model,
)
from deepmd.dpmodel.utils.serialization import (
Node,
)


class TestNativeLayer(unittest.TestCase):
Expand Down Expand Up @@ -273,6 +277,7 @@ def setUp(self) -> None:
self.w = np.full((3, 2), 3.0)
self.b = np.full((3,), 4.0)
self.model_dict = {
"@class": "some_class",
"type": "some_type",
"layers": [
{
Expand Down Expand Up @@ -304,6 +309,14 @@ def test_save_load_model_yaml(self) -> None:
assert "software" in model
assert "version" in model

def test_node_display(self):
disp_expected = textwrap.dedent("""\
some_class some_type (size=18)
└──layers -> ListNode (size=18)
└──0, 1 -> Node (size=9)""")
disp = str(Node.deserialize(self.model_dict))
self.assertEqual(disp, disp_expected)

def tearDown(self) -> None:
if os.path.exists(self.filename):
os.remove(self.filename)
Expand Down