diff --git a/examples/models/__init__.py b/examples/models/__init__.py index c78106668eb..822d55fc09d 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -33,6 +33,7 @@ "resnet18": ("resnet", "ResNet18Model"), "resnet50": ("resnet", "ResNet50Model"), "llava": ("llava", "LlavaModel"), + "efficient_sam": ("efficient_sam", "EfficientSAM"), } __all__ = [ diff --git a/examples/models/efficient_sam/README.md b/examples/models/efficient_sam/README.md new file mode 100644 index 00000000000..bce1f7c5319 --- /dev/null +++ b/examples/models/efficient_sam/README.md @@ -0,0 +1,49 @@ +# EfficientSAM Model Export + +This example demonstrates how to export the [EfficientSAM](https://github.com/yformer/EfficientSAM) model to Core ML and XNNPACK using ExecuTorch. + +# Instructions + +## 1. Setup + +Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup#) to set up ExecuTorch. + +## 2. Exports + +### Exporting to Core ML + +Make sure to install the [required dependencies](https://pytorch.org/executorch/main/build-run-coreml.html#setting-up-your-developer-environment) for Core ML export. + +To export the model to Core ML, run the following command: + +```bash +cd executorch +python -m examples.apple.coreml.scripts.export -m efficient_sam +``` + +### Exporting to XNNPACK + +To export the model to XNNPACK, run the following command: + +```bash +cd executorch +python -m examples.xnnpack.aot_compiler -m efficient_sam +``` + +# Performance + +Tests were conducted on an Apple M1 Pro chip using the instructions for building and running Executorch with [Core ML](https://pytorch.org/executorch/main/build-run-coreml.html#runtime) and [XNNPACK](https://pytorch.org/executorch/main/tutorial-xnnpack-delegate-lowering.html#running-the-xnnpack-model-with-cmake) backends. + +| Backend Configuration | Average Inference Time (seconds) | +| ---------------------- | -------------------------------- | +| Core ML (CPU, GPU, NE) | 34.8 | +| Core ML (CPU, GPU) | 34.7 | +| Core ML (CPU, NE) | 26.4 | +| Core ML (CPU) | 22.8 | +| XNNPACK | 4.1 | + +All models were tested with `float32` precision. + +# Licensing + +The code in the `efficient_sam_core` directory is licensed under the [Apache License 2.0](./efficient_sam_core/LICENSE.txt). diff --git a/examples/models/efficient_sam/__init__.py b/examples/models/efficient_sam/__init__.py new file mode 100644 index 00000000000..8a767e34ed4 --- /dev/null +++ b/examples/models/efficient_sam/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import EfficientSAM + +__all__ = [ + EfficientSAM, +] diff --git a/examples/models/efficient_sam/efficient_sam_core/LICENSE.txt b/examples/models/efficient_sam/efficient_sam_core/LICENSE.txt new file mode 100644 index 00000000000..261eeb9e9f8 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/LICENSE.txt @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/examples/models/efficient_sam/efficient_sam_core/build_efficient_sam.py b/examples/models/efficient_sam/efficient_sam_core/build_efficient_sam.py new file mode 100644 index 00000000000..8d5a2758e53 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/build_efficient_sam.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the same directory. + +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/build_efficient_sam.py + +from .efficient_sam import build_efficient_sam + + +def build_efficient_sam_vitt(): + return build_efficient_sam( + encoder_patch_embed_dim=192, + encoder_num_heads=3, + checkpoint="https://huggingface.co/merve/EfficientSAM/resolve/main/efficient_sam_vitt.pt", + ).eval() + + +def build_efficient_sam_vits(): + return build_efficient_sam( + encoder_patch_embed_dim=384, + encoder_num_heads=6, + checkpoint="weights/efficient_sam_vits.pt", + ).eval() diff --git a/examples/models/efficient_sam/efficient_sam_core/efficient_sam.py b/examples/models/efficient_sam/efficient_sam_core/efficient_sam.py new file mode 100644 index 00000000000..d06db2de434 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/efficient_sam.py @@ -0,0 +1,318 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the same directory. + +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/efficient_sam.py + +from typing import List, Tuple + +import torch +import torch.nn.functional as F + +from torch import nn + +from .efficient_sam_decoder import MaskDecoder, PromptEncoder +from .efficient_sam_encoder import ImageEncoderViT +from .two_way_transformer import TwoWayTransformer + + +class EfficientSam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: ImageEncoderViT, + prompt_encoder: PromptEncoder, + decoder_max_num_input_points: int, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = None, + pixel_std: List[float] = None, + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.decoder_max_num_input_points = decoder_max_num_input_points + self.mask_decoder = mask_decoder + if pixel_mean is None: + pixel_mean = [0.485, 0.456, 0.406] + if pixel_std is None: + pixel_std = [0.229, 0.224, 0.225] + self.register_buffer( + "pixel_mean", torch.Tensor(pixel_mean).view(1, 3, 1, 1), False + ) + self.register_buffer( + "pixel_std", torch.Tensor(pixel_std).view(1, 3, 1, 1), False + ) + + @torch.jit.export + def predict_masks( + self, + image_embeddings: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + multimask_output: bool, + input_h: int, + input_w: int, + output_h: int = -1, + output_w: int = -1, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks given image embeddings and prompts. This only runs the decoder. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + batched_points: A tensor of shape [B, max_num_queries, num_pts, 2] + batched_point_labels: A tensor of shape [B, max_num_queries, num_pts] + Returns: + A tuple of two tensors: + low_res_mask: A tensor of shape [B, max_num_queries, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + + batch_size, max_num_queries, num_pts, _ = batched_points.shape + num_pts = batched_points.shape[2] + rescaled_batched_points = self.get_rescaled_pts( + batched_points, input_h, input_w + ) + + if num_pts > self.decoder_max_num_input_points: + rescaled_batched_points = rescaled_batched_points[ + :, :, : self.decoder_max_num_input_points, : + ] + batched_point_labels = batched_point_labels[ + :, :, : self.decoder_max_num_input_points + ] + elif num_pts < self.decoder_max_num_input_points: + rescaled_batched_points = F.pad( + rescaled_batched_points, + (0, 0, 0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + batched_point_labels = F.pad( + batched_point_labels, + (0, self.decoder_max_num_input_points - num_pts), + value=-1.0, + ) + + sparse_embeddings = self.prompt_encoder( + rescaled_batched_points.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points, 2 + ), + batched_point_labels.reshape( + batch_size * max_num_queries, self.decoder_max_num_input_points + ), + ) + sparse_embeddings = sparse_embeddings.view( + batch_size, + max_num_queries, + sparse_embeddings.shape[1], + sparse_embeddings.shape[2], + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings, + self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + multimask_output=multimask_output, + ) + _, num_predictions, low_res_size, _ = low_res_masks.shape + + if output_w > 0 and output_h > 0: + output_masks = F.interpolate( + low_res_masks, (output_h, output_w), mode="bicubic" + ) + output_masks = torch.reshape( + output_masks, + (batch_size, max_num_queries, num_predictions, output_h, output_w), + ) + else: + output_masks = torch.reshape( + low_res_masks, + ( + batch_size, + max_num_queries, + num_predictions, + low_res_size, + low_res_size, + ), + ) + iou_predictions = torch.reshape( + iou_predictions, (batch_size, max_num_queries, num_predictions) + ) + return output_masks, iou_predictions + + def get_rescaled_pts( + self, batched_points: torch.Tensor, input_h: int, input_w: int + ): + return torch.stack( + [ + torch.where( + batched_points[..., 0] >= 0, + batched_points[..., 0] * self.image_encoder.img_size / input_w, + -1.0, + ), + torch.where( + batched_points[..., 1] >= 0, + batched_points[..., 1] * self.image_encoder.img_size / input_h, + -1.0, + ), + ], + dim=-1, + ) + + @torch.jit.export + def get_image_embeddings(self, batched_images) -> torch.Tensor: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + Returns: + List of image embeddings each of of shape [B, C(i), H(i), W(i)]. + The last embedding corresponds to the final layer. + """ + batched_images = self.preprocess(batched_images) + return self.image_encoder(batched_images) + + def forward( + self, + batched_images: torch.Tensor, + batched_points: torch.Tensor, + batched_point_labels: torch.Tensor, + scale_to_original_image_size: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_images: A tensor of shape [B, 3, H, W] + batched_points: A tensor of shape [B, num_queries, max_num_pts, 2] + batched_point_labels: A tensor of shape [B, num_queries, max_num_pts] + + Returns: + A list tuples of two tensors where the ith element is by considering the first i+1 points. + low_res_mask: A tensor of shape [B, 256, 256] of predicted masks + iou_predictions: A tensor of shape [B, max_num_queries] of estimated IOU scores + """ + batch_size, _, input_h, input_w = batched_images.shape + image_embeddings = self.get_image_embeddings(batched_images) + return self.predict_masks( + image_embeddings, + batched_points, + batched_point_labels, + multimask_output=True, + input_h=input_h, + input_w=input_w, + output_h=input_h if scale_to_original_image_size else -1, + output_w=input_w if scale_to_original_image_size else -1, + ) + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + if ( + x.shape[2] != self.image_encoder.img_size + or x.shape[3] != self.image_encoder.img_size + ): + x = F.interpolate( + x, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + ) + return (x - self.pixel_mean) / self.pixel_std + + +def build_efficient_sam(encoder_patch_embed_dim, encoder_num_heads, checkpoint=None): + img_size = 1024 + encoder_patch_size = 16 + encoder_depth = 12 + encoder_mlp_ratio = 4.0 + encoder_neck_dims = [256, 256] + decoder_max_num_input_points = 6 + decoder_transformer_depth = 2 + decoder_transformer_mlp_dim = 2048 + decoder_num_heads = 8 + decoder_upscaling_layer_dims = [64, 32] + num_multimask_outputs = 3 + iou_head_depth = 3 + iou_head_hidden_dim = 256 + activation = "gelu" + normalization_type = "layer_norm" + normalize_before_activation = False + + assert activation == "relu" or activation == "gelu" + if activation == "relu": + activation_fn = nn.ReLU + else: + activation_fn = nn.GELU + + image_encoder = ImageEncoderViT( + img_size=img_size, + patch_size=encoder_patch_size, + in_chans=3, + patch_embed_dim=encoder_patch_embed_dim, + normalization_type=normalization_type, + depth=encoder_depth, + num_heads=encoder_num_heads, + mlp_ratio=encoder_mlp_ratio, + neck_dims=encoder_neck_dims, + act_layer=activation_fn, + ) + + image_embedding_size = image_encoder.image_embedding_size + encoder_transformer_output_dim = image_encoder.transformer_output_dim + + sam = EfficientSam( + image_encoder=image_encoder, + prompt_encoder=PromptEncoder( + embed_dim=encoder_transformer_output_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(img_size, img_size), + ), + decoder_max_num_input_points=decoder_max_num_input_points, + mask_decoder=MaskDecoder( + transformer_dim=encoder_transformer_output_dim, + transformer=TwoWayTransformer( + depth=decoder_transformer_depth, + embedding_dim=encoder_transformer_output_dim, + num_heads=decoder_num_heads, + mlp_dim=decoder_transformer_mlp_dim, + activation=activation_fn, + normalize_before_activation=normalize_before_activation, + ), + num_multimask_outputs=num_multimask_outputs, + activation=activation_fn, + normalization_type=normalization_type, + normalize_before_activation=normalize_before_activation, + iou_head_depth=iou_head_depth - 1, + iou_head_hidden_dim=iou_head_hidden_dim, + upscaling_layer_dims=decoder_upscaling_layer_dims, + ), + pixel_mean=[0.485, 0.456, 0.406], + pixel_std=[0.229, 0.224, 0.225], + ) + + if checkpoint is not None: + state_dict = torch.hub.load_state_dict_from_url( + checkpoint, map_location="cpu", progress=True, weights_only=True + ) + sam.load_state_dict(state_dict["model"]) + + return sam diff --git a/examples/models/efficient_sam/efficient_sam_core/efficient_sam_decoder.py b/examples/models/efficient_sam/efficient_sam_core/efficient_sam_decoder.py new file mode 100644 index 00000000000..61fbf28e276 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/efficient_sam_decoder.py @@ -0,0 +1,373 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the same directory. + +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/efficient_sam_decoder.py + +from typing import List, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn + +from .mlp import MLPBlock + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + self.invalid_points = nn.Embedding(1, embed_dim) + self.point_embeddings = nn.Embedding(1, embed_dim) + self.bbox_top_left_embeddings = nn.Embedding(1, embed_dim) + self.bbox_bottom_right_embeddings = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + point_embedding = self.pe_layer.forward_with_coords( + points, self.input_image_size + ) + invalid_label_ids = torch.eq(labels, -1)[:, :, None] + point_label_ids = torch.eq(labels, 1)[:, :, None] + topleft_label_ids = torch.eq(labels, 2)[:, :, None] + bottomright_label_ids = torch.eq(labels, 3)[:, :, None] + point_embedding = ( + point_embedding + self.invalid_points.weight[:, None, :] * invalid_label_ids + ) + point_embedding = ( + point_embedding + self.point_embeddings.weight[:, None, :] * point_label_ids + ) + point_embedding = ( + point_embedding + + self.bbox_top_left_embeddings.weight[:, None, :] * topleft_label_ids + ) + point_embedding = ( + point_embedding + + self.bbox_bottom_right_embeddings.weight[:, None, :] + * bottomright_label_ids + ) + return point_embedding + + def forward( + self, + coords, + labels, + ) -> torch.Tensor: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points: A tensor of shape [B, 2] + labels: An integer tensor of shape [B] where each element is 1,2 or 3. + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + """ + return self._embed_points(coords, labels) + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int) -> None: + super().__init__() + self.register_buffer( + "positional_encoding_gaussian_matrix", torch.randn((2, num_pos_feats)) + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + # TODO: Remove custom_cumsum implementation once issue #6201 is resolved + def custom_cumsum(self, tensor: torch.Tensor, dim: int) -> torch.Tensor: + """Custom cumulative sum.""" + tensor = tensor.transpose(dim, 0) + original_shape = tensor.shape + n = original_shape[0] + tensor = tensor.reshape(n, -1) + tril = torch.tril(torch.ones(n, n, device=tensor.device, dtype=tensor.dtype)) + tensor = tril @ tensor + tensor = tensor.view(original_shape) + tensor = tensor.transpose(0, dim) + return tensor + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device = self.positional_encoding_gaussian_matrix.device + grid = torch.ones([h, w], device=device, dtype=torch.float32) + # Modification: Use custom_cumsum as a workaround for issue #6201 + y_embed = self.custom_cumsum(grid, dim=0) - 0.5 + x_embed = self.custom_cumsum(grid, dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C + + +# TODO: Remove CustomGroupNorm implementation once issue #6817 is resolved +class CustomGroupNorm(nn.Module): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + """Custom Group Normalization.""" + super(CustomGroupNorm, self).__init__() + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + + if self.affine: + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + + def forward(self, x): + N, C, *rest = x.shape + G = self.num_groups + x = x.view(N, G, C // G, *rest) + shape = (2, *range(3, x.dim())) + mean = x.mean(dim=shape, keepdim=True) + var = ((x - mean) ** 2).mean(dim=shape, keepdim=True) + x = (x - mean) / (var + self.eps).sqrt() + x = x.view(N, C, *rest) + shape = (1, -1, *([1] * len(rest))) + if self.affine: + x = x * self.weight.view(shape) + self.bias.view(shape) + return x + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int, + activation: Type[nn.Module], + normalization_type: str, + normalize_before_activation: bool, + iou_head_depth: int, + iou_head_hidden_dim: int, + upscaling_layer_dims: List[int], + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + if num_multimask_outputs > 1: + self.num_mask_tokens = num_multimask_outputs + 1 + else: + self.num_mask_tokens = 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + output_dim_after_upscaling = transformer_dim + + self.final_output_upscaling_layers = nn.ModuleList([]) + for idx, layer_dims in enumerate(upscaling_layer_dims): + self.final_output_upscaling_layers.append( + nn.Sequential( + nn.ConvTranspose2d( + output_dim_after_upscaling, + layer_dims, + kernel_size=2, + stride=2, + ), + ( + # Modification: Use CustomGroupNorm as a workaround for issue #6817 + CustomGroupNorm(1, layer_dims) + if idx < len(upscaling_layer_dims) - 1 + else nn.Identity() + ), + activation(), + ) + ) + output_dim_after_upscaling = layer_dims + + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLPBlock( + input_dim=transformer_dim, + hidden_dim=transformer_dim, + output_dim=output_dim_after_upscaling, + num_layers=2, + act=activation, + ) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLPBlock( + input_dim=transformer_dim, + hidden_dim=iou_head_hidden_dim, + output_dim=self.num_mask_tokens, + num_layers=iou_head_depth, + act=activation, + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings: A tensor of shape [B, C, H, W] or [B*max_num_queries, C, H, W] + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings (the batch dimension is broadcastable). + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + + ( + batch_size, + max_num_queries, + sparse_embed_dim_1, + sparse_embed_dim_2, + ) = sparse_prompt_embeddings.shape + + ( + _, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) = image_embeddings.shape + + # Tile the image embedding for all queries. + image_embeddings_tiled = torch.tile( + image_embeddings[:, None, :, :, :], [1, max_num_queries, 1, 1, 1] + ).view( + batch_size * max_num_queries, + image_embed_dim_c, + image_embed_dim_h, + image_embed_dim_w, + ) + sparse_prompt_embeddings = sparse_prompt_embeddings.reshape( + batch_size * max_num_queries, sparse_embed_dim_1, sparse_embed_dim_2 + ) + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings_tiled, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + ) + if multimask_output and self.num_multimask_outputs > 1: + return masks[:, 1:, :], iou_pred[:, 1:] + else: + return masks[:, :1, :], iou_pred[:, :1] + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat( + [self.iou_token.weight, self.mask_tokens.weight], dim=0 + ) + output_tokens = output_tokens.unsqueeze(0).expand( + sparse_prompt_embeddings.size(0), -1, -1 + ) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + # Expand per-image data in batch direction to be per-mask + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = image_embeddings.shape + hs, src = self.transformer(image_embeddings, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + upscaled_embedding = src.transpose(1, 2).view(b, c, h, w) + + for upscaling_layer in self.final_output_upscaling_layers: + upscaled_embedding = upscaling_layer(upscaled_embedding) + hyper_in_list: List[torch.Tensor] = [] + for i, output_hypernetworks_mlp in enumerate(self.output_hypernetworks_mlps): + hyper_in_list.append(output_hypernetworks_mlp(mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + return masks, iou_pred diff --git a/examples/models/efficient_sam/efficient_sam_core/efficient_sam_encoder.py b/examples/models/efficient_sam/efficient_sam_core/efficient_sam_encoder.py new file mode 100644 index 00000000000..d6ea4f5cc09 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/efficient_sam_encoder.py @@ -0,0 +1,262 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the same directory. + +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/efficient_sam_encoder.py + +import math +from typing import List, Type + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size, + patch_size, + in_chans, + embed_dim, + ): + super().__init__() + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + bias=True, + ) + + def forward(self, x): + B, C, H, W = x.shape + x = self.proj(x) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + num_heads, + qkv_bias, + qk_scale=None, + ): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=False, + qk_scale=None, + act_layer=nn.GELU, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=1e-6) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + ) + self.norm2 = nn.LayerNorm(dim, eps=1e-6) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + + def forward(self, x): + x = x + self.attn(self.norm1(x)) + x = x + self.mlp(self.norm2(x)) + return x + + +@torch.jit.export +def get_abs_pos( + abs_pos: torch.Tensor, has_cls_token: bool, hw: List[int] +) -> torch.Tensor: + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h = hw[0] + w = hw[1] + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + # Modification: Change memory format to contiguous + # 1. Makes it exportable to ExecuTorch + # 2. XNNPACK backend only supports contiguous memory format for inputs + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).contiguous(), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +# Image encoder for efficient SAM. +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int, + patch_size: int, + in_chans: int, + patch_embed_dim: int, + normalization_type: str, + depth: int, + num_heads: int, + mlp_ratio: float, + neck_dims: List[int], + act_layer: Type[nn.Module], + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + patch_embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + act_layer (nn.Module): Activation layer. + """ + super().__init__() + + self.img_size = img_size + self.image_embedding_size = img_size // ((patch_size if patch_size > 0 else 1)) + self.transformer_output_dim = ([patch_embed_dim] + neck_dims)[-1] + self.pretrain_use_cls_token = True + pretrain_img_size = 224 + self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, patch_embed_dim) + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * ( + pretrain_img_size // patch_size + ) + num_positions = num_patches + 1 + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, patch_embed_dim)) + self.blocks = nn.ModuleList() + for _ in range(depth): + vit_block = Block(patch_embed_dim, num_heads, mlp_ratio, True) + self.blocks.append(vit_block) + self.neck = nn.Sequential( + nn.Conv2d( + patch_embed_dim, + neck_dims[0], + kernel_size=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + nn.Conv2d( + neck_dims[0], + neck_dims[0], + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(neck_dims[0]), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert ( + x.shape[2] == self.img_size and x.shape[3] == self.img_size + ), "input image size must match self.img_size" + x = self.patch_embed(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, [x.shape[1], x.shape[2]] + ) + num_patches = x.shape[1] + assert x.shape[2] == num_patches + x = x.reshape(x.shape[0], num_patches * num_patches, x.shape[3]) + for blk in self.blocks: + x = blk(x) + x = x.reshape(x.shape[0], num_patches, num_patches, x.shape[2]) + x = self.neck(x.permute(0, 3, 1, 2)) + return x diff --git a/examples/models/efficient_sam/efficient_sam_core/mlp.py b/examples/models/efficient_sam/efficient_sam_core/mlp.py new file mode 100644 index 00000000000..cd4496b4a93 --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/mlp.py @@ -0,0 +1,31 @@ +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/mlp.py + +from typing import Type + +from torch import nn + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLPBlock(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + act: Type[nn.Module], + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Sequential(nn.Linear(n, k), act()) + for n, k in zip([input_dim] + h, [hidden_dim] * num_layers) + ) + self.fc = nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + for layer in self.layers: + x = layer(x) + return self.fc(x) diff --git a/examples/models/efficient_sam/efficient_sam_core/two_way_transformer.py b/examples/models/efficient_sam/efficient_sam_core/two_way_transformer.py new file mode 100644 index 00000000000..c9073d98e8f --- /dev/null +++ b/examples/models/efficient_sam/efficient_sam_core/two_way_transformer.py @@ -0,0 +1,268 @@ +# Source: https://github.com/yformer/EfficientSAM/blob/main/efficient_sam/two_way_transformer.py + +import math +from typing import Tuple, Type + +import torch +from torch import nn, Tensor + +from .mlp import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + curr_layer = TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + normalize_before_activation=normalize_before_activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + self.layers.append(curr_layer) + + self.final_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for _, layer in enumerate(self.layers): + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module], + normalize_before_activation: bool, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = AttentionForTwoWayAttentionBlock(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock( + embedding_dim, + mlp_dim, + embedding_dim, + 1, + activation, + ) + + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = AttentionForTwoWayAttentionBlock( + embedding_dim, + num_heads, + downsample_rate=attention_downsample_rate, + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if not self.skip_first_layer_pe: + queries = queries + query_pe + attn_out = self.self_attn(q=queries, k=queries, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class AttentionForTwoWayAttentionBlock(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert ( + self.internal_dim % num_heads == 0 + ), "num_heads must divide embedding_dim." + self.c_per_head = self.internal_dim / num_heads + self.inv_sqrt_c_per_head = 1.0 / math.sqrt(self.c_per_head) + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + self._reset_parameters() + + def _reset_parameters(self) -> None: + # The fan_out is incorrect, but matches pytorch's initialization + # for which qkv is a single 3*embedding_dim x embedding_dim matrix + fan_in = self.embedding_dim + fan_out = 3 * self.internal_dim + # Xavier uniform with our custom fan_out + bnd = math.sqrt(6 / (fan_in + fan_out)) + nn.init.uniform_(self.q_proj.weight, -bnd, bnd) + nn.init.uniform_(self.k_proj.weight, -bnd, bnd) + nn.init.uniform_(self.v_proj.weight, -bnd, bnd) + # out_proj.weight is left with default initialization, like pytorch attention + nn.init.zeros_(self.q_proj.bias) + nn.init.zeros_(self.k_proj.bias) + nn.init.zeros_(self.v_proj.bias) + nn.init.zeros_(self.out_proj.bias) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn * self.inv_sqrt_c_per_head + attn = torch.softmax(attn, dim=-1) + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + return out diff --git a/examples/models/efficient_sam/model.py b/examples/models/efficient_sam/model.py new file mode 100644 index 00000000000..d6b2b2ff806 --- /dev/null +++ b/examples/models/efficient_sam/model.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +from ..model_base import EagerModelBase + +from .efficient_sam_core.build_efficient_sam import build_efficient_sam_vitt + + +class EfficientSAM(EagerModelBase): + def __init__(self): + pass + + def get_eager_model(self) -> torch.nn.Module: + logging.info("Loading EfficientSAM model") + efficient_sam = build_efficient_sam_vitt() + logging.info("Loaded EfficientSAM model") + return efficient_sam + + def get_example_inputs(self): + B, H, W = 1, 1024, 1024 + num_queries, num_pts = 1, 1 + + batched_images = torch.randn((B, 3, H, W)) + batched_points = torch.rand((B, num_queries, num_pts, 2)) * torch.tensor([H, W]) + batched_point_labels = torch.ones((B, num_queries, num_pts)) + + return (batched_images, batched_points, batched_point_labels) diff --git a/examples/models/test/test_export.py b/examples/models/test/test_export.py index 6a7c793029c..9a4ff7a35ed 100644 --- a/examples/models/test/test_export.py +++ b/examples/models/test/test_export.py @@ -148,3 +148,14 @@ def test_dl3_export_to_executorch(self): eager_model, example_inputs ) self.validate_tensor_allclose(list(eager_output.values()), executorch_output) + + def test_efficient_sam_export_to_executorch(self): + eager_model, example_inputs, _, _ = EagerModelFactory.create_model( + *MODEL_NAME_TO_MODEL["efficient_sam"] + ) + eager_output, executorch_output = self.collect_executorch_and_eager_outputs( + eager_model, example_inputs + ) + self.validate_tensor_allclose( + list(eager_output), executorch_output, rtol=1e-2, atol=1e-2 + )