Skip to content

Commit ba087c4

Browse files
authored
automatic batch size for dp test (#1165)
* automatic batch size for `dp test` Resolves #1149. We start nbatch * natoms from 1024 (or we can set a different number), and iteratively multiply it by 2 until catching the OOM error. A small issue is that it's a bit slow to catch the TF OOM error. It's a problem of TF and I don't know how to resolve it. Luckily we only need to catch once. * replace `execuate` with `execute` * add unittest; bugfix
1 parent 53f1567 commit ba087c4

File tree

7 files changed

+188
-6
lines changed

7 files changed

+188
-6
lines changed

deepmd/entrypoints/test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from deepmd.utils import random as dp_random
1010
from deepmd.utils.data import DeepmdData
1111
from deepmd.utils.weight_avg import weighted_average
12+
from deepmd.utils.batch_size import AutoBatchSize
1213

1314
if TYPE_CHECKING:
1415
from deepmd.infer import DeepDipole, DeepPolar, DeepPot, DeepWFC
15-
from deepmd.infer.deep_eval import DeepTensor
16+
from deepmd.infer.deep_tensor import DeepTensor
1617

1718
__all__ = ["test"]
1819

@@ -69,6 +70,7 @@ def test(
6970

7071
# init model
7172
dp = DeepPotential(model)
73+
auto_batch_size = AutoBatchSize()
7274

7375
for cc, system in enumerate(all_sys):
7476
log.info("# ---------------output of dp test--------------- ")
@@ -82,6 +84,7 @@ def test(
8284
err = test_ener(
8385
dp,
8486
data,
87+
auto_batch_size,
8588
system,
8689
numb_test,
8790
detail_file,
@@ -159,6 +162,7 @@ def save_txt_file(
159162
def test_ener(
160163
dp: "DeepPot",
161164
data: DeepmdData,
165+
auto_batch_size: AutoBatchSize,
162166
system: str,
163167
numb_test: int,
164168
detail_file: Optional[str],
@@ -226,7 +230,10 @@ def test_ener(
226230
else:
227231
aparam = None
228232

229-
ret = dp.eval(
233+
ret = auto_batch_size.execute_all(
234+
dp.eval,
235+
numb_test,
236+
natoms,
230237
coord,
231238
box,
232239
atype,

deepmd/infer/deep_pot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def _eval_inner(
324324
feed_dict_test[self.t_fparam] = np.reshape(fparam, [-1])
325325
if self.has_aparam:
326326
feed_dict_test[self.t_aparam] = np.reshape(aparam, [-1])
327-
v_out = self.sess.run (t_out, feed_dict = feed_dict_test)
327+
v_out = run_sess(self.sess, t_out, feed_dict = feed_dict_test)
328328
energy = v_out[0]
329329
force = v_out[1]
330330
virial = v_out[2]

deepmd/utils/batch_size.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import logging
2+
from typing import Callable, Tuple
3+
4+
import numpy as np
5+
6+
from deepmd.utils.errors import OutOfMemoryError
7+
8+
class AutoBatchSize:
9+
"""This class allows DeePMD-kit to automatically decide the maximum
10+
batch size that will not cause an OOM error.
11+
12+
Notes
13+
-----
14+
We assume all OOM error will raise :metd:`OutOfMemoryError`.
15+
16+
Parameters
17+
----------
18+
initial_batch_size : int, default: 1024
19+
initial batch size (number of total atoms)
20+
factor : float, default: 2.
21+
increased factor
22+
23+
Attributes
24+
----------
25+
current_batch_size : int
26+
current batch size (number of total atoms)
27+
maximum_working_batch_size : int
28+
maximum working batch size
29+
minimal_not_working_batch_size : int
30+
minimal not working batch size
31+
"""
32+
def __init__(self, initial_batch_size: int = 1024, factor: float = 2.) -> None:
33+
# See also PyTorchLightning/pytorch-lightning#1638
34+
# TODO: discuss a proper initial batch size
35+
self.current_batch_size = initial_batch_size
36+
self.maximum_working_batch_size = 0
37+
self.minimal_not_working_batch_size = 2**31
38+
self.factor = factor
39+
40+
def execute(self, callable: Callable, start_index: int, natoms: int) -> Tuple[int, tuple]:
41+
"""Excuate a method with given batch size.
42+
43+
Parameters
44+
----------
45+
callable : Callable
46+
The method should accept the batch size and start_index as parameters,
47+
and returns executed batch size and data.
48+
start_index : int
49+
start index
50+
natoms : int
51+
natoms
52+
53+
Returns
54+
-------
55+
int
56+
executed batch size * number of atoms
57+
tuple
58+
result from callable, None if failing to execute
59+
60+
Raises
61+
------
62+
OutOfMemoryError
63+
OOM when batch size is 1
64+
"""
65+
try:
66+
n_batch, result = callable(max(self.current_batch_size // natoms, 1), start_index)
67+
except OutOfMemoryError as e:
68+
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
69+
# but luckily we only need to catch once
70+
self.minimal_not_working_batch_size = min(self.minimal_not_working_batch_size, self.current_batch_size)
71+
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
72+
self.maximum_working_batch_size = int(self.minimal_not_working_batch_size / self.factor)
73+
if self.minimal_not_working_batch_size <= natoms:
74+
raise OutOfMemoryError("The callable still throws an out-of-memory (OOM) error even when batch size is 1!") from e
75+
# adjust the next batch size
76+
self._adjust_batch_size(1./self.factor)
77+
return 0, None
78+
else:
79+
n_tot = n_batch * natoms
80+
self.maximum_working_batch_size = max(self.maximum_working_batch_size, n_tot)
81+
# adjust the next batch size
82+
if n_tot >= self.current_batch_size and self.current_batch_size * self.factor < self.minimal_not_working_batch_size:
83+
self._adjust_batch_size(self.factor)
84+
return n_batch, result
85+
86+
def _adjust_batch_size(self, factor: float):
87+
old_batch_size = self.current_batch_size
88+
self.current_batch_size = int(self.current_batch_size * factor)
89+
logging.info("Adjust batch size from %d to %d" % (old_batch_size, self.current_batch_size))
90+
91+
def execute_all(self, callable: Callable, total_size: int, natoms: int, *args, **kwargs) -> Tuple[np.ndarray]:
92+
"""Excuate a method with all given data.
93+
94+
Parameters
95+
----------
96+
callable : Callable
97+
The method should accept *args and **kwargs as input and return the similiar array.
98+
total_size : int
99+
Total size
100+
natoms : int
101+
The number of atoms
102+
**kwargs
103+
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
104+
"""
105+
def execute_with_batch_size(batch_size: int, start_index: int) -> Tuple[int, Tuple[np.ndarray]]:
106+
end_index = start_index + batch_size
107+
end_index = min(end_index, total_size)
108+
return (end_index - start_index), callable(
109+
*[(vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for vv in args],
110+
**{kk: (vv[start_index:end_index] if isinstance(vv, np.ndarray) and vv.ndim > 1 else vv) for kk, vv in kwargs.items()},
111+
)
112+
113+
index = 0
114+
results = []
115+
while index < total_size:
116+
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
117+
if not isinstance(result, tuple):
118+
result = (result,)
119+
index += n_batch
120+
if n_batch:
121+
for rr in result:
122+
rr.reshape((n_batch, -1))
123+
results.append(result)
124+
125+
r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
126+
if len(r) == 1:
127+
# avoid returning tuple if callable doesn't return tuple
128+
r = r[0]
129+
return r

deepmd/utils/errors.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@ class GraphTooLargeError(Exception):
33

44
class GraphWithoutTensorError(Exception):
55
pass
6+
7+
class OutOfMemoryError(Exception):
8+
"""This error is caused by out-of-memory (OOM)."""

deepmd/utils/sess.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22

33
from deepmd.env import tf
4+
from deepmd.utils.errors import OutOfMemoryError
45

56

67
def run_sess(sess: tf.Session, *args, **kwargs):
@@ -35,4 +36,4 @@ def run_sess(sess: tf.Session, *args, **kwargs):
3536
"variable (current value: %s).\n" % (
3637
os.getenv("CUDA_VISIBLE_DEVICES", None),
3738
))
38-
raise RuntimeError(MESSAGE) from e
39+
raise OutOfMemoryError(MESSAGE) from e

doc/troubleshooting/model-compatability.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ When the version of DeePMD-kit used to training model is different from the that
44

55
DeePMD-kit guarantees that the codes with the same major and minor revisions are compatible. That is to say v0.12.5 is compatible to v0.12.0, but is not compatible to v0.11.0 nor v1.0.0.
66

7-
One can execuate `dp convert-from` to convert an old model to a new one.
7+
One can execute `dp convert-from` to convert an old model to a new one.
88

99
| Model version | v0.12 | v1.0 | v1.1 | v1.2 | v1.3 | v2.0 |
1010
|:-:|:-----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
1111
| Compatibility | 😢 | 😢 | 😢 | 😊 | 😊 | 😄 |
1212

1313
**Legend**:
1414
- 😄: The model is compatible with the DeePMD-kit package.
15-
- 😊: The model is incompatible with the DeePMD-kit package, but one can execuate `dp convert-from` to convert an old model to v2.0.
15+
- 😊: The model is incompatible with the DeePMD-kit package, but one can execute `dp convert-from` to convert an old model to v2.0.
1616
- 😢: The model is incompatible with the DeePMD-kit package, and there is no way to convert models.
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from deepmd.utils.batch_size import AutoBatchSize
6+
from deepmd.utils.errors import OutOfMemoryError
7+
8+
class TestAutoBatchSize(unittest.TestCase):
9+
def oom(self, batch_size, start_index):
10+
if batch_size >= 512:
11+
raise OutOfMemoryError
12+
return batch_size, np.zeros((batch_size, 2))
13+
14+
def test_execute_oom(self):
15+
# initial batch size 256 = 128 * 2
16+
auto_batch_size = AutoBatchSize(256, 2.)
17+
# no error - 128
18+
nb, result = auto_batch_size.execute(self.oom, 1, 2)
19+
self.assertEqual(nb, 128)
20+
self.assertEqual(result.shape, (128, 2))
21+
# no error - 256
22+
nb, result = auto_batch_size.execute(self.oom, 1, 2)
23+
self.assertEqual(nb, 256)
24+
self.assertEqual(result.shape, (256, 2))
25+
# error - 512 return 0, None
26+
nb, result = auto_batch_size.execute(self.oom, 1, 2)
27+
self.assertEqual(nb, 0)
28+
self.assertIsNone(result)
29+
# 256 again
30+
nb, result = auto_batch_size.execute(self.oom, 1, 2)
31+
self.assertEqual(nb, 256)
32+
self.assertEqual(result.shape, (256, 2))
33+
# 256 again
34+
nb, result = auto_batch_size.execute(self.oom, 1, 2)
35+
self.assertEqual(nb, 256)
36+
self.assertEqual(result.shape, (256, 2))
37+
38+
def test_execute_all(self):
39+
dd1 = np.zeros((10000, 2, 1))
40+
auto_batch_size = AutoBatchSize(256, 2.)
41+
dd2 = auto_batch_size.execute_all(np.array, 10000, 2, dd1)
42+
np.testing.assert_equal(dd1, dd2)

0 commit comments

Comments
 (0)