Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 13 additions & 6 deletions deeplay/applications/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ def fit(
)

history = LogHistory()
progressbar = RichProgressBar()
aux_callbacks = [history]

callbacks = callbacks + [history, progressbar]
callbacks = callbacks + aux_callbacks
trainer = dl.Trainer(max_epochs=max_epochs, callbacks=callbacks, **kwargs)

train_dataloader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -243,12 +243,15 @@ def configure_optimizers(self):
) from e

def training_step(self, batch, batch_idx):

x, y = self.train_preprocess(batch)
y_hat = self(x)
loss = self.compute_loss(y_hat, y)
if not isinstance(loss, dict):
loss = {"loss": loss}

assert "loss" in loss, "the output of compute_loss should contain a 'loss' key"

for name, v in loss.items():
self.log(
f"train_{name}",
Expand All @@ -263,7 +266,7 @@ def training_step(self, batch, batch_idx):
"train", y_hat, y, on_step=True, on_epoch=True, prog_bar=True, logger=True
)

return sum(loss.values())
return loss["loss"]

def validation_step(self, batch, batch_idx):
x, y = self.val_preprocess(batch)
Expand All @@ -290,7 +293,7 @@ def validation_step(self, batch, batch_idx):
prog_bar=True,
logger=True,
)
return sum(loss.values())
return loss["loss"] if "loss" in loss else 0

def test_step(self, batch, batch_idx):
x, y = self.test_preprocess(batch)
Expand Down Expand Up @@ -318,7 +321,7 @@ def test_step(self, batch, batch_idx):
logger=True,
)

return sum(loss.values())
return loss["loss"] if "loss" in loss else 0

def predict_step(self, batch, batch_idx, dataloader_idx=None):
if isinstance(batch, (list, tuple)):
Expand Down Expand Up @@ -356,12 +359,16 @@ def trainer(self, trainer):
if module is self:
continue
try:
if hasattr(module, "trainer") and module.trainer is not trainer:
if isinstance(module, L.LightningModule) or hasattr(module, "trainer"):

module.trainer = trainer

except RuntimeError:
# hasattr can raise RuntimeError if the module is not attached to a trainer
if isinstance(module, L.LightningModule):
print("Battaching trainer to", module)
module.trainer = trainer
print("Battached trainer to", module)

@staticmethod
def clone_metrics(metrics: T) -> T:
Expand Down
15 changes: 12 additions & 3 deletions deeplay/applications/detection/lodestar/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ def _backward(x, angle, indices):
mat2d[:, indices[1], indices[0]] = -torch.sin(-angle)
mat2d[:, indices[0], indices[1]] = torch.sin(-angle)
mat2d[:, indices[0], indices[0]] = torch.cos(-angle)
out = torch.matmul(x.unsqueeze(1), mat2d).squeeze(1)

return out

if len(x.size()) == 2:
# (B, C) -> (B, 1, C)
x = x.unsqueeze(1)
return torch.matmul(x, mat2d).squeeze(1)

x_expanded = x.view(x.size(0), 1, x.size(1), -1)
y = torch.einsum("bijm,bjk->bikm", x_expanded, mat2d)

return y.view(x.size())


41 changes: 41 additions & 0 deletions deeplay/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,6 +723,39 @@ def replace(self, target: str, replacement: nn.Module):

self._modules[target] = replacement

@after_init
def schedule(self, **schedulers):
for attr, scheduler in schedulers.items():
setattr(self, attr, scheduler)

@after_init
def schedule_linear(
self,
attr: str,
start_value: float,
end_value: float,
n_steps: int,
on_epoch: bool = False,
):
from deeplay.schedulers import LinearScheduler

setattr(self, attr, LinearScheduler(start_value, end_value, n_steps, on_epoch))

@after_init
def schedule_loglinear(
self,
attr: str,
start_value: float,
end_value: float,
n_steps: int,
on_epoch: bool = False,
):
from deeplay.schedulers import LogLinearScheduler

setattr(
self, attr, LogLinearScheduler(start_value, end_value, n_steps, on_epoch)
)

