@@ -180,7 +180,9 @@ def test_load_pretrained(
180180 # Assert that the loaded model has the same properties as the original one
181181 np .testing .assert_array_equal (loaded_model .embedding , mock_vectors )
182182 assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
183- assert loaded_model .config == mock_config
183+ for k , v in mock_config .items ():
184+ assert loaded_model .config .get (k ) == v
185+ assert "embedding_dtype" in loaded_model .config
184186
185187
186188def test_load_pretrained_quantized (
@@ -198,19 +200,22 @@ def test_load_pretrained_quantized(
198200 # Assert that the loaded model has the same properties as the original one
199201 assert loaded_model .embedding .dtype == np .int8
200202 assert loaded_model .embedding .shape == mock_vectors .shape
203+ assert loaded_model .embedding_dtype == "int8"
201204
202205 # Load the model back from the same path
203206 loaded_model = StaticModel .from_pretrained (save_path , quantize_to = "float16" )
204207
205208 # Assert that the loaded model has the same properties as the original one
206209 assert loaded_model .embedding .dtype == np .float16
207210 assert loaded_model .embedding .shape == mock_vectors .shape
211+ assert loaded_model .embedding_dtype == "float16"
208212
209213 # Load the model back from the same path
210214 loaded_model = StaticModel .from_pretrained (save_path , quantize_to = "float32" )
211215 # Assert that the loaded model has the same properties as the original one
212216 assert loaded_model .embedding .dtype == np .float32
213217 assert loaded_model .embedding .shape == mock_vectors .shape
218+ assert loaded_model .embedding_dtype == "float32"
214219
215220 # Load the model back from the same path
216221 loaded_model = StaticModel .from_pretrained (save_path , quantize_to = "float64" )
@@ -234,15 +239,19 @@ def test_load_pretrained_dim(
234239 # Assert that the loaded model has the same properties as the original one
235240 np .testing .assert_array_equal (loaded_model .embedding , mock_vectors [:, :2 ])
236241 assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
237- assert loaded_model .config == mock_config
242+ for k , v in mock_config .items ():
243+ assert loaded_model .config .get (k ) == v
244+ assert "embedding_dtype" in loaded_model .config
238245
239246 # Load the model back from the same path
240247 loaded_model = StaticModel .from_pretrained (save_path , dimensionality = None )
241248
242249 # Assert that the loaded model has the same properties as the original one
243250 np .testing .assert_array_equal (loaded_model .embedding , mock_vectors )
244251 assert loaded_model .tokenizer .get_vocab () == mock_tokenizer .get_vocab ()
245- assert loaded_model .config == mock_config
252+ for k , v in mock_config .items ():
253+ assert loaded_model .config .get (k ) == v
254+ assert "embedding_dtype" in loaded_model .config
246255
247256 # Load the model back from the same path
248257 with pytest .raises (ValueError ):
@@ -267,6 +276,7 @@ def test_load_pretrained_vocabulary_quantized(
267276 assert loaded_model .weights is not None
268277 assert loaded_model .weights .shape == (5 ,)
269278 assert loaded_model .token_mapping is not None
279+ assert loaded_model .vocabulary_quantization == 3
270280 assert len (loaded_model .token_mapping ) == mock_tokenizer .get_vocab_size ()
271281 assert len (loaded_model .token_mapping ) == len (loaded_model .weights )
272282 assert loaded_model .encode ("word1 word2" ).shape == (2 ,)
0 commit comments