Skip to content

Commit ecbdac2

Browse files
implement SAM backbone with huggingface transformers package (#301)
* non-context vit test passing * all vit tests passing * all tests passing * rebuild docs
1 parent 063e9ef commit ecbdac2

38 files changed

+353
-284
lines changed

docs/api/lightning_pose.models.backbones.vit_img_encoder.resample_abs_pos_embed_nhwc.rst

Lines changed: 0 additions & 6 deletions
This file was deleted.

docs/api/lightning_pose.models.backbones.vit_img_encoder.ImageEncoderViT_FT.rst renamed to docs/api/lightning_pose.models.backbones.vit_sam.SamVisionEncoderHF.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
ImageEncoderViT_FT
1+
SamVisionEncoderHF
22
==================
33

4-
.. currentmodule:: lightning_pose.models.backbones.vit_img_encoder
4+
.. currentmodule:: lightning_pose.models.backbones.vit_sam
55

6-
.. autoclass:: ImageEncoderViT_FT
6+
.. autoclass:: SamVisionEncoderHF
77
:show-inheritance:
88

99
.. rubric:: Methods Summary
1010

1111
.. autosummary::
1212

13-
~ImageEncoderViT_FT.forward
13+
~SamVisionEncoderHF.forward
1414

1515
.. rubric:: Methods Documentation
1616

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
load_sam_vision_encoder_hf
2+
==========================
3+
4+
.. currentmodule:: lightning_pose.models.backbones.vit_sam
5+
6+
.. autofunction:: load_sam_vision_encoder_hf

docs/modules/lightning_pose.models.backbones.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ lightning\_pose.models.backbones
44
.. automodapi:: lightning_pose.models.backbones.torchvision
55
:no-inheritance-diagram:
66

7-
.. automodapi:: lightning_pose.models.backbones.vit_img_encoder
7+
.. automodapi:: lightning_pose.models.backbones.vit_sam
88
:no-inheritance-diagram:
99

1010
.. automodapi:: lightning_pose.models.backbones.vits

docs/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ streamlit
2525
tensorboard
2626
torchtyping
2727
torchvision
28+
transformers
2829
typeguard
2930
typing
3031
nvidia-dali-cuda110
31-
segment_anything @ git+https://github.com/facebookresearch/segment-anything.git

docs/source/user_guide/config_file.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ The following parameters relate to model architecture and unsupervised losses.
207207
* efficientnet_b0: EfficientNet-B0 pretrained on ImageNet
208208
* efficientnet_b1: EfficientNet-B1 pretrained on ImageNet
209209
* efficientnet_b2: EfficientNet-B2 pretrained on ImageNet
210-
* vit_b_sam: Segment Anything Model (Vision Transformer Base)
210+
* vitb_sam: Segment Anything Model (Vision Transformer Base)
211211

212212
Note: the file size for a single ResNet-50 network is approximately 275 MB.
213213

lightning_pose/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import importlib.metadata
12
from pathlib import Path
3+
from typing import Any
24

35
from omegaconf import OmegaConf
46

@@ -8,9 +10,6 @@
810

911
# Hacky way to get version from pypackage.toml.
1012
# Adapted from: https://github.com/python-poetry/poetry/issues/273#issuecomment-1877789967
11-
from typing import Any
12-
import importlib.metadata
13-
from pathlib import Path
1413

1514
__package_version = "unknown"
1615

@@ -33,9 +32,10 @@ def __get_package_version() -> str:
3332
# Fall back on getting it from a local pyproject.toml.
3433
# This works in a development environment where the
3534
# package has not been installed from a distribution.
36-
import toml
3735
import warnings
3836

37+
import toml
38+
3939
warnings.warn(
4040
"lightning-pose not pip-installed, getting version from pyproject.toml."
4141
)

lightning_pose/api/model.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from omegaconf import DictConfig, OmegaConf
88

99
from lightning_pose.api.model_config import ModelConfig
10-
from lightning_pose.data.datatypes import (
11-
MultiviewPredictionResult,
12-
PredictionResult,
13-
)
10+
from lightning_pose.data.datatypes import MultiviewPredictionResult, PredictionResult
1411
from lightning_pose.models import ALLOWED_MODELS
1512
from lightning_pose.utils import io as io_utils
1613
from lightning_pose.utils.predictions import (

lightning_pose/apps/labeled_frame_diagnostics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
import seaborn as sns
1818
import streamlit as st
1919

20-
from lightning_pose.apps.plots import get_y_label, make_plotly_catplot, make_plotly_scatterplot
20+
from lightning_pose.apps.plots import (
21+
get_y_label,
22+
make_plotly_catplot,
23+
make_plotly_scatterplot,
24+
)
2125
from lightning_pose.apps.utils import (
2226
build_precomputed_metrics_df,
2327
get_df_box,

lightning_pose/apps/video_diagnostics.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
import pandas as pd
1717
import streamlit as st
1818

19-
from lightning_pose.apps.plots import get_y_label, make_plotly_catplot, plot_precomputed_traces
19+
from lightning_pose.apps.plots import (
20+
get_y_label,
21+
make_plotly_catplot,
22+
plot_precomputed_traces,
23+
)
2024
from lightning_pose.apps.utils import (
2125
build_precomputed_metrics_df,
2226
concat_dfs,

0 commit comments

Comments
 (0)