3434 "https://storage.googleapis.com/keras-cv/models/clip/merges.txt" ,
3535)
3636
37- MODEL_PATH = keras .utils .get_file (
38- None ,
39- "https://storage.googleapis.com/keras-cv/models/clip/clip-vit-base-patch32.weights.h5" , # noqa: E501
40- )
41-
4237
4338class CLIPTest (TestCase ):
4439 @pytest .mark .large
4540 def test_clip_model_golden_values (self ):
46- model = CLIP ()
47- model .load_weights (MODEL_PATH )
41+ model = CLIP .from_preset ("clip-vit-base-patch32" )
4842 processed_image = np .ones (shape = [1 , 224 , 224 , 3 ])
4943 processed_text = np .ones (shape = [3 , 77 ])
5044 attention_mask = np .ones (shape = [3 , 77 ])
5145 image_logits , text_logits = model (
52- processed_image , processed_text , attention_mask
46+ {
47+ "image" : processed_image ,
48+ "text" : processed_text ,
49+ "attention_mask" : attention_mask ,
50+ }
5351 )
54- print (image_logits )
55- self .assertAllClose (image_logits , [[1.896713 , 1.896713 , 1.896713 ]])
52+ self .assertAllClose (image_logits , [[1.896712 , 1.896712 , 1.896712 ]])
5653 self .assertAllClose (
57- text_logits , ops .transpose ([[1.896713 , 1.896713 , 1.896713 ]])
54+ text_logits , ops .transpose ([[1.896712 , 1.896712 , 1.896712 ]])
5855 )
5956
6057 def test_clip_preprocessor (self ):
@@ -83,17 +80,26 @@ def test_presets(self):
8380 processed_text = np .ones (shape = [3 , 77 ])
8481 attention_mask = np .ones (shape = [3 , 77 ])
8582 image_logits , text_logits = model (
86- processed_image , processed_text , attention_mask
83+ {
84+ "image" : processed_image ,
85+ "text" : processed_text ,
86+ "attention_mask" : attention_mask ,
87+ }
8788 )
8889
8990 @pytest .mark .large
9091 def test_image_encoder_golden_values (self ):
91- model = CLIP ()
92- model .load_weights (MODEL_PATH )
92+ model = CLIP .from_preset ("clip-vit-base-patch32" )
9393 processed_image = np .ones (shape = [1 , 224 , 224 , 3 ])
9494 processed_text = np .ones (shape = [3 , 77 ])
9595 attention_mask = np .ones (shape = [3 , 77 ])
96- model (processed_image , processed_text , attention_mask )
96+ model (
97+ {
98+ "image" : processed_image ,
99+ "text" : processed_text ,
100+ "attention_mask" : attention_mask ,
101+ }
102+ )
97103 self .assertAllClose (
98104 model .image_embeddings [:, :5 ],
99105 [[0.023215 , 0.026526 , 0.008914 , - 0.091689 , 0.021791 ]],
@@ -105,8 +111,13 @@ def test_text_encoder_golden_values(self):
105111 processed_image = np .ones (shape = [1 , 224 , 224 , 3 ])
106112 processed_text = np .ones (shape = [3 , 77 ])
107113 attention_mask = np .ones (shape = [3 , 77 ])
108- model (processed_image , processed_text , attention_mask )
109- print (model .text_embeddings )
114+ model (
115+ {
116+ "image" : processed_image ,
117+ "text" : processed_text ,
118+ "attention_mask" : attention_mask ,
119+ }
120+ )
110121 self .assertAllClose (
111122 model .text_embeddings [0 , :3 ],
112123 [0.007531 , - 0.038361 , - 0.035686 ],
@@ -118,7 +129,13 @@ def test_saved_model(self):
118129 processed_image = np .ones (shape = [1 , 224 , 224 , 3 ])
119130 processed_text = np .ones (shape = [3 , 77 ])
120131 attention_mask = np .ones (shape = [3 , 77 ])
121- model_output , _ = model (processed_image , processed_text , attention_mask )
132+ model_output , _ = model (
133+ {
134+ "image" : processed_image ,
135+ "text" : processed_text ,
136+ "attention_mask" : attention_mask ,
137+ }
138+ )
122139 save_path = os .path .join (self .get_temp_dir (), "model.keras" )
123140 if keras_3 ():
124141 model .save (save_path )
@@ -130,6 +147,10 @@ def test_saved_model(self):
130147 self .assertIsInstance (restored_model , CLIP )
131148 # Check that output matches.
132149 restored_output , _ = restored_model (
133- processed_image , processed_text , attention_mask
150+ {
151+ "image" : processed_image ,
152+ "text" : processed_text ,
153+ "attention_mask" : attention_mask ,
154+ }
134155 )
135156 self .assertAllClose (model_output , restored_output )
0 commit comments