Skip to content

Commit e34653b

Browse files
authored
Add files via upload
1 parent 5e198d7 commit e34653b

32 files changed

+15639
-0
lines changed
1.29 MB
Binary file not shown.

sam3/model/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved

sam3/model/act_ckpt_utils.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
3+
import inspect
4+
from functools import wraps
5+
from typing import Callable, TypeVar, Union
6+
7+
import torch
8+
import torch.nn as nn
9+
import torch.utils.checkpoint as checkpoint
10+
from torch.utils._pytree import tree_map_only
11+
12+
# Type variables for better type hinting
13+
T = TypeVar("T")
14+
Module = TypeVar("Module", bound=nn.Module)
15+
16+
17+
def activation_ckpt_wrapper(module: Union[nn.Module, Callable]) -> Callable:
18+
"""
19+
Wraps a given module to enable or disable activation checkpointing.
20+
21+
Activation checkpointing (gradient checkpointing) trades compute for memory by
22+
recomputing intermediate activations during the backward pass instead of storing
23+
them in memory during the forward pass.
24+
25+
When activation checkpointing is enabled, the wrapper expects only keyword arguments,
26+
and it maps these to positional arguments based on the module's signature.
27+
28+
Args:
29+
module: The module or function to wrap with activation checkpointing
30+
31+
Returns:
32+
A wrapped callable that supports activation checkpointing
33+
34+
Usage:
35+
The returned wrapper function can be called with the same arguments as the
36+
original module, with an additional `act_ckpt_enable` keyword argument to control
37+
activation checkpointing and optional `use_reentrant` parameter.
38+
39+
Example:
40+
```python
41+
wrapped_module = activation_ckpt_wrapper(my_module)
42+
output = wrapped_module(x=input_tensor, y=another_tensor, act_ckpt_enable=True)
43+
```
44+
"""
45+
46+
@wraps(module)
47+
def act_ckpt_wrapper(
48+
*args, act_ckpt_enable: bool = True, use_reentrant: bool = False, **kwargs
49+
):
50+
if act_ckpt_enable:
51+
if len(args) > 0:
52+
raise ValueError(
53+
"This wrapper expects keyword arguments only when `act_ckpt_enable=True`"
54+
)
55+
# Get the signature of the target function/module
56+
callable_fn = module.forward if isinstance(module, nn.Module) else module
57+
sig = inspect.signature(callable_fn)
58+
# Create a mapping of parameter names to their default values
59+
param_defaults = {
60+
name: param.default for name, param in sig.parameters.items()
61+
}
62+
args = []
63+
for p_name in param_defaults.keys():
64+
if p_name in kwargs:
65+
args.append(kwargs.pop(p_name))
66+
elif param_defaults[p_name] is not inspect.Parameter.empty:
67+
# Set arg to default value if it's not in kwargs. Useful for primitive types or args that default to None
68+
args.append(param_defaults[p_name])
69+
elif (
70+
sig.parameters[p_name].kind is not inspect.Parameter.VAR_KEYWORD
71+
): # Skip **kwargs parameter
72+
raise ValueError(f"Missing positional argument: {p_name}")
73+
74+
# Scan remaining kwargs for torch.Tensor
75+
remaining_keys = list(kwargs.keys())
76+
for key in remaining_keys:
77+
if isinstance(kwargs[key], torch.Tensor):
78+
# Remove the tensor from kwargs, assuming it's not required by the module.
79+
# If it is required, the module's signature should be modified to accept it as a positional or keyword argument.
80+
kwargs[key] = "_REMOVED_BY_ACT_CKPT_WRAPPER_"
81+
82+
ret = checkpoint.checkpoint(
83+
module, *args, use_reentrant=use_reentrant, **kwargs
84+
)
85+
else:
86+
ret = module(*args, **kwargs)
87+
88+
return ret
89+
90+
return act_ckpt_wrapper
91+
92+
93+
def clone_output_wrapper(f: Callable[..., T]) -> Callable[..., T]:
94+
"""
95+
Clone the CUDA output tensors of a function to avoid in-place operations.
96+
97+
This wrapper is useful when working with torch.compile to prevent errors
98+
related to in-place operations on tensors.
99+
100+
Args:
101+
f: The function whose CUDA tensor outputs should be cloned
102+
103+
Returns:
104+
A wrapped function that clones any CUDA tensor outputs
105+
"""
106+
107+
@wraps(f)
108+
def wrapped(*args, **kwargs):
109+
outputs = f(*args, **kwargs)
110+
return tree_map_only(
111+
torch.Tensor, lambda t: t.clone() if t.is_cuda else t, outputs
112+
)
113+
114+
return wrapped

