@@ -365,142 +365,142 @@ def __init__(
365365 self .training_data = training_data
366366 self .labels = labels
367367
368- self .cast_as_dataset = cast_as_dataset
368+ self .cast_as_dataset = cast_as_dataset
369369
370- if not cast_as_dataset :
371- self .dataset = None
372- else :
373- # Warn about experimental feature
374- warn ("Casting data to a tf.data.Dataset is EXPERIMENTAL! We do not recommend this for production use." )
375-
376- # Check if training_data is a list/tuple and has elements
377- if isinstance (training_data , (list , tuple )) and len (training_data ) > 0 :
378- first_element = training_data [0 ]
379- if isinstance (first_element , (np .ndarray , tf .Tensor )):
380- # training_data is [np.array()] or [tf.tensor()]
381- self .dataset = tf .data .Dataset .from_tensor_slices ((training_data , labels ))
370+ if not cast_as_dataset :
371+ self .dataset = None
382372 else :
383- # training_data consists of generators/iterables
384- def nested_input_generator ():
385- """
386- Combines the training data generators and labels generator and yields data in the
387- format (sample_tensor, label_tensor).
388- """
389- # Assuming training_data contains generator functions or iterables
390- # and labels is also a generator function or iterable
391- if hasattr (training_data [0 ], '__call__' ):
392- # If they are generator functions, call them
393- data_iterators = [data_gen () for data_gen in training_data ]
394- label_iterator = labels () if hasattr (labels , '__call__' ) else labels
373+ # Warn about experimental feature
374+ warn ("Casting data to a tf.data.Dataset is EXPERIMENTAL! We do not recommend this for production use." )
375+
376+ # Check if training_data is a list/tuple and has elements
377+ if isinstance (training_data , (list , tuple )) and len (training_data ) > 0 :
378+ first_element = training_data [0 ]
379+ if isinstance (first_element , (np .ndarray , tf .Tensor )):
380+ # training_data is [np.array()] or [tf.tensor()]
381+ self .dataset = tf .data .Dataset .from_tensor_slices ((training_data , labels ))
395382 else :
396- # If they are already iterators/iterables
397- data_iterators = training_data
398- label_iterator = labels
383+ # training_data consists of generators/iterables
384+ def nested_input_generator ():
385+ """
386+ Combines the training data generators and labels generator and yields data in the
387+ format (sample_tensor, label_tensor).
388+ """
389+ # Assuming training_data contains generator functions or iterables
390+ # and labels is also a generator function or iterable
391+ if hasattr (training_data [0 ], '__call__' ):
392+ # If they are generator functions, call them
393+ data_iterators = [data_gen () for data_gen in training_data ]
394+ label_iterator = labels () if hasattr (labels , '__call__' ) else labels
395+ else :
396+ # If they are already iterators/iterables
397+ data_iterators = training_data
398+ label_iterator = labels
399399
400- for data_items , label_item in zip (zip (* data_iterators ), label_iterator ):
401- # data_items is a tuple of inputs for multimodal case
402- # For single input, it's a tuple with one element
403- if len (data_items ) == 1 :
404- yield data_items [0 ], label_item
405- else :
406- yield list (data_items ), label_item
400+ for data_items , label_item in zip (zip (* data_iterators ), label_iterator ):
401+ # data_items is a tuple of inputs for multimodal case
402+ # For single input, it's a tuple with one element
403+ if len (data_items ) == 1 :
404+ yield data_items [0 ], label_item
405+ else :
406+ yield list (data_items ), label_item
407407
408- def create_dataset_from_generator (generator_func ):
409- """
410- Creates a tf.data.Dataset by automatically inferring the output signature.
411-
412- This function inspects the first item yielded by the generator to determine the
413- shapes and data types of the elements, removing the need to specify them manually.
414-
415- Args:
416- generator_func: A callable that returns a generator object.
417-
418- Returns:
419- A tf.data.Dataset instance.
420- """
408+ def create_dataset_from_generator (generator_func ):
409+ """
410+ Creates a tf.data.Dataset by automatically inferring the output signature.
411+
412+ This function inspects the first item yielded by the generator to determine the
413+ shapes and data types of the elements, removing the need to specify them manually.
414+
415+ Args:
416+ generator_func: A callable that returns a generator object.
417+
418+ Returns:
419+ A tf.data.Dataset instance.
420+ """
421+
422+ # 1. Get the first item from the generator to see what it looks like.
423+ first_item = next (generator_func ())
424+
425+ # 2. A recursive helper function to build the TensorSpec signature from the first item.
426+ def _to_tensor_spec (element ):
427+ """Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
428+ if isinstance (element , (tuple , list )):
429+ # If it's a list or tuple, recursively process each of its items.
430+ return tuple (_to_tensor_spec (e ) for e in element )
431+ else :
432+ # For arrays or scalars, create a TensorSpec.
433+ # `tf.convert_to_tensor` correctly handles Python/Numpy types.
434+ tensor = tf .convert_to_tensor (element )
435+ return tf .TensorSpec (shape = tensor .shape , dtype = tensor .dtype )
436+
437+ # 3. Build the complete signature from the first item's structure.
438+ output_signature = _to_tensor_spec (first_item )
439+
440+ print ("✅ Successfully inferred the following signature:" )
441+ print (output_signature )
442+ print ("-" * 30 )
443+ # 4. Create the dataset using the generator function and the inferred signature.
444+ return tf .data .Dataset .from_generator (
445+ generator_func ,
446+ output_signature = output_signature
447+ )
448+
449+ self .dataset = create_dataset_from_generator (nested_input_generator )
450+ else :
451+ # Assume training_data is a single generator/iterable
452+ def nested_input_generator ():
453+ """
454+ Handles single generator/iterable case.
455+ """
456+ data_iterator = training_data () if hasattr (training_data , '__call__' ) else training_data
457+ label_iterator = labels () if hasattr (labels , '__call__' ) else labels
458+
459+ for data_item , label_item in zip (data_iterator , label_iterator ):
460+ yield data_item , label_item
461+
462+ def create_dataset_from_generator (generator_func ):
463+ """
464+ Creates a tf.data.Dataset by automatically inferring the output signature.
421465
422- # 1. Get the first item from the generator to see what it looks like.
423- first_item = next (generator_func ())
424-
425- # 2. A recursive helper function to build the TensorSpec signature from the first item.
426- def _to_tensor_spec (element ):
427- """Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
428- if isinstance (element , (tuple , list )):
429- # If it's a list or tuple, recursively process each of its items.
430- return tuple (_to_tensor_spec (e ) for e in element )
431- else :
432- # For arrays or scalars, create a TensorSpec.
433- # `tf.convert_to_tensor` correctly handles Python/Numpy types.
434- tensor = tf .convert_to_tensor (element )
435- return tf .TensorSpec (shape = tensor .shape , dtype = tensor .dtype )
436-
437- # 3. Build the complete signature from the first item's structure.
438- output_signature = _to_tensor_spec (first_item )
466+ This function inspects the first item yielded by the generator to determine the
467+ shapes and data types of the elements, removing the need to specify them manually.
439468
440- print ("✅ Successfully inferred the following signature:" )
441- print (output_signature )
442- print ("-" * 30 )
443- # 4. Create the dataset using the generator function and the inferred signature.
444- return tf .data .Dataset .from_generator (
445- generator_func ,
446- output_signature = output_signature
447- )
448-
449- self .dataset = create_dataset_from_generator (nested_input_generator )
450- else :
451- # Assume training_data is a single generator/iterable
452- def nested_input_generator ():
453- """
454- Handles single generator/iterable case.
455- """
456- data_iterator = training_data () if hasattr (training_data , '__call__' ) else training_data
457- label_iterator = labels () if hasattr (labels , '__call__' ) else labels
458-
459- for data_item , label_item in zip (data_iterator , label_iterator ):
460- yield data_item , label_item
461-
462- def create_dataset_from_generator (generator_func ):
463- """
464- Creates a tf.data.Dataset by automatically inferring the output signature.
465-
466- This function inspects the first item yielded by the generator to determine the
467- shapes and data types of the elements, removing the need to specify them manually.
468-
469- Args:
470- generator_func: A callable that returns a generator object.
471-
472- Returns:
473- A tf.data.Dataset instance.
474- """
475-
476- # 1. Get the first item from the generator to see what it looks like.
477- first_item = next (generator_func ())
478-
479- # 2. A recursive helper function to build the TensorSpec signature from the first item.
480- def _to_tensor_spec (element ):
481- """Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
482- if isinstance (element , (tuple , list )):
483- # If it's a list or tuple, recursively process each of its items.
484- return tuple (_to_tensor_spec (e ) for e in element )
485- else :
486- # For arrays or scalars, create a TensorSpec.
487- # `tf.convert_to_tensor` correctly handles Python/Numpy types.
488- tensor = tf .convert_to_tensor (element )
489- return tf .TensorSpec (shape = tensor .shape , dtype = tensor .dtype )
490-
491- # 3. Build the complete signature from the first item's structure.
492- output_signature = _to_tensor_spec (first_item )
493-
494- print ("✅ Successfully inferred the following signature:" )
495- print (output_signature )
496- print ("-" * 30 )
497- # 4. Create the dataset using the generator function and the inferred signature.
498- return tf .data .Dataset .from_generator (
499- generator_func ,
500- output_signature = output_signature
501- )
502-
503- self .dataset = create_dataset_from_generator (nested_input_generator )
469+ Args:
470+ generator_func: A callable that returns a generator object.
471+
472+ Returns:
473+ A tf.data.Dataset instance.
474+ """
475+
476+ # 1. Get the first item from the generator to see what it looks like.
477+ first_item = next (generator_func ())
478+
479+ # 2. A recursive helper function to build the TensorSpec signature from the first item.
480+ def _to_tensor_spec (element ):
481+ """Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
482+ if isinstance (element , (tuple , list )):
483+ # If it's a list or tuple, recursively process each of its items.
484+ return tuple (_to_tensor_spec (e ) for e in element )
485+ else :
486+ # For arrays or scalars, create a TensorSpec.
487+ # `tf.convert_to_tensor` correctly handles Python/Numpy types.
488+ tensor = tf .convert_to_tensor (element )
489+ return tf .TensorSpec (shape = tensor .shape , dtype = tensor .dtype )
490+
491+ # 3. Build the complete signature from the first item's structure.
492+ output_signature = _to_tensor_spec (first_item )
493+
494+ print ("✅ Successfully inferred the following signature:" )
495+ print (output_signature )
496+ print ("-" * 30 )
497+ # 4. Create the dataset using the generator function and the inferred signature.
498+ return tf .data .Dataset .from_generator (
499+ generator_func ,
500+ output_signature = output_signature
501+ )
502+
503+ self .dataset = create_dataset_from_generator (nested_input_generator )
504504
505505
506506 self .validation_split = validation_split
0 commit comments