Skip to content

Commit 539e4ab

Browse files
authored
add cross-platform AutoBatchSize (#3143)
See #3118 and dptech-corp/deepmd-pytorch#137. Subclass needs to implement `is_gpu_available` and `is_oom_error`. Signed-off-by: Jinzhe Zeng <[email protected]>
1 parent 2096b80 commit 539e4ab

File tree

2 files changed

+250
-191
lines changed

2 files changed

+250
-191
lines changed

deepmd/utils/batch_size.py

Lines changed: 17 additions & 191 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,4 @@
11
# SPDX-License-Identifier: LGPL-3.0-or-later
2-
import logging
3-
import os
4-
from typing import (
5-
Callable,
6-
Tuple,
7-
)
8-
9-
import numpy as np
102
from packaging.version import (
113
Version,
124
)
@@ -18,197 +10,31 @@
1810
from deepmd.utils.errors import (
1911
OutOfMemoryError,
2012
)
13+
from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase
2114

22-
log = logging.getLogger(__name__)
23-
24-
25-
class AutoBatchSize:
26-
"""This class allows DeePMD-kit to automatically decide the maximum
27-
batch size that will not cause an OOM error.
28-
29-
Notes
30-
-----
31-
In some CPU environments, the program may be directly killed when OOM. In
32-
this case, by default the batch size will not be increased for CPUs. The
33-
environment variable `DP_INFER_BATCH_SIZE` can be set as the batch size.
34-
35-
In other cases, we assume all OOM error will raise :class:`OutOfMemoryError`.
36-
37-
Parameters
38-
----------
39-
initial_batch_size : int, default: 1024
40-
initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE
41-
is not set
42-
factor : float, default: 2.
43-
increased factor
44-
45-
Attributes
46-
----------
47-
current_batch_size : int
48-
current batch size (number of total atoms)
49-
maximum_working_batch_size : int
50-
maximum working batch size
51-
minimal_not_working_batch_size : int
52-
minimal not working batch size
53-
"""
54-
55-
def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None:
56-
# See also PyTorchLightning/pytorch-lightning#1638
57-
# TODO: discuss a proper initial batch size
58-
self.current_batch_size = initial_batch_size
59-
DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0))
60-
if DP_INFER_BATCH_SIZE > 0:
61-
self.current_batch_size = DP_INFER_BATCH_SIZE
62-
self.maximum_working_batch_size = DP_INFER_BATCH_SIZE
63-
self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1
64-
else:
65-
self.maximum_working_batch_size = initial_batch_size
66-
if (
67-
Version(TF_VERSION) >= Version("1.14")
68-
and tf.config.experimental.get_visible_devices("GPU")
69-
) or tf.test.is_gpu_available():
70-
self.minimal_not_working_batch_size = 2**31
71-
else:
72-
self.minimal_not_working_batch_size = (
73-
self.maximum_working_batch_size + 1
74-
)
75-
log.warning(
76-
"You can use the environment variable DP_INFER_BATCH_SIZE to"
77-
"control the inference batch size (nframes * natoms). "
78-
"The default value is %d." % initial_batch_size
79-
)
8015

