|
1 | 1 | # SPDX-License-Identifier: LGPL-3.0-or-later |
2 | | -import logging |
3 | | -import os |
4 | | -from typing import ( |
5 | | - Callable, |
6 | | - Tuple, |
7 | | -) |
8 | | - |
9 | | -import numpy as np |
10 | 2 | from packaging.version import ( |
11 | 3 | Version, |
12 | 4 | ) |
|
18 | 10 | from deepmd.utils.errors import ( |
19 | 11 | OutOfMemoryError, |
20 | 12 | ) |
| 13 | +from deepmd_utils.utils.batch_size import AutoBatchSize as AutoBatchSizeBase |
21 | 14 |
|
22 | | -log = logging.getLogger(__name__) |
23 | | - |
24 | | - |
25 | | -class AutoBatchSize: |
26 | | - """This class allows DeePMD-kit to automatically decide the maximum |
27 | | - batch size that will not cause an OOM error. |
28 | | -
|
29 | | - Notes |
30 | | - ----- |
31 | | - In some CPU environments, the program may be directly killed when OOM. In |
32 | | - this case, by default the batch size will not be increased for CPUs. The |
33 | | - environment variable `DP_INFER_BATCH_SIZE` can be set as the batch size. |
34 | | -
|
35 | | - In other cases, we assume all OOM error will raise :class:`OutOfMemoryError`. |
36 | | -
|
37 | | - Parameters |
38 | | - ---------- |
39 | | - initial_batch_size : int, default: 1024 |
40 | | - initial batch size (number of total atoms) when DP_INFER_BATCH_SIZE |
41 | | - is not set |
42 | | - factor : float, default: 2. |
43 | | - increased factor |
44 | | -
|
45 | | - Attributes |
46 | | - ---------- |
47 | | - current_batch_size : int |
48 | | - current batch size (number of total atoms) |
49 | | - maximum_working_batch_size : int |
50 | | - maximum working batch size |
51 | | - minimal_not_working_batch_size : int |
52 | | - minimal not working batch size |
53 | | - """ |
54 | | - |
55 | | - def __init__(self, initial_batch_size: int = 1024, factor: float = 2.0) -> None: |
56 | | - # See also PyTorchLightning/pytorch-lightning#1638 |
57 | | - # TODO: discuss a proper initial batch size |
58 | | - self.current_batch_size = initial_batch_size |
59 | | - DP_INFER_BATCH_SIZE = int(os.environ.get("DP_INFER_BATCH_SIZE", 0)) |
60 | | - if DP_INFER_BATCH_SIZE > 0: |
61 | | - self.current_batch_size = DP_INFER_BATCH_SIZE |
62 | | - self.maximum_working_batch_size = DP_INFER_BATCH_SIZE |
63 | | - self.minimal_not_working_batch_size = self.maximum_working_batch_size + 1 |
64 | | - else: |
65 | | - self.maximum_working_batch_size = initial_batch_size |
66 | | - if ( |
67 | | - Version(TF_VERSION) >= Version("1.14") |
68 | | - and tf.config.experimental.get_visible_devices("GPU") |
69 | | - ) or tf.test.is_gpu_available(): |
70 | | - self.minimal_not_working_batch_size = 2**31 |
71 | | - else: |
72 | | - self.minimal_not_working_batch_size = ( |
73 | | - self.maximum_working_batch_size + 1 |
74 | | - ) |
75 | | - log.warning( |
76 | | - "You can use the environment variable DP_INFER_BATCH_SIZE to" |
77 | | - "control the inference batch size (nframes * natoms). " |
78 | | - "The default value is %d." % initial_batch_size |
79 | | - ) |
80 | 15 |
|
81 | | - self.factor = factor |
82 | | - |
83 | | - def execute( |
84 | | - self, callable: Callable, start_index: int, natoms: int |
85 | | - ) -> Tuple[int, tuple]: |
86 | | - """Excuate a method with given batch size. |
87 | | -
|
88 | | - Parameters |
89 | | - ---------- |
90 | | - callable : Callable |
91 | | - The method should accept the batch size and start_index as parameters, |
92 | | - and returns executed batch size and data. |
93 | | - start_index : int |
94 | | - start index |
95 | | - natoms : int |
96 | | - natoms |
| 16 | +class AutoBatchSize(AutoBatchSizeBase): |
| 17 | + def is_gpu_available(self) -> bool: |
| 18 | + """Check if GPU is available. |
97 | 19 |
|
98 | 20 | Returns |
99 | 21 | ------- |
100 | | - int |
101 | | - executed batch size * number of atoms |
102 | | - tuple |
103 | | - result from callable, None if failing to execute |
104 | | -
|
105 | | - Raises |
106 | | - ------ |
107 | | - OutOfMemoryError |
108 | | - OOM when batch size is 1 |
| 22 | + bool |
| 23 | + True if GPU is available |
109 | 24 | """ |
110 | | - if natoms > 0: |
111 | | - batch_nframes = self.current_batch_size // natoms |
112 | | - else: |
113 | | - batch_nframes = self.current_batch_size |
114 | | - try: |
115 | | - n_batch, result = callable(max(batch_nframes, 1), start_index) |
116 | | - except OutOfMemoryError as e: |
117 | | - # TODO: it's very slow to catch OOM error; I don't know what TF is doing here |
118 | | - # but luckily we only need to catch once |
119 | | - self.minimal_not_working_batch_size = min( |
120 | | - self.minimal_not_working_batch_size, self.current_batch_size |
121 | | - ) |
122 | | - if self.maximum_working_batch_size >= self.minimal_not_working_batch_size: |
123 | | - self.maximum_working_batch_size = int( |
124 | | - self.minimal_not_working_batch_size / self.factor |
125 | | - ) |
126 | | - if self.minimal_not_working_batch_size <= natoms: |
127 | | - raise OutOfMemoryError( |
128 | | - "The callable still throws an out-of-memory (OOM) error even when batch size is 1!" |
129 | | - ) from e |
130 | | - # adjust the next batch size |
131 | | - self._adjust_batch_size(1.0 / self.factor) |
132 | | - return 0, None |
133 | | - else: |
134 | | - n_tot = n_batch * natoms |
135 | | - self.maximum_working_batch_size = max( |
136 | | - self.maximum_working_batch_size, n_tot |
137 | | - ) |
138 | | - # adjust the next batch size |
139 | | - if ( |
140 | | - n_tot + natoms > self.current_batch_size |
141 | | - and self.current_batch_size * self.factor |
142 | | - < self.minimal_not_working_batch_size |
143 | | - ): |
144 | | - self._adjust_batch_size(self.factor) |
145 | | - return n_batch, result |
| 25 | + return ( |
| 26 | + Version(TF_VERSION) >= Version("1.14") |
| 27 | + and tf.config.experimental.get_visible_devices("GPU") |
| 28 | + ) or tf.test.is_gpu_available() |
146 | 29 |
|
147 | | - def _adjust_batch_size(self, factor: float): |
148 | | - old_batch_size = self.current_batch_size |
149 | | - self.current_batch_size = int(self.current_batch_size * factor) |
150 | | - log.info( |
151 | | - "Adjust batch size from %d to %d" |
152 | | - % (old_batch_size, self.current_batch_size) |
153 | | - ) |
154 | | - |
155 | | - def execute_all( |
156 | | - self, callable: Callable, total_size: int, natoms: int, *args, **kwargs |
157 | | - ) -> Tuple[np.ndarray]: |
158 | | - """Excuate a method with all given data. |
| 30 | + def is_oom_error(self, e: Exception) -> bool: |
| 31 | + """Check if the exception is an OOM error. |
159 | 32 |
|
160 | 33 | Parameters |
161 | 34 | ---------- |
162 | | - callable : Callable |
163 | | - The method should accept *args and **kwargs as input and return the similiar array. |
164 | | - total_size : int |
165 | | - Total size |
166 | | - natoms : int |
167 | | - The number of atoms |
168 | | - *args |
169 | | - Variable length argument list. |
170 | | - **kwargs |
171 | | - If 2D np.ndarray, assume the first axis is batch; otherwise do nothing. |
| 35 | + e : Exception |
| 36 | + Exception |
172 | 37 | """ |
173 | | - |
174 | | - def execute_with_batch_size( |
175 | | - batch_size: int, start_index: int |
176 | | - ) -> Tuple[int, Tuple[np.ndarray]]: |
177 | | - end_index = start_index + batch_size |
178 | | - end_index = min(end_index, total_size) |
179 | | - return (end_index - start_index), callable( |
180 | | - *[ |
181 | | - ( |
182 | | - vv[start_index:end_index] |
183 | | - if isinstance(vv, np.ndarray) and vv.ndim > 1 |
184 | | - else vv |
185 | | - ) |
186 | | - for vv in args |
187 | | - ], |
188 | | - **{ |
189 | | - kk: ( |
190 | | - vv[start_index:end_index] |
191 | | - if isinstance(vv, np.ndarray) and vv.ndim > 1 |
192 | | - else vv |
193 | | - ) |
194 | | - for kk, vv in kwargs.items() |
195 | | - }, |
196 | | - ) |
197 | | - |
198 | | - index = 0 |
199 | | - results = [] |
200 | | - while index < total_size: |
201 | | - n_batch, result = self.execute(execute_with_batch_size, index, natoms) |
202 | | - if not isinstance(result, tuple): |
203 | | - result = (result,) |
204 | | - index += n_batch |
205 | | - if n_batch: |
206 | | - for rr in result: |
207 | | - rr.reshape((n_batch, -1)) |
208 | | - results.append(result) |
209 | | - |
210 | | - r = tuple([np.concatenate(r, axis=0) for r in zip(*results)]) |
211 | | - if len(r) == 1: |
212 | | - # avoid returning tuple if callable doesn't return tuple |
213 | | - r = r[0] |
214 | | - return r |
| 38 | + # TODO: it's very slow to catch OOM error; I don't know what TF is doing here |
| 39 | + # but luckily we only need to catch once |
| 40 | + return isinstance(e, (tf.errors.ResourceExhaustedError, OutOfMemoryError)) |
0 commit comments