Skip to content

Commit 70495c7

Browse files
authored
Improve support for multi-task training (#36)
1 parent 967e553 commit 70495c7

File tree

10 files changed

+505
-113
lines changed

10 files changed

+505
-113
lines changed

mmlearn/datasets/nyuv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from hydra_zen import MISSING, store
99
from lightning_utilities.core.imports import RequirementCache
10-
from PIL.Image import Image as PILImage
10+
from PIL import Image as PILImage
1111
from torch.utils.data import Dataset
1212
from torchvision.transforms.v2.functional import to_pil_image
1313

@@ -120,7 +120,7 @@ def __init__(
120120

121121
root_dir = os.path.join(root_dir, split)
122122
depth_files = [os.path.join(root_dir, "depth", f"{f}.png") for f in file_ids]
123-
rgb_files = [os.path.join(root_dir, "rgb", f"{f}.jpg") for f in file_ids]
123+
rgb_files = [os.path.join(root_dir, "rgb", f"{f}.png") for f in file_ids]
124124

125125
label_files = [
126126
os.path.join(root_dir, "scene_class", f"{f}.txt") for f in file_ids

mmlearn/datasets/sunrgbd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from hydra_zen import MISSING, store
99
from lightning_utilities.core.imports import RequirementCache
10-
from PIL.Image import Image as PILImage
10+
from PIL import Image as PILImage
1111
from torch.utils.data import Dataset
1212
from torchvision.transforms.v2.functional import to_pil_image
1313

mmlearn/modules/encoders/vision.py

Lines changed: 118 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
import math
44
from functools import partial
5-
from typing import Any, Callable, Optional, Union, cast
5+
from typing import Any, Callable, Literal, Optional, Union, cast
66

77
import timm
88
import torch
99
from hydra_zen import store
1010
from peft import PeftConfig
1111
from timm.models.vision_transformer import VisionTransformer as TimmVisionTransformer
12+
from timm.models.vision_transformer import global_pool_nlc
1213
from torch import nn
1314
from transformers.modeling_outputs import BaseModelOutput
1415

@@ -33,6 +34,9 @@ class TimmViT(nn.Module):
3334
----------
3435
model_name : str
3536
The name of the model to use.
37+
modality : str, default="RGB"
38+
The modality of the input data. This allows this model to be used with different
39+
image modalities e.g. RGB, Depth, etc.
3640
projection_dim : int, default=768
3741
The dimension of the projection head.
3842
pretrained : bool, default=True
@@ -51,6 +55,7 @@ class TimmViT(nn.Module):
5155
def __init__(
5256
self,
5357
model_name: str,
58+
modality: str = "RGB",
5459
projection_dim: int = 768,
5560
pretrained: bool = True,
5661
freeze_layers: Union[int, float, list[int], bool] = False,
@@ -59,6 +64,7 @@ def __init__(
5964
model_kwargs: Optional[dict[str, Any]] = None,
6065
) -> None:
6166
super().__init__()
67+
self.modality = Modalities.get_modality(modality)
6268
if model_kwargs is None:
6369
model_kwargs = {}
6470

@@ -124,7 +130,7 @@ def forward(self, inputs: dict[str, Any]) -> BaseModelOutput:
124130
BaseModelOutput
125131
The output of the model.
126132
"""
127-
x = inputs[Modalities.RGB.name]
133+
x = inputs[self.modality.name]
128134
last_hidden_state, hidden_states = self.model.forward_intermediates(
129135
x, output_fmt="NLC"
130136
)
@@ -175,8 +181,11 @@ class VisionTransformer(nn.Module):
175181
176182
Parameters
177183
----------
178-
img_size : Optional[list[int]], optional, default=None
179-
list of input image sizes.
184+
modality : str, optional, default="RGB"
185+
The modality of the input data. This allows this model to be used with different
186+
image modalities e.g. RGB, Depth, etc.
187+
img_size : List[int], optional, default=None
188+
List of input image sizes.
180189
patch_size : int, optional, default=16
181190
Size of each patch.
182191
in_chans : int, optional, default=3
@@ -209,6 +218,7 @@ class VisionTransformer(nn.Module):
209218

210219
def __init__(
211220
self,
221+
modality: str = "RGB",
212222
img_size: Optional[list[int]] = None,
213223
patch_size: int = 16,
214224
in_chans: int = 3,
@@ -218,14 +228,17 @@ def __init__(
218228
mlp_ratio: float = 4.0,
219229
qkv_bias: bool = True,
220230
qk_scale: Optional[float] = None,
231+
global_pool: Literal["", "avg", "avgmax", "max", "token"] = "",
221232
drop_rate: float = 0.0,
222233
attn_drop_rate: float = 0.0,
223234
drop_path_rate: float = 0.0,
224235
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
225236
init_std: float = 0.02,
226-
**kwargs: Any,
227237
) -> None:
228238
super().__init__()
239+
assert global_pool in ("", "avg", "avgmax", "max", "token")
240+
241+
self.modality = Modalities.get_modality(modality)
229242
self.num_features = self.embed_dim = embed_dim
230243
self.num_heads = num_heads
231244
img_size = [224, 224] if img_size is None else img_size
@@ -272,6 +285,8 @@ def __init__(
272285
)
273286
self.norm = norm_layer(embed_dim)
274287

288+
self.global_pool = global_pool
289+
275290
# Weight Initialization
276291
self.init_std = init_std
277292
self.apply(self._init_weights)
@@ -301,15 +316,14 @@ def _init_weights(self, m: nn.Module) -> None:
301316
nn.init.constant_(m.bias, 0)
302317

303318
def forward(
304-
self,
305-
x: torch.Tensor,
306-
masks: Optional[Union[torch.Tensor, list[torch.Tensor]]] = None,
307-
return_hidden_states: bool = False,
308-
) -> Union[torch.Tensor, tuple[torch.Tensor, list[torch.Tensor]]]:
319+
self, inputs: dict[str, Any], return_hidden_states: bool = False
320+
) -> tuple[torch.Tensor, Optional[list[torch.Tensor]]]:
309321
"""Forward pass through the Vision Transformer."""
322+
masks = inputs.get(self.modality.mask)
310323
if masks is not None and not isinstance(masks, list):
311324
masks = [masks]
312325

326+
x = inputs[self.modality.name]
313327
# -- Patchify x
314328
x = self.patch_embed(x)
315329

@@ -336,10 +350,13 @@ def forward(
336350
if self.norm is not None:
337351
x = self.norm(x)
338352

353+
# -- Apply global pooling
354+
x = global_pool_nlc(x, pool_type=self.global_pool)
355+
339356
# -- Return both final output and hidden states if requested
340357
if return_hidden_states:
341358
return x, hidden_states
342-
return x
359+
return (x, None)
343360

344361
def interpolate_pos_encoding(
345362
self, x: torch.Tensor, pos_embed: torch.Tensor
@@ -586,8 +603,7 @@ def _trunc_normal(
586603
def _no_grad_trunc_normal_(
587604
tensor: torch.Tensor, mean: float, std: float, a: float, b: float
588605
) -> torch.Tensor:
589-
"""
590-
Apply truncated normal initialization to a tensor.
606+
"""Apply truncated normal initialization to a tensor.
591607
592608
Parameters
593609
----------
@@ -633,15 +649,23 @@ def norm_cdf(x: float) -> float:
633649
provider="mmlearn",
634650
),
635651
)
636-
def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
637-
"""
638-
Create a VisionTransformerPredictor model.
652+
def vit_predictor(
653+
kwargs: Optional[dict[str, Any]] = None,
654+
) -> VisionTransformerPredictor:
655+
"""Create a VisionTransformerPredictor model.
656+
657+
Parameters
658+
----------
659+
kwargs : dict[str, Any], optional, default=None
660+
Keyword arguments for the predictor.
639661
640662
Returns
641663
-------
642664
VisionTransformerPredictor
643665
An instance of VisionTransformerPredictor.
644666
"""
667+
if kwargs is None:
668+
kwargs = {}
645669
return VisionTransformerPredictor(
646670
mlp_ratio=4, qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs
647671
)
@@ -654,15 +678,25 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
654678
provider="mmlearn",
655679
),
656680
)
657-
def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
658-
"""
659-
Create a VisionTransformer model with tiny configuration.
681+
def vit_tiny(
682+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
683+
) -> VisionTransformer:
684+
"""Create a VisionTransformer model with tiny configuration.
685+
686+
Parameters
687+
----------
688+
patch_size : int, default=16
689+
Size of each patch.
690+
kwargs : dict[str, Any], optional, default=None
691+
Keyword arguments for the model variant.
660692
661693
Returns
662694
-------
663695
VisionTransformer
664696
An instance of VisionTransformer.
665697
"""
698+
if kwargs is None:
699+
kwargs = {}
666700
return VisionTransformer(
667701
patch_size=patch_size,
668702
embed_dim=192,
@@ -682,15 +716,25 @@ def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
682716
provider="mmlearn",
683717
),
684718
)
685-
def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
686-
"""
687-
Create a VisionTransformer model with small configuration.
719+
def vit_small(
720+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
721+
) -> VisionTransformer:
722+
"""Create a VisionTransformer model with small configuration.
723+
724+
Parameters
725+
----------
726+
patch_size : int, default=16
727+
Size of each patch.
728+
kwargs : dict[str, Any], optional, default=None
729+
Keyword arguments for the model variant.
688730
689731
Returns
690732
-------
691733
VisionTransformer
692734
An instance of VisionTransformer.
693735
"""
736+
if kwargs is None:
737+
kwargs = {}
694738
return VisionTransformer(
695739
patch_size=patch_size,
696740
embed_dim=384,
@@ -710,15 +754,25 @@ def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
710754
provider="mmlearn",
711755
),
712756
)
713-
def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
714-
"""
715-
Create a VisionTransformer model with base configuration.
757+
def vit_base(
758+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
759+
) -> VisionTransformer:
760+
"""Create a VisionTransformer model with base configuration.
761+
762+
Parameters
763+
----------
764+
patch_size : int, default=16
765+
Size of each patch.
766+
kwargs : dict[str, Any], optional, default=None
767+
Keyword arguments for the model variant.
716768
717769
Returns
718770
-------
719771
VisionTransformer
720772
An instance of VisionTransformer.
721773
"""
774+
if kwargs is None:
775+
kwargs = {}
722776
return VisionTransformer(
723777
patch_size=patch_size,
724778
embed_dim=768,
@@ -738,15 +792,25 @@ def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
738792
provider="mmlearn",
739793
),
740794
)
741-
def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
742-
"""
743-
Create a VisionTransformer model with large configuration.
795+
def vit_large(
796+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
797+
) -> VisionTransformer:
798+
"""Create a VisionTransformer model with large configuration.
799+
800+
Parameters
801+
----------
802+
patch_size : int, default=16
803+
Size of each patch.
804+
kwargs : dict[str, Any], optional, default=None
805+
Keyword arguments for the model variant.
744806
745807
Returns
746808
-------
747809
VisionTransformer
748810
An instance of VisionTransformer.
749811
"""
812+
if kwargs is None:
813+
kwargs = {}
750814
return VisionTransformer(
751815
patch_size=patch_size,
752816
embed_dim=1024,
@@ -766,15 +830,25 @@ def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
766830
provider="mmlearn",
767831
),
768832
)
769-
def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
770-
"""
771-
Create a VisionTransformer model with huge configuration.
833+
def vit_huge(
834+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
835+
) -> VisionTransformer:
836+
"""Create a VisionTransformer model with huge configuration.
837+
838+
Parameters
839+
----------
840+
patch_size : int, default=16
841+
Size of each patch.
842+
kwargs : dict[str, Any], optional, default=None
843+
Keyword arguments for the model variant.
772844
773845
Returns
774846
-------
775847
VisionTransformer
776848
An instance of VisionTransformer.
777849
"""
850+
if kwargs is None:
851+
kwargs = {}
778852
return VisionTransformer(
779853
patch_size=patch_size,
780854
embed_dim=1280,
@@ -794,15 +868,25 @@ def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
794868
provider="mmlearn",
795869
),
796870
)
797-
def vit_giant(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
798-
"""
799-
Create a VisionTransformer model with giant configuration.
871+
def vit_giant(
872+
patch_size: int = 16, kwargs: Optional[dict[str, Any]] = None
873+
) -> VisionTransformer:
874+
"""Create a VisionTransformer model with giant configuration.
875+
876+
Parameters
877+
----------
878+
patch_size : int, default=16
879+
Size of each patch.
880+
kwargs : dict[str, Any], optional, default=None
881+
Keyword arguments for the model variant.
800882
801883
Returns
802884
-------
803885
VisionTransformer
804886
An instance of VisionTransformer.
805887
"""
888+
if kwargs is None:
889+
kwargs = {}
806890
return VisionTransformer(
807891
patch_size=patch_size,
808892
embed_dim=1408,

0 commit comments

Comments
 (0)