81-
self.factor = factor
82-
83-
def execute(
84-
self, callable: Callable, start_index: int, natoms: int
85-
) -> Tuple[int, tuple]:
86-
"""Excuate a method with given batch size.
87-
88-
Parameters
89-
----------
90-
callable : Callable
91-
The method should accept the batch size and start_index as parameters,
92-
and returns executed batch size and data.
93-
start_index : int
94-
start index
95-
natoms : int
96-
natoms
16+
class AutoBatchSize(AutoBatchSizeBase):
17+
def is_gpu_available(self) -> bool:
18+
"""Check if GPU is available.
9719
9820
Returns
9921
-------
100-
int
101-
executed batch size * number of atoms
102-
tuple
103-
result from callable, None if failing to execute
104-
105-
Raises
106-
------
107-
OutOfMemoryError
108-
OOM when batch size is 1
22+
bool
23+
True if GPU is available
10924
"""
110-
if natoms > 0:
111-
batch_nframes = self.current_batch_size // natoms
112-
else:
113-
batch_nframes = self.current_batch_size
114-
try:
115-
n_batch, result = callable(max(batch_nframes, 1), start_index)
116-
except OutOfMemoryError as e:
117-
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
118-
# but luckily we only need to catch once
119-
self.minimal_not_working_batch_size = min(
120-
self.minimal_not_working_batch_size, self.current_batch_size
121-
)
122-
if self.maximum_working_batch_size >= self.minimal_not_working_batch_size:
123-
self.maximum_working_batch_size = int(
124-
self.minimal_not_working_batch_size / self.factor
125-
)
126-
if self.minimal_not_working_batch_size <= natoms:
127-
raise OutOfMemoryError(
128-
"The callable still throws an out-of-memory (OOM) error even when batch size is 1!"
129-
) from e
130-
# adjust the next batch size
131-
self._adjust_batch_size(1.0 / self.factor)
132-
return 0, None
133-
else:
134-
n_tot = n_batch * natoms
135-
self.maximum_working_batch_size = max(
136-
self.maximum_working_batch_size, n_tot
137-
)
138-
# adjust the next batch size
139-
if (
140-
n_tot + natoms > self.current_batch_size
141-
and self.current_batch_size * self.factor
142-
< self.minimal_not_working_batch_size
143-
):
144-
self._adjust_batch_size(self.factor)
145-
return n_batch, result
25+
return (
26+
Version(TF_VERSION) >= Version("1.14")
27+
and tf.config.experimental.get_visible_devices("GPU")
28+
) or tf.test.is_gpu_available()
14629

147-
def _adjust_batch_size(self, factor: float):
148-
old_batch_size = self.current_batch_size
149-
self.current_batch_size = int(self.current_batch_size * factor)
150-
log.info(
151-
"Adjust batch size from %d to %d"
152-
% (old_batch_size, self.current_batch_size)
153-
)
154-
155-
def execute_all(
156-
self, callable: Callable, total_size: int, natoms: int, *args, **kwargs
157-
) -> Tuple[np.ndarray]:
158-
"""Excuate a method with all given data.
30+
def is_oom_error(self, e: Exception) -> bool:
31+
"""Check if the exception is an OOM error.
15932
16033
Parameters
16134
----------
162-
callable : Callable
163-
The method should accept *args and **kwargs as input and return the similiar array.
164-
total_size : int
165-
Total size
166-
natoms : int
167-
The number of atoms
168-
*args
169-
Variable length argument list.
170-
**kwargs
171-
If 2D np.ndarray, assume the first axis is batch; otherwise do nothing.
35+
e : Exception
36+
Exception
17237
"""
173-
174-
def execute_with_batch_size(
175-
batch_size: int, start_index: int
176-
) -> Tuple[int, Tuple[np.ndarray]]:
177-
end_index = start_index + batch_size
178-
end_index = min(end_index, total_size)
179-
return (end_index - start_index), callable(
180-
*[
181-
(
182-
vv[start_index:end_index]
183-
if isinstance(vv, np.ndarray) and vv.ndim > 1
184-
else vv
185-
)
186-
for vv in args
187-
],
188-
**{
189-
kk: (
190-
vv[start_index:end_index]
191-
if isinstance(vv, np.ndarray) and vv.ndim > 1
192-
else vv
193-
)
194-
for kk, vv in kwargs.items()
195-
},
196-
)
197-
198-
index = 0
199-
results = []
200-
while index < total_size:
201-
n_batch, result = self.execute(execute_with_batch_size, index, natoms)
202-
if not isinstance(result, tuple):
203-
result = (result,)
204-
index += n_batch
205-
if n_batch:
206-
for rr in result:
207-
rr.reshape((n_batch, -1))
208-
results.append(result)
209-
210-
r = tuple([np.concatenate(r, axis=0) for r in zip(*results)])
211-
if len(r) == 1:
212-
# avoid returning tuple if callable doesn't return tuple
213-
r = r[0]
214-
return r
38+
# TODO: it's very slow to catch OOM error; I don't know what TF is doing here
39+
# but luckily we only need to catch once
40+
return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError))

0 commit comments

Comments
 (0)