Skip to content

Commit ce11359

Browse files
Update simple_cerebros_random_search.py
Add support for a tf.Dataset
1 parent 1ff1e75 commit ce11359

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

cerebros/simplecerebrosrandomsearch/simple_cerebros_random_search.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def __init__(
288288
maximum_units_per_level: int,
289289
minimum_neurons_per_unit: int,
290290
maximum_neurons_per_unit: int,
291+
dataset: tf.data.Dataset=None,
291292
validation_data: tuple=None,
292293
activation='elu',
293294
final_activation=None,
@@ -356,6 +357,7 @@ def __init__(
356357
self.maximum_units_per_level = maximum_units_per_level
357358
self.minimum_neurons_per_unit = minimum_neurons_per_unit
358359
self.maximum_neurons_per_unit = maximum_neurons_per_unit
360+
self.data_set = data_set
359361
self.activation = activation
360362
self.final_activation = final_activation
361363
self.unit_type = unit_type
@@ -493,15 +495,23 @@ def run_moity_permutations(self, spec, subtrial_number, lock):
493495
print(nnf.materialized_neural_network.summary())
494496
if self.chart_network_graph:
495497
nnf.get_graph()
496-
497-
history = neural_network.fit(x=self.training_data,
498-
y=self.labels,
499-
epochs=self.epochs,
500-
batch_size=self.batch_size,
501-
# callbacks=[early_stopping,
502-
# tensor_board],
503-
validation_split=self.validation_split,
504-
validation_data=self.validation_data)
498+
if self.dataset is not None:
499+
history = neural_network.fit(x=self.training_data,
500+
y=self.labels,
501+
epochs=self.epochs,
502+
batch_size=self.batch_size,
503+
# callbacks=[early_stopping,
504+
# tensor_board],
505+
validation_split=self.validation_split,
506+
validation_data=self.validation_data)
507+
else:
508+
history = neural_network.fit(dataset=self.dataset,
509+
epochs=self.epochs,
510+
batch_size=self.batch_size,
511+
# callbacks=[early_stopping,
512+
# tensor_board],
513+
validation_split=self.validation_split,
514+
validation_data=self.validation_data)
505515
oracle_0 = pd.DataFrame(history.history)
506516

507517
model_architectures_folder = f"{self.project_name}/model_architectures"

0 commit comments

Comments
 (0)