Skip to content

Commit cc598aa

Browse files
committed
Add next set of GA models
Summary: Add a few more tasks: 1. Image-Text Understanding (OpenCLIP) 2. Semantic Text Search (Sentence Transformers) 3. Document Q&A (DistilBERT QA) 4. Practical Image Enhancement (Real-ESRGAN) 5. Audio Classification (AST) 6. Text Sentiment Analysis (RoBERTa) 7. Depth estimation (Depth Anything 2)
1 parent cf78305 commit cc598aa

File tree

17 files changed

+538
-2
lines changed

17 files changed

+538
-2
lines changed

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ jobs:
6363
contents: read
6464
strategy:
6565
matrix:
66-
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe, efficientnet_b4, detr_resnet50, segformer_ade, albert, wav2vec2]
66+
model: [linear, add, add_mul, ic3, ic4, mv2, mv3, resnet18, resnet50, vit, w2l, mobilebert, emformer_join, emformer_transcribe, efficientnet_b4, detr_resnet50, segformer_ade, albert, wav2vec2, clip, sentence_transformers, distilbert_qa, real_esrgan, audio_spectrogram_transformer, roberta_sentiment, depth_anything_v2]
6767
backend: [portable, xnnpack-quantization-delegation]
6868
runner: [linux.arm64.2xlarge]
6969
include:

examples/models/__init__.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,16 @@ class Model(str, Enum):
4141
DetrResNet50 = "detr_resnet50"
4242
SegformerADE = "segformer_ade"
4343
Albert = "albert"
44-
BiLSTM = "bilstm"
4544
Swin2SR2x = "swin2sr_2x"
4645
TrOCRHandwritten = "trocr_handwritten"
4746
Wav2Vec2 = "wav2vec2"
47+
CLIP = "clip"
48+
SentenceTransformers = "sentence_transformers"
49+
DistilBertQA = "distilbert_qa"
50+
RealESRGAN = "real_esrgan"
51+
AudioSpectrogramTransformer = "audio_spectrogram_transformer"
52+
RobertaSentiment = "roberta_sentiment"
53+
DepthAnythingV2 = "depth_anything_v2"
4854

4955
def __str__(self) -> str:
5056
return self.value
@@ -97,6 +103,19 @@ def __str__(self) -> str:
97103
str(Model.Swin2SR2x): ("swin2sr_2x", "Swin2SR2xModel"),
98104
str(Model.TrOCRHandwritten): ("trocr_handwritten", "TrOCRHandwrittenModel"),
99105
str(Model.Wav2Vec2): ("wav2vec2", "Wav2Vec2Model"),
106+
str(Model.CLIP): ("clip", "CLIPModel"),
107+
str(Model.SentenceTransformers): (
108+
"sentence_transformers",
109+
"SentenceTransformersModel",
110+
),
111+
str(Model.DistilBertQA): ("distilbert_qa", "DistilBertQAModel"),
112+
str(Model.RealESRGAN): ("real_esrgan", "RealESRGANModel"),
113+
str(Model.AudioSpectrogramTransformer): (
114+
"audio_spectrogram_transformer",
115+
"AudioSpectrogramTransformerModel",
116+
),
117+
str(Model.RobertaSentiment): ("roberta_sentiment", "RobertaSentimentModel"),
118+
str(Model.DepthAnythingV2): ("depth_anything_v2", "DepthAnythingV2Model"),
100119
}
101120

