Skip to content

Commit df04bd7

Browse files
reformat func for further merging with pt version (#2946)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6d973ef commit df04bd7

File tree

2 files changed

+52
-58
lines changed

2 files changed

+52
-58
lines changed

deepmd/utils/data_system.py

Lines changed: 45 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -353,28 +353,15 @@ def set_sys_probs(self, sys_probs=None, auto_prob_style: str = "prob_sys_size"):
353353
elif auto_prob_style == "prob_sys_size":
354354
probs = self.prob_nbatches
355355
elif auto_prob_style[:14] == "prob_sys_size;":
356-
probs = self._prob_sys_size_ext(auto_prob_style)
356+
probs = prob_sys_size_ext(
357+
auto_prob_style, self.get_nsystems(), self.nbatches
358+
)
357359
else:
358360
raise RuntimeError("Unknown auto prob style: " + auto_prob_style)
359361
else:
360-
probs = self._process_sys_probs(sys_probs)
362+
probs = process_sys_probs(sys_probs, self.nbatches)
361363
self.sys_probs = probs
362364

363-
def _get_sys_probs(self, sys_probs, auto_prob_style): # depreciated
364-
if sys_probs is None:
365-
if auto_prob_style == "prob_uniform":
366-
prob_v = 1.0 / float(self.nsystems)
367-
prob = [prob_v for ii in range(self.nsystems)]
368-
elif auto_prob_style == "prob_sys_size":
369-
prob = self.prob_nbatches
370-
elif auto_prob_style[:14] == "prob_sys_size;":
371-
prob = self._prob_sys_size_ext(auto_prob_style)
372-
else:
373-
raise RuntimeError("unknown style " + auto_prob_style)
374-
else:
375-
prob = self._process_sys_probs(sys_probs)
376-
return prob
377-
378365
def get_batch(self, sys_idx: Optional[int] = None) -> dict:
379366
# batch generation style altered by Ziyao Li:
380367
# one should specify the "sys_prob" and "auto_prob_style" params
@@ -623,42 +610,44 @@ def _check_type_map_consistency(self, type_map_list):
623610
ret = ii
624611
return ret
625612

626-
def _process_sys_probs(self, sys_probs):
627-
sys_probs = np.array(sys_probs)
628-
type_filter = sys_probs >= 0
629-
assigned_sum_prob = np.sum(type_filter * sys_probs)
630-
# 1e-8 is to handle floating point error; See #1917
631-
assert (
632-
assigned_sum_prob <= 1.0 + 1e-8
633-
), "the sum of assigned probability should be less than 1"
634-
rest_sum_prob = 1.0 - assigned_sum_prob
635-
if not np.isclose(rest_sum_prob, 0):
636-
rest_nbatch = (1 - type_filter) * self.nbatches
637-
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
638-
ret_prob = rest_prob + type_filter * sys_probs
639-
else:
640-
ret_prob = sys_probs
641-
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
642-
return ret_prob
643-
644-
def _prob_sys_size_ext(self, keywords):
645-
block_str = keywords.split(";")[1:]
646-
block_stt = []
647-
block_end = []
648-
block_weights = []
649-
for ii in block_str:
650-
stt = int(ii.split(":")[0])
651-
end = int(ii.split(":")[1])
652-
weight = float(ii.split(":")[2])
653-
assert weight >= 0, "the weight of a block should be no less than 0"
654-
block_stt.append(stt)
655-
block_end.append(end)
656-
block_weights.append(weight)
657-
nblocks = len(block_str)
658-
block_probs = np.array(block_weights) / np.sum(block_weights)
659-
sys_probs = np.zeros([self.get_nsystems()])
660-
for ii in range(nblocks):
661-
nbatch_block = self.nbatches[block_stt[ii] : block_end[ii]]
662-
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
663-
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
664-
return sys_probs
613+
614+
def process_sys_probs(sys_probs, nbatch):
615+
sys_probs = np.array(sys_probs)
616+
type_filter = sys_probs >= 0
617+
assigned_sum_prob = np.sum(type_filter * sys_probs)
618+
# 1e-8 is to handle floating point error; See #1917
619+
assert (
620+
assigned_sum_prob <= 1.0 + 1e-8
621+
), "the sum of assigned probability should be less than 1"
622+
rest_sum_prob = 1.0 - assigned_sum_prob
623+
if not np.isclose(rest_sum_prob, 0):
624+
rest_nbatch = (1 - type_filter) * nbatch
625+
rest_prob = rest_sum_prob * rest_nbatch / np.sum(rest_nbatch)
626+
ret_prob = rest_prob + type_filter * sys_probs
627+
else:
628+
ret_prob = sys_probs
629+
assert np.isclose(np.sum(ret_prob), 1), "sum of probs should be 1"
630+
return ret_prob
631+
632+
633+
def prob_sys_size_ext(keywords, nsystems, nbatch):
634+
block_str = keywords.split(";")[1:]
635+
block_stt = []
636+
block_end = []
637+
block_weights = []
638+
for ii in block_str:
639+
stt = int(ii.split(":")[0])
640+
end = int(ii.split(":")[1])
641+
weight = float(ii.split(":")[2])
642+
assert weight >= 0, "the weight of a block should be no less than 0"
643+
block_stt.append(stt)
644+
block_end.append(end)
645+
block_weights.append(weight)
646+
nblocks = len(block_str)
647+
block_probs = np.array(block_weights) / np.sum(block_weights)
648+
sys_probs = np.zeros([nsystems])
649+
for ii in range(nblocks):
650+
nbatch_block = nbatch[block_stt[ii] : block_end[ii]]
651+
tmp_prob = [float(i) for i in nbatch_block] / np.sum(nbatch_block)
652+
sys_probs[block_stt[ii] : block_end[ii]] = tmp_prob * block_probs[ii]
653+
return sys_probs

source/tests/test_deepmd_data_sys.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from deepmd.utils.data_system import (
1515
DeepmdDataSystem,
16+
prob_sys_size_ext,
1617
)
1718

1819
if GLOBAL_NP_FLOAT_PRECISION == np.float32:
@@ -310,7 +311,9 @@ def test_prob_sys_size_1(self):
310311
batch_size = 1
311312
test_size = 1
312313
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
313-
prob = ds._prob_sys_size_ext("prob_sys_size; 0:2:2; 2:4:8")
314+
prob = prob_sys_size_ext(
315+
"prob_sys_size; 0:2:2; 2:4:8", ds.get_nsystems(), ds.get_nbatches()
316+
)
314317
self.assertAlmostEqual(np.sum(prob), 1)
315318
self.assertAlmostEqual(np.sum(prob[0:2]), 0.2)
316319
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)
@@ -332,7 +335,9 @@ def test_prob_sys_size_2(self):
332335
batch_size = 1
333336
test_size = 1
334337
ds = DeepmdDataSystem(self.sys_name, batch_size, test_size, 2.0)
335-
prob = ds._prob_sys_size_ext("prob_sys_size; 1:2:0.4; 2:4:1.6")
338+
prob = prob_sys_size_ext(
339+
"prob_sys_size; 1:2:0.4; 2:4:1.6", ds.get_nsystems(), ds.get_nbatches()
340+
)
336341
self.assertAlmostEqual(np.sum(prob), 1)
337342
self.assertAlmostEqual(np.sum(prob[1:2]), 0.2)
338343
self.assertAlmostEqual(np.sum(prob[2:4]), 0.8)

0 commit comments

Comments
 (0)