Skip to content

Commit 0829b88

Browse files
authored
Add files via upload
1 parent e34653b commit 0829b88

File tree

17 files changed

+2290
-0
lines changed

17 files changed

+2290
-0
lines changed

sam3/perflib/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
3+
import os
4+
5+
is_enabled = False
6+
if os.getenv("USE_PERFLIB", "1") == "1":
7+
# print("Enabled the use of perflib.\n", end="")
8+
is_enabled = True

sam3/perflib/associate_det_trk.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
3+
from collections import defaultdict
4+
5+
import torch
6+
import torch.nn.functional as F
7+
from sam3.perflib.masks_ops import mask_iou
8+
from scipy.optimize import linear_sum_assignment
9+
10+
11+
def associate_det_trk(
12+
det_masks,
13+
track_masks,
14+
iou_threshold=0.5,
15+
iou_threshold_trk=0.5,
16+
det_scores=None,
17+
new_det_thresh=0.0,
18+
):
19+
"""
20+
Optimized implementation of detection <-> track association that minimizes DtoH syncs.
21+
22+
Args:
23+
det_masks: (N, H, W) tensor of predicted masks
24+
track_masks: (M, H, W) tensor of track masks
25+
26+
Returns:
27+
new_det_indices: list of indices in det_masks considered 'new'
28+
unmatched_trk_indices: list of indices in track_masks considered 'unmatched'
29+
"""
30+
with torch.autograd.profiler.record_function("perflib: associate_det_trk"):
31+
assert isinstance(det_masks, torch.Tensor), "det_masks should be a tensor"
32+
assert isinstance(track_masks, torch.Tensor), "track_masks should be a tensor"
33+
if det_masks.size(0) == 0 or track_masks.size(0) == 0:
34+
return list(range(det_masks.size(0))), [], {}, {} # all detections are new
35+
36+
if list(det_masks.shape[-2:]) != list(track_masks.shape[-2:]):
37+
# resize to the smaller size to save GPU memory
38+
if torch.numel(det_masks[-2:]) < torch.numel(track_masks[-2:]):
39+
track_masks = (
40+
F.interpolate(
41+
track_masks.unsqueeze(1).float(),
42+
size=det_masks.shape[-2:],
43+
mode="bilinear",
44+
align_corners=False,
45+
).squeeze(1)
46+
> 0
47+
)
48+
else:
49+
# resize detections to track size
50+
det_masks = (
51+
F.interpolate(
52+
det_masks.unsqueeze(1).float(),
53+
size=track_masks.shape[-2:],
54+
mode="bilinear",
55+
align_corners=False,
56+
).squeeze(1)
57+
> 0
58+
)
59+
60+
det_masks = det_masks > 0
61+
track_masks = track_masks > 0
62+
63+
iou = mask_iou(det_masks, track_masks) # (N, M)
64+
igeit = iou >= iou_threshold
65+
igeit_any_dim_1 = igeit.any(dim=1)
66+
igeit_trk = iou >= iou_threshold_trk
67+
68+
iou_list = iou.cpu().numpy().tolist()
69+
igeit_list = igeit.cpu().numpy().tolist()
70+
igeit_any_dim_1_list = igeit_any_dim_1.cpu().numpy().tolist()
71+
igeit_trk_list = igeit_trk.cpu().numpy().tolist()
72+
73+
det_scores_list = (
74+
det_scores
75+
if det_scores is None
76+
else det_scores.cpu().float().numpy().tolist()
77+
)
78+
79+
# Hungarian matching for tracks (one-to-one: each track matches at most one detection)
80+
# For detections: allow many tracks to match to the same detection (many-to-one)
81+
82+
# If either is empty, return all detections as new
83+
if det_masks.size(0) == 0 or track_masks.size(0) == 0:
84+
return list(range(det_masks.size(0))), [], {}
85+
86+
# Hungarian matching: maximize IoU for tracks
87+
cost_matrix = 1 - iou.cpu().numpy() # Hungarian solves for minimum cost
88+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
89+
90+
def branchy_hungarian_better_uses_the_cpu(
91+
cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
92+
):
93+
matched_trk = set()
94+
matched_det = set()
95+
matched_det_scores = {} # track index -> [det_score, det_score * iou] det score of matched detection mask
96+
for d, t in zip(row_ind, col_ind):
97+
matched_det_scores[t] = [
98+
det_scores_list[d],
99+
det_scores_list[d] * iou_list[d][t],
100+
]
101+
if igeit_trk_list[d][t]:
102+
matched_trk.add(t)
103+
matched_det.add(d)
104+
105+
# Tracks not matched by Hungarian assignment above threshold are unmatched
106+
unmatched_trk_indices = [
107+
t for t in range(track_masks.size(0)) if t not in matched_trk
108+
]
109+
110+
# For detections: allow many tracks to match to the same detection (many-to-one)
111+
# So, a detection is 'new' if it does not match any track above threshold
112+
assert track_masks.size(0) == igeit.size(
113+
1
114+
) # Needed for loop optimizaiton below
115+
new_det_indices = []
116+
for d in range(det_masks.size(0)):
117+
if not igeit_any_dim_1_list[d]:
118+
if det_scores is not None and det_scores[d] >= new_det_thresh:
119+
new_det_indices.append(d)
120+
121+
# for each detection, which tracks it matched to (above threshold)
122+
det_to_matched_trk = defaultdict(list)
123+
for d in range(det_masks.size(0)):
124+
for t in range(track_masks.size(0)):
125+
if igeit_list[d][t]:
126+
det_to_matched_trk[d].append(t)
127+
128+
return (
129+
new_det_indices,
130+
unmatched_trk_indices,
131+
det_to_matched_trk,
132+
matched_det_scores,
133+
)
134+
135+
return (branchy_hungarian_better_uses_the_cpu)(
136+
cost_matrix, row_ind, col_ind, iou_list, det_masks, track_masks
137+
)

