Skip to content

Commit 5ae5503

Browse files
authored
Fix deadlock in CallbackList. (#21701)
A memory leak related to the executor in `CallbackList` was fixed in #20779 However, calling `Executor.shutdown` within `__del__` is intrisincally unsafe and can create deadlocks because the garbage collector can be called in different contexts. This new approach uses the `on_train/test/predict_begin` and `on_train/test/predict_end` callbacks to detect when we're done with the executor. - it counts the number of "begin"s and "end"s to handle the case of `evaluate` within `fit` (we do not shutdown the executor at the end of `evaluate` but instead keep it around for the rest of the training) - it also handles `CallbackList` being reused between calls to `fit`, `evaluate` or `predict` even though Keras doesn't reuse. Also renamed `_clear_futures` to `_flush_futures` to make it clear futures are not discarded, but exectuted.
1 parent cc56474 commit 5ae5503

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

keras/src/callbacks/callback_list.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(
3939
via `Callback.set_params`.
4040
"""
4141
self.callbacks = tree.flatten(callbacks) if callbacks else []
42+
self._in_begin_end_block_count = 0
4243
self._executor = None
4344
self._async_train = False
4445
self._async_test = False
@@ -78,9 +79,6 @@ def _configure_async_dispatch(self, callbacks):
7879
if not utils.is_default(cbk.on_predict_batch_end):
7980
async_predict = False
8081

81-
if async_train or async_test or async_predict:
82-
self._executor = concurrent.futures.ThreadPoolExecutor()
83-
8482
self._async_train = async_train
8583
self._async_test = async_test
8684
self._async_predict = async_predict
@@ -113,6 +111,33 @@ def set_model(self, model):
113111
for callback in self.callbacks:
114112
callback.set_model(model)
115113

114+
def _on_begin(self):
115+
"""Called by `on_train/test/predict_begin`.
116+
117+
Start the executor for async calls if needed.
118+
"""
119+
self._in_begin_end_block_count += 1
120+
if (
121+
self._in_begin_end_block_count == 1
122+
and (self._async_train or self._async_test or self._async_predict)
123+
and self._executor is None
124+
):
125+
self._executor = concurrent.futures.ThreadPoolExecutor()
126+
127+
def _on_end(self):
128+
"""Called by `on_train/test/predict_end`.
129+
130+
Shutdown the executor for async calls if all begin/end blocks completed.
131+
"""
132+
self._in_begin_end_block_count -= 1
133+
if self._in_begin_end_block_count < 0:
134+
raise ValueError(
135+
"`on_xxx_end` called without corresponding `on_xxx_begin`"
136+
)
137+
if self._in_begin_end_block_count == 0 and self._executor is not None:
138+
self._executor.shutdown()
139+
self._executor = None
140+
116141
def _async_dispatch(self, fn, *args):
117142
for future in self._futures:
118143
if future.done():
@@ -121,7 +146,8 @@ def _async_dispatch(self, fn, *args):
121146
future = self._executor.submit(fn, *args)
122147
self._futures.append(future)
123148

124-
def _clear_futures(self):
149+
def _flush_futures(self):
150+
"""Waits for all futures to complete and clears the list."""
125151
for future in self._futures:
126152
future.result()
127153
self._futures = []
@@ -138,7 +164,7 @@ def on_epoch_begin(self, epoch, logs=None):
138164

139165
def on_epoch_end(self, epoch, logs=None):
140166
if self._async_train:
141-
self._clear_futures()
167+
self._flush_futures()
142168

143169
logs = python_utils.pythonify_logs(logs)
144170
for callback in self.callbacks:
@@ -204,44 +230,52 @@ def _on_predict_batch_end(self, batch, logs=None):
204230
callback.on_predict_batch_end(batch, logs=logs)
205231

206232
def on_train_begin(self, logs=None):
233+
self._on_begin()
234+
207235
logs = python_utils.pythonify_logs(logs)
208236
for callback in self.callbacks:
209237
callback.on_train_begin(logs)
210238

211239
def on_train_end(self, logs=None):
212240
if self._async_train:
213-
self._clear_futures()
241+
self._flush_futures()
214242

215243
logs = python_utils.pythonify_logs(logs)
216244
for callback in self.callbacks:
217245
callback.on_train_end(logs)
218246

247+
self._on_end()
248+
219249
def on_test_begin(self, logs=None):
250+
self._on_begin()
251+
220252
logs = python_utils.pythonify_logs(logs)
221253
for callback in self.callbacks:
222254
callback.on_test_begin(logs)
223255

224256
def on_test_end(self, logs=None):
225257
if self._async_test:
226-
self._clear_futures()
258+
self._flush_futures()
227259

228260
logs = python_utils.pythonify_logs(logs)
229261
for callback in self.callbacks:
230262
callback.on_test_end(logs)
231263

264+
self._on_end()
265+
232266
def on_predict_begin(self, logs=None):
267+
self._on_begin()
268+
233269
logs = python_utils.pythonify_logs(logs)
234270
for callback in self.callbacks:
235271
callback.on_predict_begin(logs)
236272

237273
def on_predict_end(self, logs=None):
238274
if self._async_predict:
239-
self._clear_futures()
275+
self._flush_futures()
240276

241277
logs = python_utils.pythonify_logs(logs)
242278
for callback in self.callbacks:
243279
callback.on_predict_end(logs)
244280

245-
def __del__(self):
246-
if self._executor is not None:
247-
self._executor.shutdown(cancel_futures=True)
281+
self._on_end()

0 commit comments

Comments
 (0)