Skip to content

Commit 6bc730f

Browse files
Enable Hybrid Descriptor to be compressed (#4297)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Introduced a new method `enable_compression` in the `DescrptHybrid` class, allowing users to configure compression settings related to neighbor distance and table parameters. - **Documentation** - Enhanced documentation for the `call` method, providing clearer descriptions of parameters and return values. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f129cff commit 6bc730f

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

deepmd/dpmodel/descriptor/hybrid.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,38 @@ def get_stat_mean_and_stddev(
210210
stddev_list.append(stddev_item)
211211
return mean_list, stddev_list
212212

213+
def enable_compression(
214+
self,
215+
min_nbor_dist: float,
216+
table_extrapolate: float = 5,
217+
table_stride_1: float = 0.01,
218+
table_stride_2: float = 0.1,
219+
check_frequency: int = -1,
220+
) -> None:
221+
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
222+
223+
Parameters
224+
----------
225+
min_nbor_dist
226+
The nearest distance between atoms
227+
table_extrapolate
228+
The scale of model extrapolation
229+
table_stride_1
230+
The uniform stride of the first table
231+
table_stride_2
232+
The uniform stride of the second table
233+
check_frequency
234+
The overflow check frequency
235+
"""
236+
for descrpt in self.descrpt_list:
237+
descrpt.enable_compression(
238+
min_nbor_dist,
239+
table_extrapolate,
240+
table_stride_1,
241+
table_stride_2,
242+
check_frequency,
243+
)
244+
213245
def call(
214246
self,
215247
coord_ext,

deepmd/pt/model/descriptor/hybrid.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,38 @@ def get_stat_mean_and_stddev(
224224
stddev_list.append(stddev_item)
225225
return mean_list, stddev_list
226226

227+
def enable_compression(
228+
self,
229+
min_nbor_dist: float,
230+
table_extrapolate: float = 5,
231+
table_stride_1: float = 0.01,
232+
table_stride_2: float = 0.1,
233+
check_frequency: int = -1,
234+
) -> None:
235+
"""Receive the statisitcs (distance, max_nbor_size and env_mat_range) of the training data.
236+
237+
Parameters
238+
----------
239+
min_nbor_dist
240+
The nearest distance between atoms
241+
table_extrapolate
242+
The scale of model extrapolation
243+
table_stride_1
244+
The uniform stride of the first table
245+
table_stride_2
246+
The uniform stride of the second table
247+
check_frequency
248+
The overflow check frequency
249+
"""
250+
for descrpt in self.descrpt_list:
251+
descrpt.enable_compression(
252+
min_nbor_dist,
253+
table_extrapolate,
254+
table_stride_1,
255+
table_stride_2,
256+
check_frequency,
257+
)
258+
227259
def forward(
228260
self,
229261
coord_ext: torch.Tensor,

0 commit comments

Comments
 (0)