diff --git a/src/dryml/models/torch/generic.py b/src/dryml/models/torch/generic.py index 76b31be..1338866 100644 --- a/src/dryml/models/torch/generic.py +++ b/src/dryml/models/torch/generic.py @@ -10,7 +10,6 @@ import zipfile import torch import tqdm -from dryml.utils import validate_class class Model(TorchModel): @@ -56,13 +55,13 @@ class ModelWrapper(Model): @Meta.collect_args @Meta.collect_kwargs def __init__(self, cls, *args, **kwargs): - self.cls = validate_class(cls) + self.cls = cls self.args = args self.kwargs = kwargs self.mdl = None def compute_prepare_imp(self): - self.mdl = self.cls(*self.args, **self.kwargs) + self.mdl = self.cls(*self.args, *self.kwargs) class Sequential(Model): @@ -122,6 +121,45 @@ def compute_cleanup_imp(self): self.opt = None +class TorchScheduler(Object): + @Meta.collect_args + @Meta.collect_kwargs + def __init__(self, cls, optimizer: TorchOptimizer, *args, **kwargs): + if type(cls) is not type: + raise TypeError("first argument must be a class!") + self.cls = cls + self.optimizer = optimizer + self.args = args + self.kwargs = kwargs + self.sched = None + + def compute_prepare_imp(self): + self.sched = self.cls( + self.optimizer.opt, + *self.args, + *self.kwargs) + + def load_compute_imp(self, file: zipfile.ZipFile) -> bool: + try: + with file.open('state.pth', 'r') as f: + self.sched.load_state_dict(torch.load(f)) + return True + except Exception: + return False + + def save_compute_imp(self, file: zipfile.ZipFile) -> bool: + try: + with file.open('state.pth', 'w') as f: + torch.save(self.sched.state_dict(), f) + return True + except Exception: + return False + + def compute_cleanup_imp(self): + del self.sched + self.sched = None + + class Trainable(TorchTrainable): def __init__( self, @@ -182,16 +220,13 @@ def __call__( start_epoch = 0 if train_spec is not None: start_epoch = train_spec.level_step() - # Type checking training data, and converting if necessary batch_size = 32 data = data.torch().batch(batch_size=batch_size) total_batches = data.count() - # Move variables to same device as model devs = context().get_torch_devices() data = data.map_el(lambda el: el.to(devs[0])) - # Check data is supervised. if not data.supervised: raise RuntimeError( @@ -200,23 +235,76 @@ def __call__( optimizer = self.optimizer.opt loss = self.loss.obj model = trainable.model - for i in range(start_epoch, self.epochs): - running_loss = 0. num_batches = 0 t_data = tqdm.tqdm(data, total=total_batches) for X, Y in t_data: optimizer.zero_grad() - outputs = model(X) loss_val = loss(outputs, Y) loss_val.backward() optimizer.step() - running_loss += loss_val.item() num_batches += 1 av_loss = running_loss/(num_batches*batch_size) t_data.set_postfix(loss=av_loss) + print(f"Epoch {i+1} - Average Loss: {av_loss}") + + +class LRBasicTraining(TrainFunction): + def __init__( + self, + optimizer: Wrapper = None, + loss: Wrapper = None, + scheduler: Wrapper = None, + epochs=1): + self.optimizer = optimizer + self.loss = loss + self.epochs = epochs + self.scheduler = scheduler + self.training_loss = [] + def __call__( + self, trainable: Model, data: Dataset, train_spec=None, + train_callbacks=[]): + + # Pop the epoch to resume from + start_epoch = 0 + if train_spec is not None: + start_epoch = train_spec.level_step() + # Type checking training data, and converting if necessary + batch_size = 32 + data = data.torch().batch(batch_size=batch_size) + total_batches = data.count() + # Move variables to same device as model + devs = context().get_torch_devices() + data = data.map_el(lambda el: el.to(devs[0])) + # Check data is supervised. + if not data.supervised: + raise RuntimeError( + f"{__class__} requires supervised data") + optimizer = self.optimizer.opt + loss = self.loss.obj + scheduler = self.scheduler.sched + model = trainable.model + for i in range(start_epoch, self.epochs): + running_loss = 0. + num_batches = 0 + t_data = tqdm.tqdm(data, total=total_batches) + for X, Y in t_data: + optimizer.zero_grad() + outputs = model(X) + loss_val = loss(outputs, Y) + loss_val.backward() + optimizer.step() + running_loss += loss_val.item() + num_batches += 1 + av_loss = running_loss/(num_batches*batch_size) + t_data.set_postfix(loss=av_loss) + scheduler.step(av_loss) + self.training_loss.append(av_loss) print(f"Epoch {i+1} - Average Loss: {av_loss}") + + +