Skip to content

Commit e534520

Browse files
authored
[SAM] make saving to TF saved_model work (#81)
Ensure we only use TF operations in the forward operations for the models to enable saving and loading models.
1 parent 52b64d7 commit e534520

File tree

3 files changed

+37
-16
lines changed

3 files changed

+37
-16
lines changed

tests/models/test_segment_anything.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import tempfile
12
from typing import Tuple, cast
23

34
import numpy as np
@@ -43,7 +44,7 @@
4344
TwoWayTransformer as PTTwoWayTransformer,
4445
)
4546
from tfimm.architectures.segment_anything.transformer import (
46-
Attention as TFAttention,
47+
DownsampleAttention as TFAttention,
4748
TwoWayAttentionBlock as TFTwoWayAttentionBlock,
4849
TwoWayTransformer as TFTwoWayTransformer,
4950
)
@@ -468,3 +469,23 @@ def test_predictor(fixed_input_size):
468469
masks, scores, logits = predictor(points=[[10, 10]], multimask_output=False)
469470

470471
assert masks.shape == (1, *img.shape[:2])
472+
473+
474+
# This test takes longer, because the model is quite complex.
475+
@pytest.mark.timeout(120)
476+
def test_save_load_model():
477+
"""Tests ability to use keras save() and load() functions."""
478+
model = create_model("sam_vit_test_model")
479+
with tempfile.TemporaryDirectory() as tmpdir:
480+
model.save(tmpdir)
481+
loaded_model = tf.keras.models.load_model(tmpdir, compile=False)
482+
483+
assert type(model) is type(loaded_model)
484+
485+
inputs = model.dummy_inputs
486+
m_1, s_1, l_1 = model(inputs)
487+
m_2, s_2, l_2 = loaded_model(inputs)
488+
489+
assert np.sum(m_1.numpy() != m_2.numpy()) == 0
490+
assert (np.max(np.abs(s_1.numpy() - s_2.numpy()))) < 1e-6
491+
assert (np.max(np.abs(l_1.numpy() - l_2.numpy()))) < 1e-6

