diff --git a/lightning_action/data/video_datamodule.py b/lightning_action/data/video_datamodule.py index 9cf21c6..fb9df41 100644 --- a/lightning_action/data/video_datamodule.py +++ b/lightning_action/data/video_datamodule.py @@ -810,4 +810,3 @@ def input_size(self) -> int: def num_classes(self) -> int: """Get the number of action classes.""" return self.dataset.num_classes - return self.dataset.num_classes diff --git a/lightning_action/models/__init__.py b/lightning_action/models/__init__.py index f64977d..c5ccf71 100644 --- a/lightning_action/models/__init__.py +++ b/lightning_action/models/__init__.py @@ -1,7 +1,7 @@ """Models for action segmentation.""" -from .heads import RNN, DilatedTCN, TemporalMLP -from .segmenter import BaseModel, Segmenter +from lightning_action.models.heads import RNN, DilatedTCN, TemporalMLP +from lightning_action.models.segmenter import BaseModel, Segmenter __all__ = [ 'BaseModel', @@ -10,3 +10,4 @@ 'RNN', 'TemporalMLP', ] + diff --git a/lightning_action/models/backbones/__init__.py b/lightning_action/models/backbones/__init__.py new file mode 100644 index 0000000..6b80d6e --- /dev/null +++ b/lightning_action/models/backbones/__init__.py @@ -0,0 +1,11 @@ +"""Backbone architectures for action segmentation models.""" + +from lightning_action.models.backbones.resnet import ResNetBackbone +from lightning_action.models.backbones.resnet_beast import ResNetBeastBackbone +from lightning_action.models.backbones.vitmae import ViTMAEBackbone + +__all__ = [ + 'ResNetBackbone', + 'ResNetBeastBackbone', + 'ViTMAEBackbone', +] diff --git a/lightning_action/models/heads/__init__.py b/lightning_action/models/heads/__init__.py index e0da814..f478baa 100644 --- a/lightning_action/models/heads/__init__.py +++ b/lightning_action/models/heads/__init__.py @@ -1,11 +1,13 @@ """Head architectures for action segmentation models.""" -from .rnn import RNN -from .tcn import DilatedTCN -from .temporalmlp import TemporalMLP +from lightning_action.models.heads.rnn import RNN +from lightning_action.models.heads.tcn import DilatedTCN +from lightning_action.models.heads.temporalmlp import TemporalMLP +from lightning_action.models.necks.mha_pooling import MultiheadAttentionPooling __all__ = [ 'DilatedTCN', 'RNN', 'TemporalMLP', + 'MultiheadAttentionPooling' ] diff --git a/lightning_action/models/video_segmenter.py b/lightning_action/models/video_segmenter.py index 77b8db2..2b4ca11 100644 --- a/lightning_action/models/video_segmenter.py +++ b/lightning_action/models/video_segmenter.py @@ -20,9 +20,11 @@ from typeguard import typechecked from lightning_action.models.segmenter import BaseModel -from lightning_action.models.backbones.vitmae import ViTMAEBackbone -from lightning_action.models.backbones.resnet import ResNetBackbone -from lightning_action.models.backbones.resnet_beast import ResNetBeastBackbone +from lightning_action.models.backbones import ( + ResNetBackbone, + ResNetBeastBackbone, + ViTMAEBackbone, +) from lightning_action.models.necks.mha_pooling import MultiheadAttentionPooling from lightning_action.models.heads import DilatedTCN, TemporalMLP, RNN