Skip to content

Commit 8355947

Browse files
cherryWangYpre-commit-ci[bot]njzjz
authored
Add 4 pt descriptor compression (#4227)
se_a, se_atten(DPA1), se_t, se_r <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes - **New Features** - Introduced a model compression feature across multiple descriptor classes, enhancing performance and efficiency. - Added `enable_compression` methods to various classes, allowing users to enable and configure compression settings. - **Bug Fixes** - Improved error handling for unsupported compression scenarios and parameter validation. - **Tests** - Added comprehensive unit tests for new compression functionalities across multiple descriptor classes to ensure accuracy and reliability. - **Documentation** - Enhanced documentation for new methods and classes to clarify usage and parameters related to compression. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Jinzhe Zeng <[email protected]> Signed-off-by: Yan Wang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jinzhe Zeng <[email protected]>
1 parent eb2832b commit 8355947

File tree

15 files changed

+2377
-388
lines changed

15 files changed

+2377
-388
lines changed

deepmd/dpmodel/descriptor/make_base_descriptor.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,31 @@ def compute_input_stats(
147147
"""Update mean and stddev for descriptor elements."""
148148
raise NotImplementedError
149149

150+
def enable_compression(
151+
self,
152+
min_nbor_dist: float,
153+
table_extrapolate: float = 5,
154+
table_stride_1: float = 0.01,
155+
table_stride_2: float = 0.1,
156+
check_frequency: int = -1,
157+
) -> None:
158+
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
159+
160+
Parameters
161+
----------
162+
min_nbor_dist
163+
The nearest distance between atoms
164+
table_extrapolate
165+
The scale of model extrapolation
166+
table_stride_1
167+
The uniform stride of the first table
168+
table_stride_2
169+
The uniform stride of the second table
170+
check_frequency
171+
The overflow check frequency
172+
"""
173+
raise NotImplementedError("This descriptor doesn't support compression!")
174+
150175
@abstractmethod
151176
def fwd(
152177
self,

deepmd/pt/model/descriptor/dpa1.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,15 @@
2424
from deepmd.pt.utils.env import (
2525
RESERVED_PRECISON_DICT,
2626
)
27+
from deepmd.pt.utils.tabulate import (
28+
DPTabulate,
29+
)
2730
from deepmd.pt.utils.update_sel import (
2831
UpdateSel,
2932
)
33+
from deepmd.pt.utils.utils import (
34+
ActivationFn,
35+
)
3036
from deepmd.utils.data_system import (
3137
DeepmdDataSystem,
3238
)
@@ -261,6 +267,8 @@ def __init__(
261267
if ln_eps is None:
262268
ln_eps = 1e-5
263269

270+
self.tebd_input_mode = tebd_input_mode
271+
264272
del type, spin, attn_mask
265273
self.se_atten = DescrptBlockSeAtten(
266274
rcut,
@@ -293,6 +301,7 @@ def __init__(
293301
self.use_econf_tebd = use_econf_tebd
294302
self.use_tebd_bias = use_tebd_bias
295303
self.type_map = type_map
304+
self.compress = False
296305
self.type_embedding = TypeEmbedNet(
297306
ntypes,
298307
tebd_dim,
@@ -551,6 +560,84 @@ def t_cvt(xx):
551560
)
552561
return obj
553562

563+
def enable_compression(
564+
self,
565+
min_nbor_dist: float,
566+
table_extrapolate: float = 5,
567+
table_stride_1: float = 0.01,
568+
table_stride_2: float = 0.1,
569+
check_frequency: int = -1,
570+
) -> None:
571+
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
572+
573+
Parameters
574+
----------
575+
min_nbor_dist
576+
The nearest distance between atoms
577+
table_extrapolate
578+
The scale of model extrapolation
579+
table_stride_1
580+
The uniform stride of the first table
581+
table_stride_2
582+
The uniform stride of the second table
583+
check_frequency
584+
The overflow check frequency
585+
"""
586+
# do some checks before the mocel compression process
587+
if self.compress:
588+
raise ValueError("Compression is already enabled.")
589+
assert (
590+
not self.se_atten.resnet_dt
591+
), "Model compression error: descriptor resnet_dt must be false!"
592+
for tt in self.se_atten.exclude_types:
593+
if (tt[0] not in range(self.se_atten.ntypes)) or (
594+
tt[1] not in range(self.se_atten.ntypes)
595+
):
596+
raise RuntimeError(
597+
"exclude types"
598+
+ str(tt)
599+
+ " must within the number of atomic types "
600+
+ str(self.se_atten.ntypes)
601+
+ "!"
602+
)
603+
if (
604+
self.se_atten.ntypes * self.se_atten.ntypes
605+
- len(self.se_atten.exclude_types)
606+
== 0
607+
):
608+
raise RuntimeError(
609+
"Empty embedding-nets are not supported in model compression!"
610+
)
611+
612+
if self.se_atten.attn_layer != 0:
613+
raise RuntimeError("Cannot compress model when attention layer is not 0.")
614+
615+
if self.tebd_input_mode != "strip":
616+
raise RuntimeError("Cannot compress model when tebd_input_mode == 'concat'")
617+
618+
data = self.serialize()
619+
self.table = DPTabulate(
620+
self,
621+
data["neuron"],
622+
data["type_one_side"],
623+
data["exclude_types"],
624+
ActivationFn(data["activation_function"]),
625+
)
626+
self.table_config = [
627+
table_extrapolate,
628+
table_stride_1,
629+
table_stride_2,
630+
check_frequency,
631+
]
632+
self.lower, self.upper = self.table.build(
633+
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
634+
)
635+
636+
self.se_atten.enable_compression(
637+
self.table.data, self.table_config, self.lower, self.upper
638+
)
639+
self.compress = True
640+
554641
def forward(
555642
self,
556643
extended_coord: torch.Tensor,

deepmd/pt/model/descriptor/se_a.py

Lines changed: 130 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,34 @@
5858
from deepmd.pt.utils.exclude_mask import (
5959
PairExcludeMask,
6060
)
61+
from deepmd.pt.utils.tabulate import (
62+
DPTabulate,
63+
)
64+
from deepmd.pt.utils.utils import (
65+
ActivationFn,
66+
)
6167

6268
from .base_descriptor import (
6369
BaseDescriptor,
6470
)
6571

72+
if not hasattr(torch.ops.deepmd, "tabulate_fusion_se_a"):
73+
74+
def tabulate_fusion_se_a(
75+
argument0,
76+
argument1,
77+
argument2,
78+
argument3,
79+
argument4,
80+
) -> list[torch.Tensor]:
81+
raise NotImplementedError(
82+
"tabulate_fusion_se_a is not available since customized PyTorch OP library is not built when freezing the model. "
83+
"See documentation for model compression for details."
84+
)
85+
86+
# Note: this hack cannot actually save a model that can be runned using LAMMPS.
87+
torch.ops.deepmd.tabulate_fusion_se_a = tabulate_fusion_se_a
88+
6689

6790
@BaseDescriptor.register("se_e2_a")
6891
@BaseDescriptor.register("se_a")
@@ -93,6 +116,7 @@ def __init__(
93116
raise NotImplementedError("old implementation of spin is not supported.")
94117
super().__init__()
95118
self.type_map = type_map
119+
self.compress = False
96120
self.sea = DescrptBlockSeA(
97121
rcut,
98122
rcut_smth,
@@ -225,6 +249,53 @@ def reinit_exclude(
225249
"""Update the type exclusions."""
226250
self.sea.reinit_exclude(exclude_types)
227251

252+
def enable_compression(
253+
self,
254+
min_nbor_dist: float,
255+
table_extrapolate: float = 5,
256+
table_stride_1: float = 0.01,
257+
table_stride_2: float = 0.1,
258+
check_frequency: int = -1,
259+
) -> None:
260+
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
261+
262+
Parameters
263+
----------
264+
min_nbor_dist
265+
The nearest distance between atoms
266+
table_extrapolate
267+
The scale of model extrapolation
268+
table_stride_1
269+
The uniform stride of the first table
270+
table_stride_2
271+
The uniform stride of the second table
272+
check_frequency
273+
The overflow check frequency
274+
"""
275+
if self.compress:
276+
raise ValueError("Compression is already enabled.")
277+
data = self.serialize()
278+
self.table = DPTabulate(
279+
self,
280+
data["neuron"],
281+
data["type_one_side"],
282+
data["exclude_types"],
283+
ActivationFn(data["activation_function"]),
284+
)
285+
self.table_config = [
286+
table_extrapolate,
287+
table_stride_1,
288+
table_stride_2,
289+
check_frequency,
290+
]
291+
self.lower, self.upper = self.table.build(
292+
min_nbor_dist, table_extrapolate, table_stride_1, table_stride_2
293+
)
294+
self.sea.enable_compression(
295+
self.table.data, self.table_config, self.lower, self.upper
296+
)
297+
self.compress = True
298+
228299
def forward(
229300
self,
230301
coord_ext: torch.Tensor,
@@ -366,6 +437,10 @@ def update_sel(
366437
class DescrptBlockSeA(DescriptorBlock):
367438
ndescrpt: Final[int]
368439
__constants__: ClassVar[list] = ["ndescrpt"]
440+
lower: dict[str, int]
441+
upper: dict[str, int]
442+
table_data: dict[str, torch.Tensor]
443+
table_config: list[Union[int, float]]
369444

370445
def __init__(
371446
self,
@@ -425,6 +500,13 @@ def __init__(
425500
self.register_buffer("mean", mean)
426501
self.register_buffer("stddev", stddev)
427502

503+
# add for compression
504+
self.compress = False
505+
self.lower = {}
506+
self.upper = {}
507+
self.table_data = {}
508+
self.table_config = []
509+
428510
ndim = 1 if self.type_one_side else 2
429511
filter_layers = NetworkCollection(
430512
ndim=ndim, ntypes=len(sel), network_type="embedding_network"
@@ -443,6 +525,7 @@ def __init__(
443525
self.filter_layers = filter_layers
444526
self.stats = None
445527
# set trainable
528+
self.trainable = trainable
446529
for param in self.parameters():
447530
param.requires_grad = trainable
448531

@@ -470,6 +553,10 @@ def get_dim_out(self) -> int:
470553
"""Returns the output dimension."""
471554
return self.dim_out
472555

556+
def get_dim_rot_mat_1(self) -> int:
557+
"""Returns the first dimension of the rotation matrix. The rotation is of shape dim_1 x 3."""
558+
return self.filter_neuron[-1]
559+
473560
def get_dim_emb(self) -> int:
474561
"""Returns the output dimension."""
475562
return self.neuron[-1]
@@ -578,6 +665,19 @@ def reinit_exclude(
578665
self.exclude_types = exclude_types
579666
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)
580667

668+
def enable_compression(
669+
self,
670+
table_data,
671+
table_config,
672+
lower,
673+
upper,
674+
) -> None:
675+
self.compress = True
676+
self.table_data = table_data
677+
self.table_config = table_config
678+
self.lower = lower
679+
self.upper = upper
680+
581681
def forward(
582682
self,
583683
nlist: torch.Tensor,
@@ -627,6 +727,7 @@ def forward(
627727
for embedding_idx, ll in enumerate(self.filter_layers.networks):
628728
if self.type_one_side:
629729
ii = embedding_idx
730+
ti = -1
630731
# torch.jit is not happy with slice(None)
631732
# ti_mask = torch.ones(nfnl, dtype=torch.bool, device=dmatrix.device)
632733
# applying a mask seems to cause performance degradation
@@ -648,10 +749,35 @@ def forward(
648749
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
649750
rr = rr * mm[:, :, None]
650751
ss = rr[:, :, :1]
651-
# nfnl x nt x ng
652-
gg = ll.forward(ss)
653-
# nfnl x 4 x ng
654-
gr = torch.matmul(rr.permute(0, 2, 1), gg)
752+
753+
if self.compress:
754+
if self.type_one_side:
755+
net = "filter_-1_net_" + str(ii)
756+
else:
757+
net = "filter_" + str(ti) + "_net_" + str(ii)
758+
info = [
759+
self.lower[net],
760+
self.upper[net],
761+
self.upper[net] * self.table_config[0],
762+
self.table_config[1],
763+
self.table_config[2],
764+
self.table_config[3],
765+
]
766+
ss = ss.reshape(-1, 1) # xyz_scatter_tensor in tf
767+
tensor_data = self.table_data[net].to(ss.device).to(dtype=self.prec)
768+
gr = torch.ops.deepmd.tabulate_fusion_se_a(
769+
tensor_data.contiguous(),
770+
torch.tensor(info, dtype=self.prec, device="cpu").contiguous(),
771+
ss.contiguous(),
772+
rr.contiguous(),
773+
self.filter_neuron[-1],
774+
)[0]
775+
else:
776+
# nfnl x nt x ng
777+
gg = ll.forward(ss)
778+
# nfnl x 4 x ng
779+
gr = torch.matmul(rr.permute(0, 2, 1), gg)
780+
655781
if ti_mask is not None:
656782
xyz_scatter[ti_mask] += gr
657783
else:

0 commit comments

Comments
 (0)