Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions src/dryml/models/torch/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import zipfile
import torch
import tqdm
from dryml.utils import validate_class


class Model(TorchModel):
Expand Down Expand Up @@ -56,13 +55,13 @@ class ModelWrapper(Model):
@Meta.collect_args
@Meta.collect_kwargs
def __init__(self, cls, *args, **kwargs):
self.cls = validate_class(cls)
self.cls = cls
self.args = args
self.kwargs = kwargs
self.mdl = None

def compute_prepare_imp(self):
self.mdl = self.cls(*self.args, **self.kwargs)
self.mdl = self.cls(*self.args, *self.kwargs)


class Sequential(Model):
Expand Down Expand Up @@ -122,6 +121,45 @@ def compute_cleanup_imp(self):
self.opt = None


class TorchScheduler(Object):
@Meta.collect_args
@Meta.collect_kwargs
def __init__(self, cls, optimizer: TorchOptimizer, *args, **kwargs):
if type(cls) is not type:
raise TypeError("first argument must be a class!")
self.cls = cls
self.optimizer = optimizer
self.args = args
self.kwargs = kwargs
self.sched = None

def compute_prepare_imp(self):
self.sched = self.cls(
self.optimizer.opt,
*self.args,
*self.kwargs)

def load_compute_imp(self, file: zipfile.ZipFile) -> bool:
try:
with file.open('state.pth', 'r') as f:
self.sched.load_state_dict(torch.load(f))
return True
except Exception:
return False

def save_compute_imp(self, file: zipfile.ZipFile) -> bool:
try:
with file.open('state.pth', 'w') as f:
torch.save(self.sched.state_dict(), f)
return True
except Exception:
return False

def compute_cleanup_imp(self):
del self.sched
self.sched = None


class Trainable(TorchTrainable):
def __init__(
self,
Expand Down Expand Up @@ -182,16 +220,13 @@ def __call__(
start_epoch = 0
if train_spec is not None:
start_epoch = train_spec.level_step()

# Type checking training data, and converting if necessary
batch_size = 32
data = data.torch().batch(batch_size=batch_size)
total_batches = data.count()

# Move variables to same device as model
devs = context().get_torch_devices()
data = data.map_el(lambda el: el.to(devs[0]))

# Check data is supervised.
if not data.supervised:
raise RuntimeError(
Expand All @@ -200,23 +235,76 @@ def __call__(
optimizer = self.optimizer.opt
loss = self.loss.obj
model = trainable.model

for i in range(start_epoch, self.epochs):

running_loss = 0.
num_batches = 0
t_data = tqdm.tqdm(data, total=total_batches)
for X, Y in t_data:
optimizer.zero_grad()

outputs = model(X)
loss_val = loss(outputs, Y)
loss_val.backward()
optimizer.step()

running_loss += loss_val.item()
num_batches += 1
av_loss = running_loss/(num_batches*batch_size)
t_data.set_postfix(loss=av_loss)
print(f"Epoch {i+1} - Average Loss: {av_loss}")


class LRBasicTraining(TrainFunction):
def __init__(
self,
optimizer: Wrapper = None,
loss: Wrapper = None,
scheduler: Wrapper = None,
epochs=1):
self.optimizer = optimizer
self.loss = loss
self.epochs = epochs
self.scheduler = scheduler
self.training_loss = []

def __call__(
self, trainable: Model, data: Dataset, train_spec=None,
train_callbacks=[]):

# Pop the epoch to resume from
start_epoch = 0
if train_spec is not None:
start_epoch = train_spec.level_step()
# Type checking training data, and converting if necessary
batch_size = 32
data = data.torch().batch(batch_size=batch_size)
total_batches = data.count()
# Move variables to same device as model
devs = context().get_torch_devices()
data = data.map_el(lambda el: el.to(devs[0]))
# Check data is supervised.
if not data.supervised:
raise RuntimeError(
f"{__class__} requires supervised data")
optimizer = self.optimizer.opt
loss = self.loss.obj
scheduler = self.scheduler.sched
model = trainable.model
for i in range(start_epoch, self.epochs):
running_loss = 0.
num_batches = 0
t_data = tqdm.tqdm(data, total=total_batches)
for X, Y in t_data:
optimizer.zero_grad()
outputs = model(X)
loss_val = loss(outputs, Y)
loss_val.backward()
optimizer.step()
running_loss += loss_val.item()
num_batches += 1
av_loss = running_loss/(num_batches*batch_size)
t_data.set_postfix(loss=av_loss)
scheduler.step(av_loss)
self.training_loss.append(av_loss)
print(f"Epoch {i+1} - Average Loss: {av_loss}")