Skip to content

Commit bfbe2ed

Browse files
Add compression API to BaseModel and AtomicModel (#4298)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced `enable_compression` method across multiple classes to allow configuration of compression settings for descriptors. - Enhanced robustness of output definitions and serialization processes in the `DPAtomicModel` class. - Added `enable_compression` method to `LinearEnergyAtomicModel` for improved model compression capabilities. - **Bug Fixes** - Improved error handling in the `fitting_output_def` method to ensure fallback functionality when the fitting network is unavailable. These updates enhance the functionality and reliability of the model management features. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 25bb821 commit bfbe2ed

File tree

8 files changed

+229
-0
lines changed

8 files changed

+229
-0
lines changed

deepmd/dpmodel/atomic_model/dp_atomic_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,37 @@ def need_sorted_nlist_for_lower(self) -> bool:
8686
"""Returns whether the atomic model needs sorted nlist when using `forward_lower`."""
8787
return self.descriptor.need_sorted_nlist_for_lower()
8888

89+
def enable_compression(
90+
self,
91+
min_nbor_dist: float,
92+
table_extrapolate: float = 5,
93+
table_stride_1: float = 0.01,
94+
table_stride_2: float = 0.1,
95+
check_frequency: int = -1,
96+
) -> None:
97+
"""Call descriptor enable_compression().
98+
99+
Parameters
100+
----------
101+
min_nbor_dist
102+
The nearest distance between atoms
103+
table_extrapolate
104+
The scale of model extrapolation
105+
table_stride_1
106+
The uniform stride of the first table
107+
table_stride_2
108+
The uniform stride of the second table
109+
check_frequency
110+
The overflow check frequency
111+
"""
112+
self.descriptor.enable_compression(
113+
min_nbor_dist,
114+
table_extrapolate,
115+
table_stride_1,
116+
table_stride_2,
117+
check_frequency,
118+
)
119+
89120
def forward_atomic(
90121
self,
91122
extended_coord: np.ndarray,

deepmd/dpmodel/atomic_model/linear_atomic_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,38 @@ def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
149149
)
150150
return [p[0] for p in zipped], [p[1] for p in zipped]
151151

152+
def enable_compression(
153+
self,
154+
min_nbor_dist: float,
155+
table_extrapolate: float = 5,
156+
table_stride_1: float = 0.01,
157+
table_stride_2: float = 0.1,
158+
check_frequency: int = -1,
159+
) -> None:
160+
"""Compress model.
161+
162+
Parameters
163+
----------
164+
min_nbor_dist
165+
The nearest distance between atoms
166+
table_extrapolate
167+
The scale of model extrapolation
168+
table_stride_1
169+
The uniform stride of the first table
170+
table_stride_2
171+
The uniform stride of the second table
172+
check_frequency
173+
The overflow check frequency
174+
"""
175+
for model in self.models:
176+
model.enable_compression(
177+
min_nbor_dist,
178+
table_extrapolate,
179+
table_stride_1,
180+
table_stride_2,
181+
check_frequency,
182+
)
183+
152184
def forward_atomic(
153185
self,
154186
extended_coord,

deepmd/dpmodel/atomic_model/make_base_atomic_model.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,31 @@ def change_type_map(
148148
) -> None:
149149
pass
150150

151+
def enable_compression(
152+
self,
153+
min_nbor_dist: float,
154+
table_extrapolate: float = 5,
155+
table_stride_1: float = 0.01,
156+
table_stride_2: float = 0.1,
157+
check_frequency: int = -1,
158+
) -> None:
159+
"""Call descriptor enable_compression().
160+
161+
Parameters
162+
----------
163+
min_nbor_dist
164+
The nearest distance between atoms
165+
table_extrapolate
166+
The scale of model extrapolation
167+
table_stride_1
168+
The uniform stride of the first table
169+
table_stride_2
170+
The uniform stride of the second table
171+
check_frequency
172+
The overflow check frequency
173+
"""
174+
raise NotImplementedError("This atomi model doesn't support compression!")
175+
151176
def make_atom_mask(
152177
self,
153178
atype: t_tensor,

deepmd/dpmodel/model/base_model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,28 @@ def update_sel(
191191
cls = cls.get_class_by_type(model_type)
192192
return cls.update_sel(train_data, type_map, local_jdata)
193193

194+
def enable_compression(
195+
self,
196+
table_extrapolate: float = 5,
197+
table_stride_1: float = 0.01,
198+
table_stride_2: float = 0.1,
199+
check_frequency: int = -1,
200+
) -> None:
201+
"""Enable model compression by tabulation.
202+
203+
Parameters
204+
----------
205+
table_extrapolate
206+
The scale of model extrapolation
207+
table_stride_1
208+
The uniform stride of the first table
209+
table_stride_2
210+
The uniform stride of the second table
211+
check_frequency
212+
The overflow check frequency
213+
"""
214+
raise NotImplementedError("This atomic model doesn't support compression!")
215+
194216
@classmethod
195217
def get_model(cls, model_params: dict) -> "BaseBaseModel":
196218
"""Get the model by the parameters.

deepmd/dpmodel/model/make_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,34 @@ def model_output_type(self) -> list[str]:
186186
]
187187
return vars
188188

189+
def enable_compression(
190+
self,
191+
table_extrapolate: float = 5,
192+
table_stride_1: float = 0.01,
193+
table_stride_2: float = 0.1,
194+
check_frequency: int = -1,
195+
) -> None:
196+
"""Call atomic_model enable_compression().
197+
198+
Parameters
199+
----------
200+
table_extrapolate
201+
The scale of model extrapolation
202+
table_stride_1
203+
The uniform stride of the first table
204+
table_stride_2
205+
The uniform stride of the second table
206+
check_frequency
207+
The overflow check frequency
208+
"""
209+
self.atomic_model.enable_compression(
210+
self.get_min_nbor_dist(),
211+
table_extrapolate,
212+
table_stride_1,
213+
table_stride_2,
214+
check_frequency,
215+
)
216+
189217
def call(
190218
self,
191219
coord,

deepmd/pt/model/atomic_model/dp_atomic_model.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,37 @@ def deserialize(cls, data) -> "DPAtomicModel":
160160
obj = super().deserialize(data)
161161
return obj
162162

163+
def enable_compression(
164+
self,
165+
min_nbor_dist: float,
166+
table_extrapolate: float = 5,
167+
table_stride_1: float = 0.01,
168+
table_stride_2: float = 0.1,
169+
check_frequency: int = -1,
170+
) -> None:
171+
"""Call descriptor enable_compression().
172+
173+
Parameters
174+
----------
175+
min_nbor_dist
176+
The nearest distance between atoms
177+
table_extrapolate
178+
The scale of model extrapolation
179+
table_stride_1
180+
The uniform stride of the first table
181+
table_stride_2
182+
The uniform stride of the second table
183+
check_frequency
184+
The overflow check frequency
185+
"""
186+
self.descriptor.enable_compression(
187+
min_nbor_dist,
188+
table_extrapolate,
189+
table_stride_1,
190+
table_stride_2,
191+
check_frequency,
192+
)
193+
163194
def forward_atomic(
164195
self,
165196
extended_coord,

deepmd/pt/model/atomic_model/linear_atomic_model.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,38 @@ def _sort_rcuts_sels(self) -> tuple[list[float], list[int]]:
184184
sorted_sels: list[int] = outer_sorted[:, 1].to(torch.int64).tolist()
185185
return sorted_rcuts, sorted_sels
186186

187+
def enable_compression(
188+
self,
189+
min_nbor_dist: float,
190+
table_extrapolate: float = 5,
191+
table_stride_1: float = 0.01,
192+
table_stride_2: float = 0.1,
193+
check_frequency: int = -1,
194+
) -> None:
195+
"""Compress model.
196+
197+
Parameters
198+
----------
199+
min_nbor_dist
200+
The nearest distance between atoms
201+
table_extrapolate
202+
The scale of model extrapolation
203+
table_stride_1
204+
The uniform stride of the first table
205+
table_stride_2
206+
The uniform stride of the second table
207+
check_frequency
208+
The overflow check frequency
209+
"""
210+
for model in self.models:
211+
model.enable_compression(
212+
min_nbor_dist,
213+
table_extrapolate,
214+
table_stride_1,
215+
table_stride_2,
216+
check_frequency,
217+
)
218+
187219
def forward_atomic(
188220
self,
189221
extended_coord: torch.Tensor,

deepmd/pt/model/model/make_model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,34 @@ def model_output_type(self) -> list[str]:
9898
vars.append(kk)
9999
return vars
100100

101+
def enable_compression(
102+
self,
103+
table_extrapolate: float = 5,
104+
table_stride_1: float = 0.01,
105+
table_stride_2: float = 0.1,
106+
check_frequency: int = -1,
107+
) -> None:
108+
"""Call atomic_model enable_compression().
109+
110+
Parameters
111+
----------
112+
table_extrapolate
113+
The scale of model extrapolation
114+
table_stride_1
115+
The uniform stride of the first table
116+
table_stride_2
117+
The uniform stride of the second table
118+
check_frequency
119+
The overflow check frequency
120+
"""
121+
self.atomic_model.enable_compression(
122+
self.get_min_nbor_dist(),
123+
table_extrapolate,
124+
table_stride_1,
125+
table_stride_2,
126+
check_frequency,
127+
)
128+
101129
# cannot use the name forward. torch script does not work
102130
def forward_common(
103131
self,

0 commit comments

Comments
 (0)