diff --git a/Train/config_trainer.py b/Train/config_trainer.py index 74b96613..652ea354 100644 --- a/Train/config_trainer.py +++ b/Train/config_trainer.py @@ -24,7 +24,7 @@ from Layers import DistanceWeightedMessagePassing from Layers import LLFillSpace from Layers import LLExtendedObjectCondensation -from Layers import DictModel,RaggedDictModel +from Layers import DictModel from Layers import RaggedGlobalExchange from Layers import SphereActivation from Layers import Multi @@ -163,7 +163,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000): ### Loop over GravNet Layers ############################################## ########################################################################### - gravnet_regs = [0.01, 0.01, 0.01] + gravnet_reg = 0.01 for i in range(GRAVNET_ITERATIONS): @@ -189,14 +189,14 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000): )([x, rs]) gndist = LLRegulariseGravNetSpace( - scale=gravnet_regs[i], + scale=gravnet_reg, record_metrics=False, name=f'regularise_gravnet_{i}')([gndist, prime_coords, gnnidx]) - x_rand = random_sampling_block( - xgn, rs, gncoords, gnnidx, gndist, is_track, - reduction=6, layer_norm=True, name=f"RSU_{i}") - x_rand = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_rand) + #x_rand = random_sampling_block( + # xgn, rs, gncoords, gnnidx, gndist, is_track, + # reduction=6, layer_norm=True, name=f"RSU_{i}") + #x_rand = ScaledGooeyBatchNorm2(**BATCHNORM_OPTIONS)(x_rand) gndist = AverageDistanceRegularizer( strength=1e-3, @@ -214,7 +214,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000): # x_rand = ScalarMultiply(0.1)(x_rand) # gndist = ScalarMultiply(0.01)(gndist) # gncoords = ScalarMultiply(0.01)(gncoords) - x = Concatenate()([x_pre, xgn, x_rand, gndist, gncoords]) + x = Concatenate()([x_pre, xgn, gndist, gncoords]) x = Dense(d_shape, name=f"dense_post_gravnet_1_iteration_{i}", activation=DENSE_ACTIVATION, @@ -270,7 +270,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000): pred_beta = LLExtendedObjectCondensation(scale=1., use_energy_weights=True, - record_metrics=False, + record_metrics=True, print_loss=True, name="ExtendedOCLoss", implementation = loss_implementation, @@ -304,7 +304,7 @@ def config_model(Inputs, td, debug_outdir=None, plot_debug_every=2000): # 'no_noise_rs': pre_processed['no_noise_rs'], } - return RaggedDictModel(inputs=Inputs, outputs=model_outputs) + return DictModel(inputs=Inputs, outputs=model_outputs) #return DictModel(inputs=Inputs, outputs=model_outputs) diff --git a/modules/GravNetLayersRagged.py b/modules/GravNetLayersRagged.py index 7e2189b1..b8375482 100644 --- a/modules/GravNetLayersRagged.py +++ b/modules/GravNetLayersRagged.py @@ -3321,8 +3321,9 @@ def priv_call(self, inputs, training=None): row_splits = inputs[1] tf.assert_equal(x.shape.ndims, 2) tf.assert_equal(row_splits.shape.ndims, 1) + #print(row_splits, row_splits[-1], tf.shape(x)[0]) if row_splits.shape[0] is not None: - tf.assert_equal(row_splits[-1], x.shape[0]) + tf.assert_equal(row_splits[-1], tf.shape(x)[0]) x_coord = x if len(inputs) == 3: diff --git a/modules/Layers.py b/modules/Layers.py index ce13ae2c..0b88d0e3 100644 --- a/modules/Layers.py +++ b/modules/Layers.py @@ -953,36 +953,4 @@ def __init__(self, super(DictModel, self).__init__(inputs,outputs=outputs, *args, **kwargs) - -class RaggedDictModel(tf.keras.Model): - def __init__(self, - inputs, - outputs: dict, #force to be dict - *args, **kwargs): - """ - Just forces dictionary output - """ - - super(RaggedDictModel, self).__init__(inputs,outputs=outputs, *args, **kwargs) - - def call(self, inputs, *args, **kwargs): - return super(RaggedDictModel, self).call(self.unpack_ragged(inputs), *args, **kwargs) - - def train_step(self, inputs, *args, **kwargs): - return super(RaggedDictModel, self).train_step(self.unpack_ragged(inputs), *args, **kwargs) - #super(RaggedDictModel, self).train_step(inputs, *args, **kwargs) - - def unpack_ragged(self, inputs): - output = [] - for i in inputs: - print("Type of i is", type(i)) - print("Hasattr", hasattr(i, "row_splits")) - if type(i) == tf.RaggedTensor: - print("Inside") - output.append((i.values, i.row_splts)) - else: - output.append(i) - - return output - -global_layers_list['DictModel']=DictModel +global_layers_list['DictModel']=DictModel \ No newline at end of file diff --git a/modules/training_base_hgcal.py b/modules/training_base_hgcal.py index 215e7132..2487818f 100644 --- a/modules/training_base_hgcal.py +++ b/modules/training_base_hgcal.py @@ -1,75 +1,357 @@ -from DeepJetCore.training.training_base import training_base +#from DeepJetCore.training.training_base import training_base + + +import concurrent.futures +import numpy as np + +## to call it from cammand lines +import sys +import os from argparse import ArgumentParser +import shutil +from DeepJetCore import DataCollection +import tensorflow.keras as keras import tensorflow as tf +import copy +from DeepJetCore.training.gpuTools import DJCSetGPUs +from DeepJetCore.training.training_base import training_base as training_base_djc +import time -class HGCalTraining(training_base): - def __init__(self, *args, - parser = None, - **kwargs): - ''' - Adds file logging - ''' - #use the DJC training base option to pass a parser - if parser is None: - parser = ArgumentParser('Run the training') - parser.add_argument("--interactive", help="prints output to screen", default=False, action="store_true") + + +#for multi-gpu we need to overwrite a few things here + +### +# +# this will become a cleaned-up version of DJC training_base at some point +# +### +class training_base(object): + + def __init__( + self, splittrainandtest=0.85, + useweights=False, testrun=False, + testrun_fraction=0.1, + resumeSilently=False, + renewtokens=False, #compat + collection_class=DataCollection, + parser=None, + recreate_silently=False + ): + + scriptname=sys.argv[0] + + if parser is None: parser = ArgumentParser('Run the training') + parser.add_argument('inputDataCollection') + parser.add_argument('outputDir') + #parser.add_argument('--modelMethod', help='Method to be used to instantiate model in derived training class', metavar='OPT', default=None) + parser.add_argument("--gpu", help="select specific GPU", metavar="OPT", default="") + #parser.add_argument("--gpufraction", help="select memory fraction for GPU", type=float, metavar="OPT", default=-1) + #parser.add_argument("--submitbatch", help="submits the job to condor" , default=False, action="store_true") + #parser.add_argument("--walltime", help="sets the wall time for the batch job, format: 1d5h or 2d or 3h etc" , default='1d') + #parser.add_argument("--isbatchrun", help="is batch run", default=False, action="store_true") + parser.add_argument("--valdata", help="set validation dataset (optional)", default="") + parser.add_argument("--takeweights", help="Applies weights from the model given as relative or absolute path. Matches by names and skips layers that don't match.", default="") + + + args = parser.parse_args() + self.args = args + self.argstring = sys.argv + #sanity check + + + import matplotlib + #if no X11 use below + matplotlib.use('Agg') + DJCSetGPUs(args.gpu) - #no reason for a lot of validation samples usually - super().__init__(*args, resumeSilently=True,parser=parser,splittrainandtest=0.95,**kwargs) - if not self.args.interactive: - print('>>> redirecting the following stdout and stderr to logs in',self.outputDir) - import sys - sys.stdout = open(self.outputDir+'/stdout.txt', 'w') - sys.stderr = open(self.outputDir+'/stderr.txt', 'w') + self.ngpus=1 + + if len(args.gpu): + self.ngpus=len([i for i in args.gpu.split(',')]) + print('running on '+str(self.ngpus)+ ' gpus') + self.keras_inputs=[] + self.keras_inputsshapes=[] + self.keras_model=None + self.mgpu_keras_models = [] + self.keras_weight_model_path=args.takeweights + self.train_data=None + self.val_data=None + self.startlearningrate=None + self.optimizer=None + self.trainedepoches=0 + self.compiled=False + self.checkpointcounter=0 + self.callbacks=None + self.custom_optimizer=False + self.copied_script="" + + self.inputData = os.path.abspath(args.inputDataCollection) \ + if ',' not in args.inputDataCollection else \ + [os.path.abspath(i) for i in args.inputDataCollection.split(',')] + self.outputDir=args.outputDir + # create output dir + + isNewTraining=True + if os.path.isdir(self.outputDir): + if not (resumeSilently or recreate_silently): + var = input('output dir exists. To recover a training, please type "yes"\n') + if not var == 'yes': + raise Exception('output directory must not exists yet') + isNewTraining=False + if recreate_silently: + isNewTraining=True + else: + os.mkdir(self.outputDir) + self.outputDir = os.path.abspath(self.outputDir) + self.outputDir+='/' + + if recreate_silently: + os.system('rm -rf '+ self.outputDir +'*') + + #copy configuration to output dir + try: + shutil.copyfile(scriptname,self.outputDir+os.path.basename(scriptname)) + except shutil.SameFileError: + pass + except BaseException as e: + raise e + + self.copied_script = self.outputDir+os.path.basename(scriptname) + + self.train_data = collection_class() + self.train_data.readFromFile(self.inputData) + self.train_data.useweights=useweights + + if len(args.valdata): + print('using validation data from ',args.valdata) + self.val_data = DataCollection(args.valdata) + + else: + if testrun: + if len(self.train_data)>1: + self.train_data.split(testrun_fraction) + + self.train_data.dataclass_instance=None #can't be pickled + self.val_data=copy.deepcopy(self.train_data) + + else: + self.val_data=self.train_data.split(splittrainandtest) - from config_saver import copyModules - copyModules(self.outputDir)#save the modules with indexing for overwrites - @tf.function - def compute_per_replica_loss(replica_data): - with tf.GradientTape() as tape: - logits = self.keras_model(replica_data) - primary_loss_value = loss_fn(replica_data, logits) - total_loss_value = primary_loss_value + tf.add_n(self.keras_model.losses) - grads = tape.gradient(total_loss_value, self.keras_model.trainable_variables) - optimizer.apply_gradients(zip(grads, self.keras_model.trainable_variables)) - return total_loss_value - - def to_ragged_tensor(self, data_list): - for e in data_list: - rt = tf.RaggedTensor.from_row_splits(values=e[0][0], row_splits=e[0][1].flatten()) - yield rt - def compileModel(self, **kwargs): - super().compileModel(is_eager=True, - loss=None, - **kwargs) - + shapes = self.train_data.getNumpyFeatureShapes() + inputdtypes = self.train_data.getNumpyFeatureDTypes() + inputnames= self.train_data.getNumpyFeatureArrayNames() + for i in range(len(inputnames)): #in case they are not named + if inputnames[i]=="" or inputnames[i]=="_rowsplits": + inputnames[i]="input_"+str(i)+inputnames[i] + + + print("shapes", shapes) + print("inputdtypes", inputdtypes) + print("inputnames", inputnames) + + self.keras_inputs=[] + self.keras_inputsshapes=[] + + for s,dt,n in zip(shapes,inputdtypes,inputnames): + self.keras_inputs.append(keras.layers.Input(shape=s, dtype=dt, name=n)) + self.keras_inputsshapes.append(s) + + #bookkeeping + self.train_data.writeToFile(self.outputDir+'trainsamples.djcdc',abspath=True) + self.val_data.writeToFile(self.outputDir+'valsamples.djcdc',abspath=True) + + if not isNewTraining: + kfile = self.outputDir+'/KERAS_check_model_last.h5' + if not os.path.isfile(kfile): + kfile = self.outputDir+'/KERAS_check_model_last' #savedmodel format + if not os.path.isdir(kfile): + kfile='' + if len(kfile): + print('loading model',kfile) + self.loadModel(kfile) + self.trainedepoches=0 + if os.path.isfile(self.outputDir+'losses.log'): + for line in open(self.outputDir+'losses.log'): + valloss = line.split(' ')[1][:-1] + if not valloss == "None": + self.trainedepoches+=1 + else: + print('incomplete epochs, starting from the beginning but with pretrained model') + else: + print('no model found in existing output dir, starting training from scratch') + + def modelSet(self): + return (not self.keras_model==None) and not len(self.keras_weight_model_path) + + def syncModelWeights(self): + if len(self.mgpu_keras_models) < 2: + return + weights = self.mgpu_keras_models[0].get_weights() + for model in self.mgpu_keras_models[1:]: + model.set_weights(weights) + + def setModel(self,model,**modelargs): + if len(self.keras_inputs)<1: + raise Exception('setup data first') + + with tf.device('/GPU:0'): + self.keras_model=model(self.keras_inputs,**modelargs) + + if len(self.keras_weight_model_path): + from DeepJetCore.modeltools import apply_weights_where_possible, load_model + self.keras_model = apply_weights_where_possible(self.keras_model, + load_model(self.keras_weight_model_path)) + if not self.keras_model: + raise Exception('Setting model not successful') + + self.mgpu_keras_models = [self.keras_model] #zero model + if self.ngpus > 1: + for i in range(self.ngpus-1): + with tf.device(f'/GPU:{i+1}'): + self.mgpu_keras_models.append(model(self.keras_inputs,**modelargs)) + #sync initial or loaded weights + self.syncModelWeights() + + def saveCheckPoint(self,addstring=''): + self.checkpointcounter=self.checkpointcounter+1 + self.saveModel("KERAS_model_checkpoint_"+str(self.checkpointcounter)+"_"+addstring) + + def _loadModel(self,filename): + from tensorflow.keras.models import load_model + keras_model=load_model(filename, custom_objects=custom_objects_list) + optimizer=keras_model.optimizer + return keras_model, optimizer + + def loadModel(self,filename): + self.keras_model, self.optimizer = self._loadModel(filename) + self.compiled=True + + def setCustomOptimizer(self,optimizer): + self.optimizer = optimizer + self.custom_optimizer=True + + def compileModel(self, + learningrate, + clipnorm=None, + print_models=False, + metrics=None, + is_eager=False, + **compileargs): + + if not self.keras_model: + raise Exception('set model first') + + print('Model being compiled for '+str(self.ngpus)+' gpus') + + self.startlearningrate=learningrate + + if not self.custom_optimizer: + from tensorflow.keras.optimizers import Adam + if clipnorm: + self.optimizer = Adam(lr=self.startlearningrate,clipnorm=clipnorm) + else: + self.optimizer = Adam(lr=self.startlearningrate) + + def run_compile(model, device): + with tf.device(device): + model.compile(optimizer=self.optimizer,metrics=metrics,**compileargs) + if is_eager: + #call on one batch to fully build it + model(self.train_data.getExampleFeatureBatch()) + + if print_models: + print(model.summary()) + + for i, m in enumerate(self.mgpu_keras_models): + run_compile(m, f'/GPU:{i}') + + self.compiled=True + + + def saveModel(self,outfile): + self.keras_model.save(self.outputDir+outfile) + + + # add some of the multi-gpu initialisation here? + def _initTraining(self, + nepochs, + batchsize, + use_sum_of_squares=False): + + + self.train_data.setBatchSize(batchsize) + self.val_data.setBatchSize(batchsize) + self.train_data.batch_uses_sum_of_squares=use_sum_of_squares + self.val_data.batch_uses_sum_of_squares=use_sum_of_squares + + self.train_data.setBatchSize(batchsize) + self.val_data.setBatchSize(batchsize) + + ## create multi-gpu models + + #now this is hgcal specific because of missing truth, think later how to do that better + def compute_gradients(self, model, data, i): + with tf.device(f'/GPU:{i}'): + with tf.GradientTape() as tape: + predictions = model(data, training=True) + loss = tf.add_n(model.losses) + return tape.gradient(loss, model.trainable_variables) + + def average_gradients(self, all_gradients): + # Average the gradients across GPUs + if len(all_gradients) < 2: + return all_gradients[0] + avg_grads = [] + for grad_list_tuple in zip(*all_gradients): + grads = [g for g in grad_list_tuple if g is not None] + avg_grads.append(tf.reduce_mean(grads, axis=0)) + return avg_grads + + def trainstep_parallel(self, split_data): + + batch_gradients = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=self.ngpus) as executor: + futures = [executor.submit(self.compute_gradients, self.mgpu_keras_models[i], split_data[i], i) for i in range(self.ngpus)] + for future in concurrent.futures.as_completed(futures): + gradients = future.result() + batch_gradients.append(gradients) + + # Average gradients across GPUs - just see if it executes for now + avg_grads = self.average_gradients(batch_gradients) + self.optimizer.apply_gradients(zip(avg_grads, self.mgpu_keras_models[0].trainable_variables)) + self.syncModelWeights() # weights synced + + + def trainModel(self, nepochs, batchsize, - run_eagerly=True, - verbose=1, + run_eagerly=False, batchsize_use_sum_of_squares = False, - fake_truth=True,#extend the truth list with dummies. Useful when adding more prediction outputs than truth inputs - backup_after_batches=500, - checkperiod=1, + fake_truth=False,#extend the truth list with dummies. Useful when adding more prediction outputs than truth inputs stop_patience=-1, lr_factor=0.5, lr_patience=-1, lr_epsilon=0.003, lr_cooldown=6, lr_minimum=0.000001, + checkperiod=10, + backup_after_batches=-1, additional_plots=None, additional_callbacks=None, load_in_mem = False, max_files = -1, plot_batch_loss = False, + verbose = 0, **trainargs): - self.keras_model.run_eagerly=run_eagerly + for m in self.mgpu_keras_models: + m.run_eagerly=run_eagerly # write only after the output classes have been added self._initTraining(nepochs,batchsize, batchsize_use_sum_of_squares) @@ -79,9 +361,6 @@ def trainModel(self, pass print('setting up callbacks') from DeepJetCore.training.DeepJet_callbacks import DeepJet_callbacks - minTokenLifetime = 5 - if not self.renewtokens: - minTokenLifetime = -1 self.callbacks=DeepJet_callbacks(self.keras_model, stop_patience=stop_patience, @@ -97,79 +376,190 @@ def trainModel(self, additional_plots=additional_plots, batch_loss = plot_batch_loss, print_summary_after_first_batch=run_eagerly, - minTokenLifetime = minTokenLifetime) + minTokenLifetime = -1) if additional_callbacks is not None: if not isinstance(additional_callbacks, list): additional_callbacks=[additional_callbacks] self.callbacks.callbacks.extend(additional_callbacks) + #create callbacks wrapper - print('starting training') - if load_in_mem: - print('make features') - X_train = self.train_data.getAllFeatures(nfiles=max_files) - X_test = self.val_data.getAllFeatures(nfiles=max_files) - print('make truth') - Y_train = self.train_data.getAllLabels(nfiles=max_files) - Y_test = self.val_data.getAllLabels(nfiles=max_files) - self.keras_model.fit(X_train, Y_train, batch_size=batchsize, epochs=nepochs, - callbacks=self.callbacks.callbacks, - validation_data=(X_test, Y_test), - max_queue_size=1, - use_multiprocessing=False, - workers=0, - **trainargs) - else: - #prepare generator - print("setting up generator... can take a while") - use_fake_truth=None - if fake_truth: - if isinstance(self.keras_model.output,dict): - use_fake_truth = [k for k in self.keras_model.output.keys()] - elif isinstance(self.keras_model.output,list): - use_fake_truth = len(self.keras_model.output) - - traingen = self.train_data.invokeGenerator(fake_truth = use_fake_truth) - valgen = self.val_data.invokeGenerator(fake_truth = use_fake_truth) + #prepare generator + + print("setting up generator... can take a while") + use_fake_truth=None + if fake_truth: + if isinstance(self.keras_model.output,dict): + use_fake_truth = [k for k in self.keras_model.output.keys()] + elif isinstance(self.keras_model.output,list): + use_fake_truth = len(self.keras_model.output) + + traingen = self.train_data.invokeGenerator(fake_truth = use_fake_truth) + valgen = self.val_data.invokeGenerator(fake_truth = use_fake_truth) + batch_time = 100 - while(self.trainedepoches < nepochs): - - #this can change from epoch to epoch - #calculate steps for this epoch - #feed info below - traingen.prepareNextEpoch() - valgen.prepareNextEpoch() - nbatches_train = traingen.getNBatches() #might have changed due to shuffeling - nbatches_val = valgen.getNBatches() + + + + while(self.trainedepoches < nepochs): - print('>>>> epoch', self.trainedepoches,"/",nepochs) - print('training batches: ',nbatches_train) - print('validation batches: ',nbatches_val) + + #this can change from epoch to epoch + #calculate steps for this epoch + #feed info below + traingen.prepareNextEpoch() + valgen.prepareNextEpoch() + nbatches_train = traingen.getNBatches() #might have changed due to shuffeling + nbatches_val = valgen.getNBatches() + + print('>>>> epoch', self.trainedepoches,"/",nepochs) + print('training batches: ',nbatches_train) + print('validation batches: ',nbatches_val) - data = self.to_ragged_tensor(traingen.feedNumpyData()) - #data = traingen.feedNumpyData() - - self.keras_model.fit(data, - steps_per_epoch=nbatches_train, - epochs=self.trainedepoches + 1, - initial_epoch=self.trainedepoches, - callbacks=self.callbacks.callbacks, - validation_data=valgen.feedNumpyData(), - validation_steps=nbatches_val, - max_queue_size=1, - use_multiprocessing=False, - workers=0, - **trainargs + + #this is in here because it needs steps count + callbacks = tf.keras.callbacks.CallbackList( + self.callbacks.callbacks, + add_history=True, + add_progbar=verbose != 0, + model=self.keras_model, #only run them on the main module! + verbose=verbose, + epochs=1, + steps=nbatches_train, ) - self.trainedepoches += 1 - traingen.shuffleFileList() - # - - self.saveModel("KERAS_model.h5") + if self.trainedepoches == 0: + callbacks.on_train_begin() + + callbacks.on_epoch_begin(self.trainedepoches) + + ### + + nbatches_in = 0 + single_counter = 0 + time_sum = 0 + + while nbatches_in < nbatches_train: + + if batch_time > 0.1: + st = time.time() + + thisbatch = [] + while len(thisbatch) < self.ngpus and nbatches_in < nbatches_train: + #only 'feature' part matters for HGCAL + thisbatch.append(next(traingen.feedNumpyData())[0]) + nbatches_in += 1 + + if len(thisbatch) != self.ngpus: #last batch might not be enough + break + + callbacks.on_train_batch_begin(single_counter) + + self.trainstep_parallel(thisbatch) + + if batch_time > 0.1: + batch_time = time.time() - st + time_sum += batch_time + if not nbatches_in % int(20/(time_sum/nbatches_in)): # print less when it's faster + print('avg batch time', time_sum/nbatches_in, 's') + + logs = { m.name: m.result() for m in self.keras_model.metrics } #only for main model + + callbacks.on_train_batch_end(single_counter, logs) + + single_counter += 1 + + callbacks.on_epoch_end(self.trainedepoches, logs) #use same logs here + self.trainedepoches += 1 + traingen.shuffleFileList() + # + + self.saveModel("KERAS_model.h5") return self.keras_model, self.callbacks.history + + + + def change_learning_rate(self, new_lr): + import tensorflow.keras.backend as K + K.set_value(self.keras_model.optimizer.lr, new_lr) + + + +class HGCalTraining(training_base): + def __init__(self, *args, + **kwargs): + ''' + Adds file logging + ''' + #no reason for a lot of validation samples usually + super().__init__(*args, resumeSilently=True,splittrainandtest=0.95,**kwargs) + + from config_saver import copyModules + copyModules(self.outputDir)#save the modules with indexing for overwrites + + def compileModel(self, **kwargs): + super().compileModel(is_eager=True, + loss=None, + **kwargs) + + def trainModel(self, + nepochs, + batchsize, + backup_after_batches=500, + checkperiod=1, + **kwargs): + ''' + Just implements some defaults + ''' + return super().trainModel(nepochs=nepochs, + batchsize=batchsize, + run_eagerly=True, + verbose=2, + batchsize_use_sum_of_squares=False, + fake_truth=True, + backup_after_batches=backup_after_batches, + checkperiod=checkperiod, + **kwargs) + + + +class HGCalTraining_compat(training_base_djc): + def __init__(self, *args, + **kwargs): + ''' + Adds file logging + ''' + #no reason for a lot of validation samples usually + super().__init__(*args, resumeSilently=True,splittrainandtest=0.95,**kwargs) + + from config_saver import copyModules + copyModules(self.outputDir)#save the modules with indexing for overwrites + + def compileModel(self, **kwargs): + super().compileModel(is_eager=True, + loss=None, + **kwargs) + + def trainModel(self, + nepochs, + batchsize, + backup_after_batches=500, + checkperiod=1, + **kwargs): + ''' + Just implements some defaults + ''' + return super().trainModel(nepochs=nepochs, + batchsize=batchsize, + run_eagerly=True, + verbose=2, + batchsize_use_sum_of_squares=False, + fake_truth=True, + backup_after_batches=backup_after_batches, + checkperiod=checkperiod, + **kwargs) \ No newline at end of file