Skip to content

Commit d98634f

Browse files
Update simple_cerebros_random_search.py
Fix AI generated indentation screw - up...
1 parent 6ca8c0b commit d98634f

File tree

1 file changed

+129
-129
lines changed

1 file changed

+129
-129
lines changed

cerebros/simplecerebrosrandomsearch/simple_cerebros_random_search.py

Lines changed: 129 additions & 129 deletions
Original file line numberDiff line numberDiff line change
@@ -365,142 +365,142 @@ def __init__(
365365
self.training_data = training_data
366366
self.labels = labels
367367

368-
self.cast_as_dataset = cast_as_dataset
368+
self.cast_as_dataset = cast_as_dataset
369369

370-
if not cast_as_dataset:
371-
self.dataset = None
372-
else:
373-
# Warn about experimental feature
374-
warn("Casting data to a tf.data.Dataset is EXPERIMENTAL! We do not recommend this for production use.")
375-
376-
# Check if training_data is a list/tuple and has elements
377-
if isinstance(training_data, (list, tuple)) and len(training_data) > 0:
378-
first_element = training_data[0]
379-
if isinstance(first_element, (np.ndarray, tf.Tensor)):
380-
# training_data is [np.array()] or [tf.tensor()]
381-
self.dataset = tf.data.Dataset.from_tensor_slices((training_data, labels))
370+
if not cast_as_dataset:
371+
self.dataset = None
382372
else:
383-
# training_data consists of generators/iterables
384-
def nested_input_generator():
385-
"""
386-
Combines the training data generators and labels generator and yields data in the
387-
format (sample_tensor, label_tensor).
388-
"""
389-
# Assuming training_data contains generator functions or iterables
390-
# and labels is also a generator function or iterable
391-
if hasattr(training_data[0], '__call__'):
392-
# If they are generator functions, call them
393-
data_iterators = [data_gen() for data_gen in training_data]
394-
label_iterator = labels() if hasattr(labels, '__call__') else labels
373+
# Warn about experimental feature
374+
warn("Casting data to a tf.data.Dataset is EXPERIMENTAL! We do not recommend this for production use.")
375+
376+
# Check if training_data is a list/tuple and has elements
377+
if isinstance(training_data, (list, tuple)) and len(training_data) > 0:
378+
first_element = training_data[0]
379+
if isinstance(first_element, (np.ndarray, tf.Tensor)):
380+
# training_data is [np.array()] or [tf.tensor()]
381+
self.dataset = tf.data.Dataset.from_tensor_slices((training_data, labels))
395382
else:
396-
# If they are already iterators/iterables
397-
data_iterators = training_data
398-
label_iterator = labels
383+
# training_data consists of generators/iterables
384+
def nested_input_generator():
385+
"""
386+
Combines the training data generators and labels generator and yields data in the
387+
format (sample_tensor, label_tensor).
388+
"""
389+
# Assuming training_data contains generator functions or iterables
390+
# and labels is also a generator function or iterable
391+
if hasattr(training_data[0], '__call__'):
392+
# If they are generator functions, call them
393+
data_iterators = [data_gen() for data_gen in training_data]
394+
label_iterator = labels() if hasattr(labels, '__call__') else labels
395+
else:
396+
# If they are already iterators/iterables
397+
data_iterators = training_data
398+
label_iterator = labels
399399

400-
for data_items, label_item in zip(zip(*data_iterators), label_iterator):
401-
# data_items is a tuple of inputs for multimodal case
402-
# For single input, it's a tuple with one element
403-
if len(data_items) == 1:
404-
yield data_items[0], label_item
405-
else:
406-
yield list(data_items), label_item
400+
for data_items, label_item in zip(zip(*data_iterators), label_iterator):
401+
# data_items is a tuple of inputs for multimodal case
402+
# For single input, it's a tuple with one element
403+
if len(data_items) == 1:
404+
yield data_items[0], label_item
405+
else:
406+
yield list(data_items), label_item
407407

408-
def create_dataset_from_generator(generator_func):
409-
"""
410-
Creates a tf.data.Dataset by automatically inferring the output signature.
411-
412-
This function inspects the first item yielded by the generator to determine the
413-
shapes and data types of the elements, removing the need to specify them manually.
414-
415-
Args:
416-
generator_func: A callable that returns a generator object.
417-
418-
Returns:
419-
A tf.data.Dataset instance.
420-
"""
408+
def create_dataset_from_generator(generator_func):
409+
"""
410+
Creates a tf.data.Dataset by automatically inferring the output signature.
411+
412+
This function inspects the first item yielded by the generator to determine the
413+
shapes and data types of the elements, removing the need to specify them manually.
414+
415+
Args:
416+
generator_func: A callable that returns a generator object.
417+
418+
Returns:
419+
A tf.data.Dataset instance.
420+
"""
421+
422+
# 1. Get the first item from the generator to see what it looks like.
423+
first_item = next(generator_func())
424+
425+
# 2. A recursive helper function to build the TensorSpec signature from the first item.
426+
def _to_tensor_spec(element):
427+
"""Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
428+
if isinstance(element, (tuple, list)):
429+
# If it's a list or tuple, recursively process each of its items.
430+
return tuple(_to_tensor_spec(e) for e in element)
431+
else:
432+
# For arrays or scalars, create a TensorSpec.
433+
# `tf.convert_to_tensor` correctly handles Python/Numpy types.
434+
tensor = tf.convert_to_tensor(element)
435+
return tf.TensorSpec(shape=tensor.shape, dtype=tensor.dtype)
436+
437+
# 3. Build the complete signature from the first item's structure.
438+
output_signature = _to_tensor_spec(first_item)
439+
440+
print("✅ Successfully inferred the following signature:")
441+
print(output_signature)
442+
print("-" * 30)
443+
# 4. Create the dataset using the generator function and the inferred signature.
444+
return tf.data.Dataset.from_generator(
445+
generator_func,
446+
output_signature=output_signature
447+
)
448+
449+
self.dataset = create_dataset_from_generator(nested_input_generator)
450+
else:
451+
# Assume training_data is a single generator/iterable
452+
def nested_input_generator():
453+
"""
454+
Handles single generator/iterable case.
455+
"""
456+
data_iterator = training_data() if hasattr(training_data, '__call__') else training_data
457+
label_iterator = labels() if hasattr(labels, '__call__') else labels
458+
459+
for data_item, label_item in zip(data_iterator, label_iterator):
460+
yield data_item, label_item
461+
462+
def create_dataset_from_generator(generator_func):
463+
"""
464+
Creates a tf.data.Dataset by automatically inferring the output signature.
421465
422-
# 1. Get the first item from the generator to see what it looks like.
423-
first_item = next(generator_func())
424-
425-
# 2. A recursive helper function to build the TensorSpec signature from the first item.
426-
def _to_tensor_spec(element):
427-
"""Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
428-
if isinstance(element, (tuple, list)):
429-
# If it's a list or tuple, recursively process each of its items.
430-
return tuple(_to_tensor_spec(e) for e in element)
431-
else:
432-
# For arrays or scalars, create a TensorSpec.
433-
# `tf.convert_to_tensor` correctly handles Python/Numpy types.
434-
tensor = tf.convert_to_tensor(element)
435-
return tf.TensorSpec(shape=tensor.shape, dtype=tensor.dtype)
436-
437-
# 3. Build the complete signature from the first item's structure.
438-
output_signature = _to_tensor_spec(first_item)
466+
This function inspects the first item yielded by the generator to determine the
467+
shapes and data types of the elements, removing the need to specify them manually.
439468
440-
print("✅ Successfully inferred the following signature:")
441-
print(output_signature)
442-
print("-" * 30)
443-
# 4. Create the dataset using the generator function and the inferred signature.
444-
return tf.data.Dataset.from_generator(
445-
generator_func,
446-
output_signature=output_signature
447-
)
448-
449-
self.dataset = create_dataset_from_generator(nested_input_generator)
450-
else:
451-
# Assume training_data is a single generator/iterable
452-
def nested_input_generator():
453-
"""
454-
Handles single generator/iterable case.
455-
"""
456-
data_iterator = training_data() if hasattr(training_data, '__call__') else training_data
457-
label_iterator = labels() if hasattr(labels, '__call__') else labels
458-
459-
for data_item, label_item in zip(data_iterator, label_iterator):
460-
yield data_item, label_item
461-
462-
def create_dataset_from_generator(generator_func):
463-
"""
464-
Creates a tf.data.Dataset by automatically inferring the output signature.
465-
466-
This function inspects the first item yielded by the generator to determine the
467-
shapes and data types of the elements, removing the need to specify them manually.
468-
469-
Args:
470-
generator_func: A callable that returns a generator object.
471-
472-
Returns:
473-
A tf.data.Dataset instance.
474-
"""
475-
476-
# 1. Get the first item from the generator to see what it looks like.
477-
first_item = next(generator_func())
478-
479-
# 2. A recursive helper function to build the TensorSpec signature from the first item.
480-
def _to_tensor_spec(element):
481-
"""Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
482-
if isinstance(element, (tuple, list)):
483-
# If it's a list or tuple, recursively process each of its items.
484-
return tuple(_to_tensor_spec(e) for e in element)
485-
else:
486-
# For arrays or scalars, create a TensorSpec.
487-
# `tf.convert_to_tensor` correctly handles Python/Numpy types.
488-
tensor = tf.convert_to_tensor(element)
489-
return tf.TensorSpec(shape=tensor.shape, dtype=tensor.dtype)
490-
491-
# 3. Build the complete signature from the first item's structure.
492-
output_signature = _to_tensor_spec(first_item)
493-
494-
print("✅ Successfully inferred the following signature:")
495-
print(output_signature)
496-
print("-" * 30)
497-
# 4. Create the dataset using the generator function and the inferred signature.
498-
return tf.data.Dataset.from_generator(
499-
generator_func,
500-
output_signature=output_signature
501-
)
502-
503-
self.dataset = create_dataset_from_generator(nested_input_generator)
469+
Args:
470+
generator_func: A callable that returns a generator object.
471+
472+
Returns:
473+
A tf.data.Dataset instance.
474+
"""
475+
476+
# 1. Get the first item from the generator to see what it looks like.
477+
first_item = next(generator_func())
478+
479+
# 2. A recursive helper function to build the TensorSpec signature from the first item.
480+
def _to_tensor_spec(element):
481+
"""Converts a nested structure of numpy arrays/scalars to a TensorSpec structure."""
482+
if isinstance(element, (tuple, list)):
483+
# If it's a list or tuple, recursively process each of its items.
484+
return tuple(_to_tensor_spec(e) for e in element)
485+
else:
486+
# For arrays or scalars, create a TensorSpec.
487+
# `tf.convert_to_tensor` correctly handles Python/Numpy types.
488+
tensor = tf.convert_to_tensor(element)
489+
return tf.TensorSpec(shape=tensor.shape, dtype=tensor.dtype)
490+
491+
# 3. Build the complete signature from the first item's structure.
492+
output_signature = _to_tensor_spec(first_item)
493+
494+
print("✅ Successfully inferred the following signature:")
495+
print(output_signature)
496+
print("-" * 30)
497+
# 4. Create the dataset using the generator function and the inferred signature.
498+
return tf.data.Dataset.from_generator(
499+
generator_func,
500+
output_signature=output_signature
501+
)
502+
503+
self.dataset = create_dataset_from_generator(nested_input_generator)
504504

505505

506506
self.validation_split = validation_split

0 commit comments

Comments
 (0)