Skip to content

Commit e8044b2

Browse files
committed
1 parent 1cbcdfe commit e8044b2

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

hanlp/common/keras_component.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tensorflow as tf
1212

1313
import hanlp.utils
14-
from hanlp_common.io import save_json,load_json
14+
from hanlp_common.io import save_json, load_json
1515
from hanlp.callbacks.fine_csv_logger import FineCSVLogger
1616
from hanlp.common.component import Component
1717
from hanlp.common.transform_tf import Transform
@@ -255,7 +255,8 @@ def build_optimizer(self, optimizer, **kwargs):
255255
if isinstance(optimizer, (str, dict)):
256256
custom_objects = {'AdamWeightDecay': AdamWeightDecay}
257257
optimizer: tf.keras.optimizers.Optimizer = tf.keras.utils.deserialize_keras_object(optimizer,
258-
module_objects=vars(tf.keras.optimizers),
258+
module_objects=vars(
259+
tf.keras.optimizers),
259260
custom_objects=custom_objects)
260261
self.config.optimizer = tf.keras.utils.serialize_keras_object(optimizer)
261262
return optimizer
@@ -437,6 +438,7 @@ def predict(self, data: Any, batch_size=None, **kwargs):
437438
for output in self.predict_batch(batch, inputs=inputs, **kwargs):
438439
results.append(output)
439440
num_samples += samples_in_batch
441+
self.transform.cleanup()
440442

441443
if flat:
442444
return results[0]

hanlp/common/transform_tf.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kw
2828
self.output_types = None
2929
self.output_shapes = None
3030
self.padding_values = None
31+
# Fix tf memory leak: https://github.com/tensorflow/tensorflow/issues/37653#issuecomment-1000517720
32+
self.py_func_set_to_cleanup = set()
3133

3234
@abstractmethod
3335
def fit(self, trn_path: str, **kwargs) -> int:
@@ -170,6 +172,9 @@ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_s
170172
padding_values]), 'Your create_types_shapes_values returns None, which is not allowed'
171173
# if not callable(samples):
172174
# samples = Transform.generator_to_callable(samples)
175+
if not hasattr(tf.compat.v1.get_default_graph(), '_py_funcs_used_in_graph'):
176+
tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = []
177+
py_func_set_before = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph)
173178
dataset = tf.data.Dataset.from_generator(samples, output_types=output_types, output_shapes=output_shapes)
174179
if cache:
175180
logger.debug('Dataset cache enabled')
@@ -197,6 +202,8 @@ def mapper(X, Y):
197202
return X, Y
198203

199204
dataset = dataset.map(mapper, num_parallel_calls=tf.data.experimental.AUTOTUNE)
205+
py_func_set_after = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - py_func_set_before
206+
self.py_func_set_to_cleanup |= py_func_set_after
200207
return dataset
201208

202209
@abstractmethod
@@ -237,7 +244,8 @@ def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
237244
def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
238245
return [repr(x) for x in X]
239246

240-
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
247+
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
248+
batch=None) -> Iterable:
241249
return [repr(y) for y in Y]
242250

243251
def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]],
@@ -295,3 +303,8 @@ def input_truth_output_to_str(self, input, truth, output):
295303
296304
"""
297305
return '\t'.join([input, truth, output]) + '\n'
306+
307+
def cleanup(self):
308+
new_py_funcs = set(tf.compat.v1.get_default_graph()._py_funcs_used_in_graph) - self.py_func_set_to_cleanup
309+
tf.compat.v1.get_default_graph()._py_funcs_used_in_graph = list(new_py_funcs)
310+
self.py_func_set_to_cleanup = set()

0 commit comments

Comments
 (0)