Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion lightning_action/data/video_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,4 +810,3 @@ def input_size(self) -> int:
def num_classes(self) -> int:
"""Get the number of action classes."""
return self.dataset.num_classes
return self.dataset.num_classes
5 changes: 3 additions & 2 deletions lightning_action/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Models for action segmentation."""

from .heads import RNN, DilatedTCN, TemporalMLP
from .segmenter import BaseModel, Segmenter
from lightning_action.models.heads import RNN, DilatedTCN, TemporalMLP
from lightning_action.models.segmenter import BaseModel, Segmenter

__all__ = [
'BaseModel',
Expand All @@ -10,3 +10,4 @@
'RNN',
'TemporalMLP',
]

11 changes: 11 additions & 0 deletions lightning_action/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""Backbone architectures for action segmentation models."""

from lightning_action.models.backbones.resnet import ResNetBackbone
from lightning_action.models.backbones.resnet_beast import ResNetBeastBackbone
from lightning_action.models.backbones.vitmae import ViTMAEBackbone

__all__ = [
'ResNetBackbone',
'ResNetBeastBackbone',
'ViTMAEBackbone',
]
8 changes: 5 additions & 3 deletions lightning_action/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Head architectures for action segmentation models."""

from .rnn import RNN
from .tcn import DilatedTCN
from .temporalmlp import TemporalMLP
from lightning_action.models.heads.rnn import RNN
from lightning_action.models.heads.tcn import DilatedTCN
from lightning_action.models.heads.temporalmlp import TemporalMLP
from lightning_action.models.necks.mha_pooling import MultiheadAttentionPooling

__all__ = [
'DilatedTCN',
'RNN',
'TemporalMLP',
'MultiheadAttentionPooling'
]
8 changes: 5 additions & 3 deletions lightning_action/models/video_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from typeguard import typechecked

from lightning_action.models.segmenter import BaseModel
from lightning_action.models.backbones.vitmae import ViTMAEBackbone
from lightning_action.models.backbones.resnet import ResNetBackbone
from lightning_action.models.backbones.resnet_beast import ResNetBeastBackbone
from lightning_action.models.backbones import (
ResNetBackbone,
ResNetBeastBackbone,
ViTMAEBackbone,
)
from lightning_action.models.necks.mha_pooling import MultiheadAttentionPooling
from lightning_action.models.heads import DilatedTCN, TemporalMLP, RNN

Expand Down
Loading