Skip to content

Commit c8efcd7

Browse files
Clip update inputs format (#2377)
* update model input format * update * code reformat * remove input tensors * Update docstring --------- Co-authored-by: Divyashree Sreepathihalli <divyashreepathihalli>
1 parent 5faae37 commit c8efcd7

File tree

2 files changed

+66
-20
lines changed

2 files changed

+66
-20
lines changed

keras_cv/models/feature_extractor/clip/clip_model.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,26 @@ class CLIP(Task):
6161
transformer-based text encoder.
6262
transformer_layers (int): The number of layers in the transformer-based
6363
text encoder.
64+
Example:
65+
```python
66+
processor = CLIPProcessor(
67+
input_resolution=224,
68+
"path_to_vocab.json",
69+
"path_to_merges.txt"
70+
)
71+
processed_image = processor.process_images(["cat.jpg"])
72+
processed_text, attention_mask = processor.process_texts(
73+
["mountains", "cat on tortoise", "two cats"]
74+
)
75+
model = CLIP.from_preset("clip-vit-base-patch16")
76+
image_logits, text_logits = model(
77+
{
78+
"image": processed_image,
79+
"text": processed_text,
80+
"attention_mask": attention_mask,
81+
}
82+
)
83+
```
6484
"""
6585

6686
def __init__(
@@ -133,7 +153,12 @@ def encode_images(self, image):
133153
def encode_text(self, text, attention_mask=None):
134154
return self.text_encoder(text, attention_mask=attention_mask)
135155

136-
def call(self, image, text, attention_mask=None):
156+
def call(self, inputs):
157+
image, text = inputs["image"], inputs["text"]
158+
if "attention_mask" in inputs:
159+
attention_mask = inputs["attention_mask"]
160+
else:
161+
attention_mask = None
137162
self.image_embeddings = self.encode_images(image)
138163
self.text_embeddings = self.encode_text(
139164
text, attention_mask=attention_mask

keras_cv/models/feature_extractor/clip/clip_model_test.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,24 @@
3434
"https://storage.googleapis.com/keras-cv/models/clip/merges.txt",
3535
)
3636

37-
MODEL_PATH = keras.utils.get_file(
38-
None,
39-
"https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5", # noqa: E501
40-
)
41-
4237

4338
class CLIPTest(TestCase):
4439
@pytest.mark.large
4540
def test_clip_model_golden_values(self):
46-
model = CLIP()
47-
model.load_weights(MODEL_PATH)
41+
model = CLIP.from_preset("clip-vit-base-patch32")
4842
processed_image = np.ones(shape=[1, 224, 224, 3])
4943
processed_text = np.ones(shape=[3, 77])
5044
attention_mask = np.ones(shape=[3, 77])
5145
image_logits, text_logits = model(
52-
processed_image, processed_text, attention_mask
46+
{
47+
"image": processed_image,
48+
"text": processed_text,
49+
"attention_mask": attention_mask,
50+
}
5351
)
54-
print(image_logits)
55-
self.assertAllClose(image_logits, [[1.896713, 1.896713, 1.896713]])
52+
self.assertAllClose(image_logits, [[1.896712, 1.896712, 1.896712]])
5653
self.assertAllClose(
57-
text_logits, ops.transpose([[1.896713, 1.896713, 1.896713]])
54+
text_logits, ops.transpose([[1.896712, 1.896712, 1.896712]])
5855
)
5956

6057
def test_clip_preprocessor(self):
@@ -83,17 +80,26 @@ def test_presets(self):
8380
processed_text = np.ones(shape=[3, 77])
8481
attention_mask = np.ones(shape=[3, 77])
8582
image_logits, text_logits = model(
86-
processed_image, processed_text, attention_mask
83+
{
84+
"image": processed_image,
85+
"text": processed_text,
86+
"attention_mask": attention_mask,
87+
}
8788
)
8889

8990
@pytest.mark.large
9091
def test_image_encoder_golden_values(self):
91-
model = CLIP()
92-
model.load_weights(MODEL_PATH)
92+
model = CLIP.from_preset("clip-vit-base-patch32")
9393
processed_image = np.ones(shape=[1, 224, 224, 3])
9494
processed_text = np.ones(shape=[3, 77])
9595
attention_mask = np.ones(shape=[3, 77])
96-
model(processed_image, processed_text, attention_mask)
96+
model(
97+
{
98+
"image": processed_image,
99+
"text": processed_text,
100+
"attention_mask": attention_mask,
101+
}
102+
)
97103
self.assertAllClose(
98104
model.image_embeddings[:, :5],
99105
[[0.023215, 0.026526, 0.008914, -0.091689, 0.021791]],
@@ -105,8 +111,13 @@ def test_text_encoder_golden_values(self):
105111
processed_image = np.ones(shape=[1, 224, 224, 3])
106112
processed_text = np.ones(shape=[3, 77])
107113
attention_mask = np.ones(shape=[3, 77])
108-
model(processed_image, processed_text, attention_mask)
109-
print(model.text_embeddings)
114+
model(
115+
{
116+
"image": processed_image,
117+
"text": processed_text,
118+
"attention_mask": attention_mask,
119+
}
120+
)
110121
self.assertAllClose(
111122
model.text_embeddings[0, :3],
112123
[0.007531, -0.038361, -0.035686],
@@ -118,7 +129,13 @@ def test_saved_model(self):
118129
processed_image = np.ones(shape=[1, 224, 224, 3])
119130
processed_text = np.ones(shape=[3, 77])
120131
attention_mask = np.ones(shape=[3, 77])
121-
model_output, _ = model(processed_image, processed_text, attention_mask)
132+
model_output, _ = model(
133+
{
134+
"image": processed_image,
135+
"text": processed_text,
136+
"attention_mask": attention_mask,
137+
}
138+
)
122139
save_path = os.path.join(self.get_temp_dir(), "model.keras")
123140
if keras_3():
124141
model.save(save_path)
@@ -130,6 +147,10 @@ def test_saved_model(self):
130147
self.assertIsInstance(restored_model, CLIP)
131148
# Check that output matches.
132149
restored_output, _ = restored_model(
133-
processed_image, processed_text, attention_mask
150+
{
151+
"image": processed_image,
152+
"text": processed_text,
153+
"attention_mask": attention_mask,
154+
}
134155
)
135156
self.assertAllClose(model_output, restored_output)

0 commit comments

Comments
 (0)