@stateful
def configure(self, *args: Any, **kwargs: Any):
"""
Expand Down Expand Up @@ -1335,6 +1368,14 @@ def __setattr__(self, name, value):
# # ensure that logs are stored in the correct place
# value.set_root_module(self.root_module)

def __getattr__(self, name):
from deeplay.schedulers import BaseScheduler

x = super().__getattr__(name)
if self._has_built and isinstance(x, BaseScheduler):
return x.__get__(self, type(self))
return x

@stateful
def _set_submodule(self, name, module, tags):

Expand Down
5 changes: 5 additions & 0 deletions deeplay/schedulers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .scheduler import BaseScheduler
from .linear import LinearScheduler
from .constant import ConstantScheduler
from .loglinear import LogLinearScheduler
from .sequence import SchedulerSequence
18 changes: 18 additions & 0 deletions deeplay/schedulers/constant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from . import BaseScheduler


class ConstantScheduler(BaseScheduler):
"""Sheduler that returns constant value."""

def __init__(self, value, on_epoch=False):
super().__init__(on_epoch)
self.value = value

def __call__(self, step):
return self.value

def __repr__(self):
return f"{self.__class__.__name__}({self.value})"

def __str__(self):
return repr(self)
42 changes: 42 additions & 0 deletions deeplay/schedulers/linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from .scheduler import BaseScheduler


class LinearScheduler(BaseScheduler):
"""Scheduler that returns linearly changing value from start_value to end_value.

For steps beyond n_steps, returns end_value.
For steps before 0, returns start_value.

Parameters
----------
start_value : float
Initial value of the scheduler.
end_value : float
Final value of the scheduler.
n_steps : int
Number of steps to reach end_value.
on_epoch : bool
If True, the step is taken from the epoch counter of the trainer.
Otherwise, the step is taken from the global step counter of the trainer.
"""

def __init__(self, start_value, end_value, n_steps, on_epoch=False):
super().__init__(on_epoch)
self.start_value = start_value
self.end_value = end_value
self.n_steps = n_steps

def __call__(self, step):
if step < 0:
return self.start_value
if step >= self.n_steps:
return self.end_value
return (
self.start_value + (self.end_value - self.start_value) * step / self.n_steps
)

def __repr__(self):
return f"{self.__class__.__name__}({self.start_value}, {self.end_value}, {self.n_steps})"

def __str__(self):
return repr(self)
33 changes: 33 additions & 0 deletions deeplay/schedulers/loglinear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from . import BaseScheduler
import numpy as np


class LogLinearScheduler(BaseScheduler):
"""Scheduler that returns log-linearly changing value from start_value to end_value.

For steps beyond n_steps, returns end_value."""

def __init__(self, start_value, end_value, n_steps, on_epoch=False):
super().__init__(on_epoch)
assert np.sign(start_value) == np.sign(
end_value
), "Start and end values must have the same sign"
assert start_value != 0, "Start value must be non-zero"
assert end_value != 0, "End value must be non-zero"
assert n_steps > 0, "Number of steps must be greater than 0"
self.start_value = start_value
self.end_value = end_value
self.n_steps = n_steps

def __call__(self, step):
if step >= self.n_steps:
return self.end_value
return self.start_value * (self.end_value / self.start_value) ** (
step / self.n_steps
)

def __repr__(self):
return f"{self.__class__.__name__}({self.start_value}, {self.end_value}, {self.n_steps})"

def __str__(self):
return repr(self)
47 changes: 47 additions & 0 deletions deeplay/schedulers/scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import lightning as L

from deeplay.module import DeeplayModule
from deeplay.trainer import Trainer


class BaseScheduler(DeeplayModule, L.LightningModule):
"""Base class for annealers."""

step: int

def __init__(self, on_epoch=False):
super().__init__()
self.on_epoch = on_epoch
self._step = 0
self._x = None

def set_step(self, step):
self._step = step
self._x = self(step)

def update(self):
current_step = self._step

if self._trainer:
updated_step = (
self.trainer.current_epoch
if self.on_epoch
else self.trainer.global_step
)
else:
updated_step = self._step

if updated_step != current_step or self._x is None:
self.set_step(updated_step)

def __get__(self, obj, objtype=None):
if obj is None:
return self
self.update()
return self._x

def __set__(self, obj, value):
self._x = value

def __call__(self, step):
raise NotImplementedError
36 changes: 36 additions & 0 deletions deeplay/schedulers/sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from .scheduler import BaseScheduler


class SchedulerSequence(BaseScheduler):
"""Scheduler that returns value from one of the schedulers in the chain.

The scheduler is chosen based on the current step.
"""

def __init__(self, on_epoch=False):
super().__init__(on_epoch)
self.schedulers = []

def add(self, scheduler, n_steps=None):
if n_steps is None:
assert hasattr(
scheduler, "n_steps"
), "For a scheduler without n_steps, n_steps must be specified"
n_steps = scheduler.n_steps

self.schedulers.append((n_steps, scheduler))

def __call__(self, step):
for n_steps, scheduler in self.schedulers:
if step < n_steps:
return scheduler(step)
step -= n_steps

final_step, final_scheduler = self.schedulers[-1]
return final_scheduler(final_step + step)

def __repr__(self):
return f"{self.__class__.__name__}({self.schedulers})"

def __str__(self):
return repr(self)
Empty file.
Loading
Loading