@@ -28,6 +28,8 @@ def __init__(self, config: SerializableDict = None, map_x=True, map_y=True, **kw
2828 self .output_types = None
2929 self .output_shapes = None
3030 self .padding_values = None
31+ # Fix tf memory leak: https://github.com/tensorflow/tensorflow/issues/37653#issuecomment-1000517720
32+ self .py_func_set_to_cleanup = set ()
3133
3234 @abstractmethod
3335 def fit (self , trn_path : str , ** kwargs ) -> int :
@@ -170,6 +172,9 @@ def samples_to_dataset(self, samples: Generator, map_x=None, map_y=None, batch_s
170172 padding_values ]), 'Your create_types_shapes_values returns None, which is not allowed'
171173 # if not callable(samples):
172174 # samples = Transform.generator_to_callable(samples)
175+ if not hasattr (tf .compat .v1 .get_default_graph (), '_py_funcs_used_in_graph' ):
176+ tf .compat .v1 .get_default_graph ()._py_funcs_used_in_graph = []
177+ py_func_set_before = set (tf .compat .v1 .get_default_graph ()._py_funcs_used_in_graph )
173178 dataset = tf .data .Dataset .from_generator (samples , output_types = output_types , output_shapes = output_shapes )
174179 if cache :
175180 logger .debug ('Dataset cache enabled' )
@@ -197,6 +202,8 @@ def mapper(X, Y):
197202 return X , Y
198203
199204 dataset = dataset .map (mapper , num_parallel_calls = tf .data .experimental .AUTOTUNE )
205+ py_func_set_after = set (tf .compat .v1 .get_default_graph ()._py_funcs_used_in_graph ) - py_func_set_before
206+ self .py_func_set_to_cleanup |= py_func_set_after
200207 return dataset
201208
202209 @abstractmethod
@@ -237,7 +244,8 @@ def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
237244 def X_to_inputs (self , X : Union [tf .Tensor , Tuple [tf .Tensor ]]) -> Iterable :
238245 return [repr (x ) for x in X ]
239246
240- def Y_to_outputs (self , Y : Union [tf .Tensor , Tuple [tf .Tensor ]], gold = False , inputs = None , X = None , batch = None ) -> Iterable :
247+ def Y_to_outputs (self , Y : Union [tf .Tensor , Tuple [tf .Tensor ]], gold = False , inputs = None , X = None ,
248+ batch = None ) -> Iterable :
241249 return [repr (y ) for y in Y ]
242250
243251 def XY_to_inputs_outputs (self , X : Union [tf .Tensor , Tuple [tf .Tensor ]],
@@ -295,3 +303,8 @@ def input_truth_output_to_str(self, input, truth, output):
295303
296304 """
297305 return '\t ' .join ([input , truth , output ]) + '\n '
306+
307+ def cleanup (self ):
308+ new_py_funcs = set (tf .compat .v1 .get_default_graph ()._py_funcs_used_in_graph ) - self .py_func_set_to_cleanup
309+ tf .compat .v1 .get_default_graph ()._py_funcs_used_in_graph = list (new_py_funcs )
310+ self .py_func_set_to_cleanup = set ()
0 commit comments