@@ -248,26 +248,44 @@ def test_build_model(self):
248
248
self .assertEqual (model .predict (np .zeros ([5 , 4 ])).shape , (5 , 16 ))
249
249
self .assertEqual (model (np .zeros ([5 , 4 ])).shape , (5 , 16 ))
250
250
251
- def test_save_load (self ):
251
+ @parameterized .named_parameters (
252
+ ("safe_mode" , True ),
253
+ ("unsafe_mode" , False ),
254
+ )
255
+ def test_save_load (self , safe_mode ):
252
256
@keras .saving .register_keras_serializable ()
253
257
class M (keras .Model ):
254
- def __init__ (self , channels = 10 , ** kwargs ):
255
- super ().__init__ ()
256
- self .sequence = torch .nn .Sequential (
257
- torch .nn .Conv2d (1 , channels , kernel_size = (3 , 3 )),
258
- )
258
+ def __init__ (self , module , ** kwargs ):
259
+ super ().__init__ (** kwargs )
260
+ self .module = module
259
261
260
262
def call (self , x ):
261
- return self .sequence (x )
263
+ return self .module (x )
262
264
263
- m = M ()
265
+ def get_config (self ):
266
+ base_config = super ().get_config ()
267
+ config = {"module" : self .module }
268
+ return {** base_config , ** config }
269
+
270
+ @classmethod
271
+ def from_config (cls , config ):
272
+ config ["module" ] = saving .deserialize_keras_object (
273
+ config ["module" ]
274
+ )
275
+ return cls (** config )
276
+
277
+ m = M (torch .nn .Conv2d (1 , 10 , kernel_size = (3 , 3 )))
264
278
device = get_device () # Get the current device (e.g., "cuda" or "cpu")
265
279
x = torch .ones (
266
280
(10 , 1 , 28 , 28 ), device = device
267
281
) # Place input on the correct device
268
- m (x )
282
+ ref_output = m (x )
269
283
temp_filepath = os .path .join (self .get_temp_dir (), "mymodel.keras" )
270
284
m .save (temp_filepath )
271
- new_model = saving .load_model (temp_filepath )
272
- for ref_w , new_w in zip (m .get_weights (), new_model .get_weights ()):
273
- self .assertAllClose (ref_w , new_w , atol = 1e-5 )
285
+
286
+ if safe_mode :
287
+ with self .assertRaisesRegex (ValueError , "arbitrary code execution" ):
288
+ saving .load_model (temp_filepath , safe_mode = safe_mode )
289
+ else :
290
+ new_model = saving .load_model (temp_filepath , safe_mode = safe_mode )
291
+ self .assertAllClose (new_model (x ), ref_output )
0 commit comments