sam3/model/box_ops.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
"""
3+
Utilities for bounding box manipulation and GIoU.
4+
"""
5+
6+
from typing import Tuple
7+
8+
import torch
9+
10+
11+
def box_cxcywh_to_xyxy(x):
12+
x_c, y_c, w, h = x.unbind(-1)
13+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
14+
return torch.stack(b, dim=-1)
15+
16+
17+
def box_cxcywh_to_xywh(x):
18+
x_c, y_c, w, h = x.unbind(-1)
19+
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (w), (h)]
20+
return torch.stack(b, dim=-1)
21+
22+
23+
def box_xywh_to_xyxy(x):
24+
x, y, w, h = x.unbind(-1)
25+
b = [(x), (y), (x + w), (y + h)]
26+
return torch.stack(b, dim=-1)
27+
28+
29+
def box_xywh_to_cxcywh(x):
30+
x, y, w, h = x.unbind(-1)
31+
b = [(x + 0.5 * w), (y + 0.5 * h), (w), (h)]
32+
return torch.stack(b, dim=-1)
33+
34+
35+
def box_xyxy_to_xywh(x):
36+
x, y, X, Y = x.unbind(-1)
37+
b = [(x), (y), (X - x), (Y - y)]
38+
return torch.stack(b, dim=-1)
39+
40+
41+
def box_xyxy_to_cxcywh(x):
42+
x0, y0, x1, y1 = x.unbind(-1)
43+
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
44+
return torch.stack(b, dim=-1)
45+
46+
47+
def box_area(boxes):
48+
"""
49+
Batched version of box area. Boxes should be in [x0, y0, x1, y1] format.
50+
51+
Inputs:
52+
- boxes: Tensor of shape (..., 4)
53+
54+
Returns:
55+
- areas: Tensor of shape (...,)
56+
"""
57+
x0, y0, x1, y1 = boxes.unbind(-1)
58+
return (x1 - x0) * (y1 - y0)
59+
60+
61+
def masks_to_boxes(masks):
62+
"""Compute the bounding boxes around the provided masks
63+
64+
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
65+
66+
Returns a [N, 4] tensors, with the boxes in xyxy format
67+
"""
68+
if masks.numel() == 0:
69+
return torch.zeros((0, 4), device=masks.device)
70+
71+
h, w = masks.shape[-2:]
72+
73+
y = torch.arange(0, h, dtype=torch.float, device=masks.device)
74+
x = torch.arange(0, w, dtype=torch.float, device=masks.device)
75+
y, x = torch.meshgrid(y, x)
76+
77+
x_mask = masks * x.unsqueeze(0)
78+
x_max = x_mask.flatten(1).max(-1)[0] + 1
79+
x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
80+
81+
y_mask = masks * y.unsqueeze(0)
82+
y_max = y_mask.flatten(1).max(-1)[0] + 1
83+
y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
84+
85+
boxes = torch.stack([x_min, y_min, x_max, y_max], 1)
86+
# Invalidate boxes corresponding to empty masks.
87+
boxes = boxes * masks.flatten(-2).any(-1)
88+
return boxes
89+
90+
91+
def box_iou(boxes1, boxes2):
92+
"""
93+
Batched version of box_iou. Boxes should be in [x0, y0, x1, y1] format.
94+
95+
Inputs:
96+
- boxes1: Tensor of shape (..., N, 4)
97+
- boxes2: Tensor of shape (..., M, 4)
98+
99+
Returns:
100+
- iou, union: Tensors of shape (..., N, M)
101+
"""
102+
area1 = box_area(boxes1)
103+
area2 = box_area(boxes2)
104+
105+
# boxes1: (..., N, 4) -> (..., N, 1, 2)
106+
# boxes2: (..., M, 4) -> (..., 1, M, 2)
107+
lt = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
108+
rb = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
109+
110+
wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
111+
inter = wh[..., 0] * wh[..., 1] # (..., N, M)
112+
113+
union = area1[..., None] + area2[..., None, :] - inter
114+
115+
iou = inter / union
116+
return iou, union
117+
118+
119+
def generalized_box_iou(boxes1, boxes2):
120+
"""
121+
Batched version of Generalized IoU from https://giou.stanford.edu/
122+
123+
Boxes should be in [x0, y0, x1, y1] format
124+
125+
Inputs:
126+
- boxes1: Tensor of shape (..., N, 4)
127+
- boxes2: Tensor of shape (..., M, 4)
128+
129+
Returns:
130+
- giou: Tensor of shape (..., N, M)
131+
"""
132+
iou, union = box_iou(boxes1, boxes2)
133+
134+
# boxes1: (..., N, 4) -> (..., N, 1, 2)
135+
# boxes2: (..., M, 4) -> (..., 1, M, 2)
136+
lt = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
137+
rb = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
138+
139+
wh = (rb - lt).clamp(min=0) # (..., N, M, 2)
140+
area = wh[..., 0] * wh[..., 1] # (..., N, M)
141+
142+
return iou - (area - union) / area
143+
144+
145+
@torch.jit.script
146+
def fast_diag_generalized_box_iou(boxes1, boxes2):
147+
assert len(boxes1) == len(boxes2)
148+
box1_xy = boxes1[:, 2:]
149+
box1_XY = boxes1[:, :2]
150+
box2_xy = boxes2[:, 2:]
151+
box2_XY = boxes2[:, :2]
152+
# assert (box1_xy >= box1_XY).all()
153+
# assert (box2_xy >= box2_XY).all()
154+
area1 = (box1_xy - box1_XY).prod(-1)
155+
area2 = (box2_xy - box2_XY).prod(-1)
156+
157+
lt = torch.max(box1_XY, box2_XY) # [N,2]
158+
lt2 = torch.min(box1_XY, box2_XY)
159+
rb = torch.min(box1_xy, box2_xy) # [N,2]
160+
rb2 = torch.max(box1_xy, box2_xy)
161+
162+
inter = (rb - lt).clamp(min=0).prod(-1)
163+
tot_area = (rb2 - lt2).clamp(min=0).prod(-1)
164+
165+
union = area1 + area2 - inter
166+
167+
iou = inter / union
168+
169+
return iou - (tot_area - union) / tot_area
170+
171+
172+
@torch.jit.script
173+
def fast_diag_box_iou(boxes1, boxes2):
174+
assert len(boxes1) == len(boxes2)
175+
box1_xy = boxes1[:, 2:]
176+
box1_XY = boxes1[:, :2]
177+
box2_xy = boxes2[:, 2:]
178+
box2_XY = boxes2[:, :2]
179+
# assert (box1_xy >= box1_XY).all()
180+
# assert (box2_xy >= box2_XY).all()
181+
area1 = (box1_xy - box1_XY).prod(-1)
182+
area2 = (box2_xy - box2_XY).prod(-1)
183+
184+
lt = torch.max(box1_XY, box2_XY) # [N,2]
185+
rb = torch.min(box1_xy, box2_xy) # [N,2]
186+
187+
inter = (rb - lt).clamp(min=0).prod(-1)
188+
189+
union = area1 + area2 - inter
190+
191+
iou = inter / union
192+
193+
return iou
194+
195+
196+
def box_xywh_inter_union(
197+
boxes1: torch.Tensor, boxes2: torch.Tensor
198+
) -> Tuple[torch.Tensor, torch.Tensor]:
199+
# Asuumes boxes in xywh format
200+
assert boxes1.size(-1) == 4 and boxes2.size(-1) == 4
201+
boxes1 = box_xywh_to_xyxy(boxes1)
202+
boxes2 = box_xywh_to_xyxy(boxes2)
203+
box1_tl_xy = boxes1[..., :2]
204+
box1_br_xy = boxes1[..., 2:]
205+
box2_tl_xy = boxes2[..., :2]
206+
box2_br_xy = boxes2[..., 2:]
207+
area1 = (box1_br_xy - box1_tl_xy).prod(-1)
208+
area2 = (box2_br_xy - box2_tl_xy).prod(-1)
209+
210+
assert (area1 >= 0).all() and (area2 >= 0).all()
211+
tl = torch.max(box1_tl_xy, box2_tl_xy)
212+
br = torch.min(box1_br_xy, box2_br_xy)
213+
214+
inter = (br - tl).clamp(min=0).prod(-1)
215+
union = area1 + area2 - inter
216+
217+
return inter, union

0 commit comments

Comments
 (0)