sam3/perflib/compile.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
3+
import torch
4+
5+
6+
def recursive_fn_factory(fn):
7+
def recursive_fn(b):
8+
if isinstance(b, dict):
9+
return {k: recursive_fn(b[k]) for k in b}
10+
if isinstance(b, list):
11+
return [recursive_fn(t) for t in b]
12+
if isinstance(b, tuple):
13+
return tuple(recursive_fn(t) for t in b)
14+
if isinstance(b, torch.Tensor):
15+
return fn(b)
16+
# Yes, writing out an explicit white list of
17+
# trivial types is tedious, but so are bugs that
18+
# come from not applying fn, when expected to have
19+
# applied it.
20+
if b is None:
21+
return b
22+
trivial_types = [bool, int]
23+
for t in trivial_types:
24+
if isinstance(b, t):
25+
return b
26+
raise TypeError(f"Unexpected type {type(b)}")
27+
28+
return recursive_fn
29+
30+
31+
recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous())
32+
recursive_clone = recursive_fn_factory(torch.clone)
33+
34+
35+
def compile_wrapper(
36+
fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None
37+
):
38+
compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic)
39+
40+
def compiled_fn_wrapper(*args, **kwargs):
41+
with torch.autograd.profiler.record_function(
42+
f"compiled {fn}" if name is None else name
43+
):
44+
cont_args = recursive_contiguous(args)
45+
cont_kwargs = recursive_contiguous(kwargs)
46+
result = compiled_fn(*cont_args, **cont_kwargs)
47+
cloned_result = recursive_clone(result)
48+
return cloned_result
49+
50+
return compiled_fn_wrapper
51+
52+
53+
def shape_logging_wrapper(fn, keep_kwargs, enable_logging=False):
54+
"""
55+
Wraps a function and prints the shapes of all tensor inputs.
56+
Only prints when a new combination of shapes is seen.
57+
Thread-safe.
58+
59+
Args:
60+
fn: Function to wrap
61+
enable_logging: Boolean flag to enable/disable logging
62+
"""
63+
seen_shapes = set()
64+
65+
def get_shape(obj):
66+
if isinstance(obj, torch.Tensor):
67+
return obj.shape
68+
elif isinstance(obj, (list, tuple)):
69+
if len(obj) > 1:
70+
return tuple(get_shape(x) for x in obj)
71+
return get_shape(obj[0])
72+
elif isinstance(obj, dict):
73+
return tuple(sorted((k, get_shape(v)) for k, v in obj.items()))
74+
else:
75+
return type(obj).__name__
76+
77+
def wrapper(*args, **kwargs):
78+
shapes = tuple(get_shape(arg) for arg in args) + tuple(
79+
(k, get_shape(v))
80+
for k, v in kwargs.items()
81+
if isinstance(v, (torch.Tensor, list))
82+
and (len(keep_kwargs) > 0 and k in keep_kwargs)
83+
)
84+
if shapes not in seen_shapes:
85+
seen_shapes.add(shapes)
86+
if enable_logging:
87+
print(f"[ShapeLogger] New input shapes for {fn.__qualname__}: {shapes}")
88+
return fn(*args, **kwargs)
89+
90+
# Allow toggling the flag at runtime
91+
wrapper.enable_logging = enable_logging
92+
93+
def set_logging(enabled=False):
94+
nonlocal enable_logging
95+
enable_logging = enabled
96+
wrapper.enable_logging = enable_logging
97+
98+
wrapper.set_logging = set_logging
99+
return wrapper
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
import logging
3+
4+
import torch
5+
6+
try:
7+
from cc_torch import get_connected_components
8+
9+
HAS_CC_TORCH = True
10+
except ImportError:
11+
logging.debug(
12+
"cc_torch not found. Consider installing for better performance. Command line:"
13+
" pip install git+https://github.com/ronghanghu/cc_torch.git"
14+
)
15+
HAS_CC_TORCH = False
16+
17+
18+
def connected_components_cpu_single(values: torch.Tensor):
19+
assert values.dim() == 2
20+
from skimage.measure import label
21+
22+
labels, num = label(values.cpu().numpy(), return_num=True)
23+
labels = torch.from_numpy(labels)
24+
counts = torch.zeros_like(labels)
25+
for i in range(1, num + 1):
26+
cur_mask = labels == i
27+
cur_count = cur_mask.sum()
28+
counts[cur_mask] = cur_count
29+
return labels, counts
30+
31+
32+
def connected_components_cpu(input_tensor: torch.Tensor):
33+
out_shape = input_tensor.shape
34+
if input_tensor.dim() == 4 and input_tensor.shape[1] == 1:
35+
input_tensor = input_tensor.squeeze(1)
36+
else:
37+
assert (
38+
input_tensor.dim() == 3
39+
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
40+
41+
batch_size = input_tensor.shape[0]
42+
labels_list = []
43+
counts_list = []
44+
for b in range(batch_size):
45+
labels, counts = connected_components_cpu_single(input_tensor[b])
46+
labels_list.append(labels)
47+
counts_list.append(counts)
48+
labels_tensor = torch.stack(labels_list, dim=0).to(input_tensor.device)
49+
counts_tensor = torch.stack(counts_list, dim=0).to(input_tensor.device)
50+
return labels_tensor.view(out_shape), counts_tensor.view(out_shape)
51+
52+
53+
def connected_components(input_tensor: torch.Tensor):
54+
"""
55+
Computes connected components labeling on a batch of 2D tensors, using the best available backend.
56+
57+
Args:
58+
input_tensor (torch.Tensor): A BxHxW integer tensor or Bx1xHxW. Non-zero values are considered foreground. Bool tensor also accepted
59+
60+
Returns:
61+
Tuple[torch.Tensor, torch.Tensor]: Both tensors have the same shape as input_tensor.
62+
- A tensor with dense labels. Background is 0.
63+
- A tensor with the size of the connected component for each pixel.
64+
"""
65+
if input_tensor.dim() == 3:
66+
input_tensor = input_tensor.unsqueeze(1)
67+
68+
assert (
69+
input_tensor.dim() == 4 and input_tensor.shape[1] == 1
70+
), "Input tensor must be (B, H, W) or (B, 1, H, W)."
71+
72+
if input_tensor.is_cuda:
73+
if HAS_CC_TORCH:
74+
return get_connected_components(input_tensor.to(torch.uint8))
75+
else:
76+
# triton fallback
77+
from sam3.perflib.triton.connected_components import (
78+
connected_components_triton,
79+
)
80+
81+
return connected_components_triton(input_tensor)
82+
83+
# CPU fallback
84+
return connected_components_cpu(input_tensor)

sam3/perflib/fa3.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
2+
3+
import torch
4+
5+
6+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
7+
def flash_attn_func_op(
8+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
9+
) -> torch.Tensor:
10+
from flash_attn_interface import flash_attn_func as fa3
11+
12+
return fa3(q, k, v)
13+
14+
15+
def flash_attn_func(q, k, v):
16+
dtype = torch.float8_e4m3fn
17+
return flash_attn_func_op(q.to(dtype), k.to(dtype), v.to(dtype)).to(q.dtype)
18+
19+
20+
@flash_attn_func_op.register_fake
21+
def _(q, k, v, **kwargs):
22+
# two outputs:
23+
# 1. output: (batch, seq_len, num_heads, head_dim)
24+
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
25+
# output needs to be bfloat16, not float8!
26+
meta_q = torch.empty_like(q, dtype=torch.bfloat16).contiguous()
27+
return meta_q

0 commit comments

Comments
 (0)