tfimm/architectures/segment_anything/image_encoder.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def window_unpartition(
5858
x: Unpartitioned tensor of shape (B, H, W, C).
5959
"""
6060
hp, wp = pad_hw
61-
h, w = hw
61+
h, w = hw[0], hw[1]
6262
window_size = tf.shape(windows)[1]
6363
nb_windows = (hp // window_size) * (wp // window_size)
6464
n = tf.shape(windows)[0] // nb_windows
@@ -93,7 +93,7 @@ def get_rel_pos(
9393
Extracted positional embeddings according to relative positions.
9494
"""
9595
m = tf.shape(rel_pos)[0]
96-
max_rel_dist = int(2 * max(q_size, k_size) - 1)
96+
max_rel_dist = tf.cast(2 * tf.math.maximum(q_size, k_size) - 1, tf.int32)
9797

9898
if interpolate_pos:
9999
# Interpolate positional embeddings if needed.
@@ -108,10 +108,10 @@ def get_rel_pos(
108108
q_coords = tf.expand_dims(tf.range(q_size, dtype=tf.float32), axis=-1)
109109
k_coords = tf.expand_dims(tf.range(k_size, dtype=tf.float32), axis=0)
110110
# Scale the coords with short length if shapes for q and k are different.
111-
q_coords = q_coords * tf.cast(max(k_size / q_size, 1.0), tf.float32)
112-
k_coords = k_coords * tf.cast(max(q_size / k_size, 1.0), tf.float32)
111+
q_coords = q_coords * tf.cast(tf.math.maximum(k_size / q_size, 1.0), tf.float32)
112+
k_coords = k_coords * tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32)
113113

114-
lambda_ = tf.cast(max(q_size / k_size, 1.0), tf.float32)
114+
lambda_ = tf.cast(tf.math.maximum(q_size / k_size, 1.0), tf.float32)
115115
offset = tf.cast(k_size - 1, tf.float32) * lambda_
116116
relative_coords = (q_coords - k_coords) + offset
117117
relative_coords = tf.cast(relative_coords, tf.int32)
@@ -168,7 +168,7 @@ def add_decomposed_rel_pos(
168168
return attn
169169

170170

171-
class Attention(tf.keras.layers.Layer):
171+
class RelPosAttention(tf.keras.layers.Layer):
172172
"""Multi-head Attention block with relative position embeddings."""
173173

174174
def __init__(
@@ -263,7 +263,7 @@ def call(self, x, training=False):
263263
return x
264264

265265

266-
class Block(tf.keras.layers.Layer):
266+
class ImageEncoderBlock(tf.keras.layers.Layer):
267267
"""
268268
Transformer blocks with support for window attention and residual propagation.
269269
"""
@@ -316,7 +316,7 @@ def __init__(
316316
norm_layer = norm_layer_factory(norm_layer)
317317

318318
self.norm1 = norm_layer(name="norm1")
319-
self.attn = Attention(
319+
self.attn = RelPosAttention(
320320
fixed_input_size=self.fixed_input_size,
321321
embed_dim=self.embed_dim,
322322
nb_heads=self.nb_heads,
@@ -438,7 +438,7 @@ def __init__(
438438
self.pos_embed = None
439439

440440
self.blocks = [
441-
Block(
441+
ImageEncoderBlock(
442442
fixed_input_size=self.fixed_input_size,
443443
embed_dim=self.embed_dim,
444444
nb_heads=self.nb_heads,

tfimm/architectures/segment_anything/transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
)
3939
for j in range(self.nb_blocks)
4040
]
41-
self.final_attn_token_to_image = Attention(
41+
self.final_attn_token_to_image = DownsampleAttention(
4242
embed_dim=self.embed_dim,
4343
nb_heads=self.nb_heads,
4444
downsample_rate=self.attention_downsample_rate,
@@ -131,12 +131,12 @@ def __init__(
131131

132132
norm_layer = norm_layer_factory("layer_norm")
133133

134-
self.self_attn = Attention(
134+
self.self_attn = DownsampleAttention(
135135
embed_dim=embed_dim, nb_heads=nb_heads, downsample_rate=1, name="self_attn"
136136
)
137137
self.norm1 = norm_layer(name="norm1")
138138

139-
self.cross_attn_token_to_image = Attention(
139+
self.cross_attn_token_to_image = DownsampleAttention(
140140
embed_dim=embed_dim,
141141
nb_heads=nb_heads,
142142
downsample_rate=attention_downsample_rate,
@@ -153,7 +153,7 @@ def __init__(
153153
)
154154
self.norm3 = norm_layer(name="norm3")
155155

156-
self.cross_attn_image_to_token = Attention(
156+
self.cross_attn_image_to_token = DownsampleAttention(
157157
embed_dim=embed_dim,
158158
nb_heads=nb_heads,
159159
downsample_rate=attention_downsample_rate,
@@ -194,7 +194,7 @@ def call(self, inputs, training=False):
194194
return q, k
195195

196196

197-
class Attention(tf.keras.layers.Layer):
197+
class DownsampleAttention(tf.keras.layers.Layer):
198198
"""
199199
An attention layer that allows for downscaling the size of the embedding after
200200
projection to queries, keys, and values.
@@ -221,7 +221,7 @@ def __init__(
221221
)
222222

223223
def _separate_heads(self, x: tf.Tensor):
224-
b, m, c = tf.shape(x) # (B, M, C)
224+
b, m, c = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2] # (B, M, C)
225225
x = tf.reshape(x, (b, m, self.nb_heads, c // self.nb_heads)) # (B, M, Hd, C/Hd)
226226
x = tf.transpose(x, (0, 2, 1, 3)) # (B, Hd, M, C/Hd)
227227
return x

0 commit comments

Comments
 (0)