@@ -257,87 +257,6 @@ def get_batch(self) -> tuple:
257257 return tuple (batch )
258258
259259
260- class TensorFlowDataGenerator (DataGenerator ): # pragma: no cover
261- """
262- Wrapper class on top of the TensorFlow native iterators :class:`tf.data.Iterator`.
263- """
264-
265- def __init__ (
266- self ,
267- sess : "tf.Session" ,
268- iterator : "tf.data.Iterator" ,
269- iterator_type : str ,
270- iterator_arg : dict | tuple | "tf.Operation" ,
271- size : int ,
272- batch_size : int ,
273- ) -> None :
274- """
275- Create a data generator wrapper for TensorFlow. Supported iterators: initializable, reinitializable, feedable.
276-
277- :param sess: TensorFlow session.
278- :param iterator: Data iterator from TensorFlow.
279- :param iterator_type: Type of the iterator. Supported types: `initializable`, `reinitializable`, `feedable`.
280- :param iterator_arg: Argument to initialize the iterator. It is either a feed_dict used for the initializable
281- and feedable mode, or an init_op used for the reinitializable mode.
282- :param size: Total size of the dataset.
283- :param batch_size: Size of the minibatches.
284- :raises `TypeError`, `ValueError`: If input parameters are not valid.
285- """
286-
287- import tensorflow .compat .v1 as tf
288-
289- super ().__init__ (size = size , batch_size = batch_size )
290- self .sess = sess
291- self ._iterator = iterator
292- self .iterator_type = iterator_type
293- self .iterator_arg = iterator_arg
294-
295- if not isinstance (iterator , tf .data .Iterator ):
296- raise TypeError ("Only support object tf.data.Iterator" )
297-
298- if iterator_type == "initializable" :
299- if not isinstance (iterator_arg , dict ):
300- raise TypeError (f"Need to pass a dictionary for iterator type { iterator_type } " )
301- elif iterator_type == "reinitializable" :
302- if not isinstance (iterator_arg , tf .Operation ):
303- raise TypeError (f"Need to pass a TensorFlow operation for iterator type { iterator_type } " )
304- elif iterator_type == "feedable" :
305- if not isinstance (iterator_arg , tuple ):
306- raise TypeError (f"Need to pass a tuple for iterator type { iterator_type } " )
307- else :
308- raise TypeError (f"Iterator type { iterator_type } not supported" )
309-
310- def get_batch (self ) -> tuple :
311- """
312- Provide the next batch for training in the form of a tuple `(x, y)`. The generator should loop over the data
313- indefinitely.
314-
315- :return: A tuple containing a batch of data `(x, y)`.
316- :raises `ValueError`: If the iterator has reached the end.
317- """
318- import tensorflow as tf
319-
320- # Get next batch
321- next_batch = self .iterator .get_next ()
322-
323- # Process to get the batch
324- try :
325- if self .iterator_type in ("initializable" , "reinitializable" ):
326- return self .sess .run (next_batch )
327- return self .sess .run (next_batch , feed_dict = self .iterator_arg [1 ])
328- except (tf .errors .FailedPreconditionError , tf .errors .OutOfRangeError ):
329- if self .iterator_type == "initializable" :
330- self .sess .run (self .iterator .initializer , feed_dict = self .iterator_arg )
331- return self .sess .run (next_batch )
332-
333- if self .iterator_type == "reinitializable" :
334- self .sess .run (self .iterator_arg )
335- return self .sess .run (next_batch )
336-
337- self .sess .run (self .iterator_arg [0 ].initializer )
338- return self .sess .run (next_batch , feed_dict = self .iterator_arg [1 ])
339-
340-
341260class TensorFlowV2DataGenerator (DataGenerator ):
342261 """
343262 Wrapper class on top of the TensorFlow v2 native iterators :class:`tf.data.Iterator`.
0 commit comments