102121
__all__ = [
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import AudioSpectrogramTransformerModel
8+
9+
__all__ = ["AudioSpectrogramTransformerModel"]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from transformers import ASTFeatureExtractor, ASTForAudioClassification
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class AudioSpectrogramTransformerWrapper(torch.nn.Module):
16+
"""Wrapper for HuggingFace Audio Spectrogram Transformer model to make it torch.export compatible"""
17+
18+
def __init__(self, model_name="MIT/ast-finetuned-audioset-10-10-0.4593"):
19+
super().__init__()
20+
self.model = ASTForAudioClassification.from_pretrained(model_name)
21+
self.feature_extractor = ASTFeatureExtractor.from_pretrained(model_name)
22+
self.model.eval()
23+
24+
def forward(self, input_values):
25+
# Audio classification with AST
26+
with torch.no_grad():
27+
outputs = self.model(input_values)
28+
29+
# Return classification logits
30+
return outputs.logits
31+
32+
33+
class AudioSpectrogramTransformerModel(EagerModelBase):
34+
def __init__(self):
35+
pass
36+
37+
def get_eager_model(self) -> torch.nn.Module:
38+
logging.info("Loading Audio Spectrogram Transformer model from HuggingFace")
39+
model = AudioSpectrogramTransformerWrapper(
40+
"MIT/ast-finetuned-audioset-10-10-0.4593"
41+
)
42+
model.eval()
43+
logging.info("Loaded Audio Spectrogram Transformer model")
44+
return model
45+
46+
def get_example_inputs(self):
47+
# Example inputs for AST
48+
# Audio spectrogram: batch_size=1, time_steps=1024, freq_bins=128
49+
input_values = torch.randn(1, 1024, 128)
50+
51+
return (input_values,)

examples/models/clip/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import CLIPModel
8+
9+
__all__ = ["CLIPModel"]

examples/models/clip/model.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from transformers import CLIPModel as HFCLIPModel, CLIPProcessor
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class OpenCLIPWrapper(torch.nn.Module):
16+
"""Wrapper for OpenCLIP model to make it torch.export compatible"""
17+
18+
def __init__(self, model_name="laion/CLIP-ViT-B-32-laion2B-s34B-b79K"):
19+
super().__init__()
20+
self.model = HFCLIPModel.from_pretrained(model_name)
21+
self.processor = CLIPProcessor.from_pretrained(model_name)
22+
self.model.eval()
23+
24+
def forward(self, pixel_values, input_ids, attention_mask):
25+
# Extract image and text features
26+
with torch.no_grad():
27+
outputs = self.model(
28+
pixel_values=pixel_values,
29+
input_ids=input_ids,
30+
attention_mask=attention_mask,
31+
return_loss=False,
32+
)
33+
34+
# Return image and text embeddings
35+
return outputs.image_embeds, outputs.text_embeds
36+
37+
38+
class CLIPModel(EagerModelBase):
39+
def __init__(self):
40+
pass
41+
42+
def get_eager_model(self) -> torch.nn.Module:
43+
logging.info("Loading OpenCLIP model from HuggingFace")
44+
model = OpenCLIPWrapper("laion/CLIP-ViT-B-32-laion2B-s34B-b79K")
45+
model.eval()
46+
logging.info("Loaded OpenCLIP model")
47+
return model
48+
49+
def get_example_inputs(self):
50+
# Example inputs for CLIP
51+
# Image: batch_size=1, channels=3, height=224, width=224
52+
pixel_values = torch.randn(1, 3, 224, 224)
53+
54+
# Text: batch_size=1, max_length=77 (CLIP's typical context length)
55+
input_ids = torch.randint(0, 49408, (1, 77)) # CLIP vocab size is ~49408
56+
attention_mask = torch.ones(1, 77)
57+
58+
return (pixel_values, input_ids, attention_mask)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import DepthAnythingV2Model
8+
9+
__all__ = ["DepthAnythingV2Model"]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.examples.models.model_base import EagerModelBase
9+
10+
11+
class DepthAnythingV2Model(EagerModelBase):
12+
def __init__(self, model_name="depth-anything/Depth-Anything-V2-Small-hf"):
13+
self.model_name = model_name
14+
15+
def _load_model(self):
16+
"""Load the Depth Anything V2 model from HuggingFace"""
17+
try:
18+
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
19+
except ImportError:
20+
raise ImportError(
21+
"transformers is required for DepthAnythingV2Model. "
22+
"Install with: pip install transformers"
23+
)
24+
25+
# Load model and processor
26+
self.processor = AutoImageProcessor.from_pretrained(self.model_name)
27+
model = AutoModelForDepthEstimation.from_pretrained(self.model_name)
28+
29+
return model
30+
31+
def get_eager_model(self) -> torch.nn.Module:
32+
return DepthAnythingV2Wrapper(self.model_name)
33+
34+
def get_example_inputs(self):
35+
"""Get example inputs for the model"""
36+
# Standard input size for Depth Anything V2 models
37+
# The model expects images of size (3, 518, 518) based on the processor configuration
38+
return (torch.randn(1, 3, 518, 518),)
39+
40+
def get_dynamic_shapes(self):
41+
"""Dynamic shapes for variable input sizes"""
42+
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}
43+
44+
45+
class DepthAnythingV2Wrapper(torch.nn.Module):
46+
"""
47+
Wrapper for Depth Anything V2 model that handles preprocessing and provides a clean interface.
48+
"""
49+
50+
def __init__(self, model_name="depth-anything/Depth-Anything-V2-Small-hf"):
51+
super().__init__()
52+
try:
53+
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
54+
except ImportError:
55+
raise ImportError(
56+
"transformers is required for DepthAnythingV2Model. "
57+
"Install with: pip install transformers"
58+
)
59+
60+
self.processor = AutoImageProcessor.from_pretrained(model_name)
61+
self.model = AutoModelForDepthEstimation.from_pretrained(model_name)
62+
63+
# Set to evaluation mode
64+
self.model.eval()
65+
66+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
67+
"""
68+
Forward pass for depth estimation.
69+
70+
Args:
71+
pixel_values: Input image tensor of shape (batch_size, 3, height, width)
72+
Values should be normalized to [0, 1] range
73+
74+
Returns:
75+
predicted_depth: Depth map tensor of shape (batch_size, height, width)
76+
"""
77+
# The model expects inputs to be preprocessed
78+
# pixel_values should already be properly normalized and sized
79+
80+
with torch.no_grad():
81+
outputs = self.model(pixel_values=pixel_values)
82+
predicted_depth = outputs.predicted_depth
83+
84+
# The model outputs depth in a specific format - we may need to interpolate
85+
# to match the input image size
86+
if predicted_depth.shape[-2:] != pixel_values.shape[-2:]:
87+
predicted_depth = torch.nn.functional.interpolate(
88+
predicted_depth.unsqueeze(1),
89+
size=pixel_values.shape[-2:],
90+
mode="bilinear",
91+
align_corners=False,
92+
).squeeze(1)
93+
94+
return predicted_depth
95+
96+
def preprocess_image(self, image):
97+
"""
98+
Preprocess a PIL image for the model.
99+
This method is not used in the forward pass but can be helpful for testing.
100+
"""
101+
inputs = self.processor(images=image, return_tensors="pt")
102+
return inputs["pixel_values"]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .model import DistilBertQAModel
8+
9+
__all__ = ["DistilBertQAModel"]
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizer
11+
12+
from ..model_base import EagerModelBase
13+
14+
15+
class DistilBertQAWrapper(torch.nn.Module):
16+
"""Wrapper for HuggingFace DistilBERT QA model to make it torch.export compatible"""
17+
18+
def __init__(self, model_name="distilbert-base-cased-distilled-squad"):
19+
super().__init__()
20+
self.model = DistilBertForQuestionAnswering.from_pretrained(model_name)
21+
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
22+
self.model.eval()
23+
24+
def forward(self, input_ids, attention_mask):
25+
# Get question answering outputs
26+
with torch.no_grad():
27+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
28+
29+
# Return start and end logits for answer span
30+
return outputs.start_logits, outputs.end_logits
31+
32+
33+
class DistilBertQAModel(EagerModelBase):
34+
def __init__(self):
35+
pass
36+
37+
def get_eager_model(self) -> torch.nn.Module:
38+
logging.info("Loading DistilBERT QA model from HuggingFace")
39+
model = DistilBertQAWrapper("distilbert-base-cased-distilled-squad")
40+
model.eval()
41+
logging.info("Loaded DistilBERT QA model")
42+
return model
43+
44+
def get_example_inputs(self):
45+
# Example inputs for DistilBERT QA
46+
# Combined question and context: batch_size=1, max_length=512
47+
input_ids = torch.randint(0, 28996, (1, 512)) # DistilBERT vocab size
48+
attention_mask = torch.ones(1, 512)
49+
50+
return (input_ids, attention_mask)

0 commit comments

Comments
 (0)