2525import numpy as np
2626from recml .core .utils import keras_utils
2727
28+ _LEARNING_RATE_SCHEDULE = keras .optimizers .schedules .PolynomialDecay (
29+ initial_learning_rate = 0.1 ,
30+ decay_steps = 100 ,
31+ end_learning_rate = 0.01 ,
32+ power = 1.0 ,
33+ )
34+
2835
2936def _create_model (input_shapes : Sequence [int ]) -> keras .Model :
3037 model = keras_hub .models .BertMaskedLM (
@@ -39,7 +46,7 @@ def _create_model(input_shapes: Sequence[int]) -> keras.Model:
3946 dropout = 0.1 ,
4047 )
4148 )
42- optimizer = keras .optimizers .Adam (learning_rate = 0.1 )
49+ optimizer = keras .optimizers .Adam (learning_rate = _LEARNING_RATE_SCHEDULE )
4350 loss = keras .losses .SparseCategoricalCrossentropy ()
4451 metrics = [keras .metrics .SparseCategoricalAccuracy ()]
4552 model .compile (optimizer , loss , weighted_metrics = metrics )
@@ -242,6 +249,112 @@ def test_metrics_variables_checkpointing(
242249 )
243250 self .assertSequenceEqual (w1 .dtype , w2 .dtype )
244251
252+ @parameterized .named_parameters (
253+ {
254+ "testcase_name" : "restore_all_variables" ,
255+ "restore_optimizer_vars" : True ,
256+ "restore_steps" : True ,
257+ "restore_iterations" : True ,
258+ "expected_learning_rate" : 0.01 ,
259+ "expected_iterations" : 100 ,
260+ "expected_initial_epoch" : 2 ,
261+ },
262+ {
263+ "testcase_name" : "restore_without_optimizer_vars" ,
264+ "restore_optimizer_vars" : False ,
265+ "restore_steps" : True ,
266+ "restore_iterations" : True ,
267+ "expected_learning_rate" : 0.1 ,
268+ "expected_iterations" : 0 ,
269+ "expected_initial_epoch" : 2 ,
270+ },
271+ {
272+ "testcase_name" : "restore_without_steps" ,
273+ "restore_optimizer_vars" : True ,
274+ "restore_steps" : False ,
275+ "restore_iterations" : True ,
276+ "expected_learning_rate" : 0.01 ,
277+ "expected_iterations" : 100 ,
278+ "expected_initial_epoch" : None ,
279+ },
280+ {
281+ "testcase_name" : "restore_without_iterations" ,
282+ "restore_optimizer_vars" : True ,
283+ "restore_steps" : True ,
284+ "restore_iterations" : False ,
285+ "expected_learning_rate" : 0.1 ,
286+ "expected_iterations" : 0 ,
287+ "expected_initial_epoch" : 2 ,
288+ },
289+ {
290+ "testcase_name" : "restore_only_model_variables" ,
291+ "restore_optimizer_vars" : False ,
292+ "restore_steps" : False ,
293+ "restore_iterations" : False ,
294+ "expected_learning_rate" : 0.1 ,
295+ "expected_iterations" : 0 ,
296+ "expected_initial_epoch" : None ,
297+ },
298+ )
299+ def test_restore_keras_model_with_different_options (
300+ self ,
301+ restore_optimizer_vars : bool ,
302+ restore_steps : bool ,
303+ restore_iterations : bool ,
304+ expected_learning_rate : float ,
305+ expected_iterations : int ,
306+ expected_initial_epoch : int | None ,
307+ ):
308+ checkpoint_dir = self .create_tempdir ().full_path
309+ checkpointer = keras_utils .KerasOrbaxCheckpointManager (checkpoint_dir )
310+ epoch = 1
311+ dummy_inputs = {
312+ "token_ids" : jax .random .randint (
313+ jax .random .key (0 ), (64 , 128 ), minval = 0 , maxval = 50_000
314+ ),
315+ "segment_ids" : jax .random .randint (
316+ jax .random .key (0 ), (64 , 128 ), minval = 0 , maxval = 7
317+ ),
318+ "padding_mask" : jax .random .uniform (jax .random .key (0 ), (64 , 128 )),
319+ "mask_positions" : jax .random .randint (
320+ jax .random .key (0 ), (64 , 20 ), minval = 0 , maxval = 128
321+ ),
322+ }
323+
324+ source_bert_pretrainer = _create_model (
325+ jax .tree .map (jnp .shape , dummy_inputs )
326+ )
327+ source_bert_pretrainer .optimizer .iterations .assign (100 )
328+ source_state = source_bert_pretrainer ._get_jax_state ( # pylint: disable=protected-access
329+ trainable_variables = True ,
330+ non_trainable_variables = True ,
331+ optimizer_variables = True ,
332+ )
333+ checkpointer .save (step = epoch , items = source_state )
334+ checkpointer .wait_until_finished ()
335+
336+ target_bert_pretrainer = _create_model (
337+ jax .tree .map (jnp .shape , dummy_inputs )
338+ )
339+ keras_utils .restore_keras_model (
340+ target_bert_pretrainer ,
341+ checkpoint_dir ,
342+ restore_optimizer_vars = restore_optimizer_vars ,
343+ restore_steps = restore_steps ,
344+ restore_iterations = restore_iterations ,
345+ )
346+
347+ self .assertEqual (
348+ target_bert_pretrainer .optimizer .iterations .value , expected_iterations
349+ )
350+ self .assertEqual (
351+ target_bert_pretrainer .optimizer .learning_rate ,
352+ expected_learning_rate ,
353+ )
354+ self .assertEqual (
355+ target_bert_pretrainer ._initial_epoch , expected_initial_epoch
356+ )
357+
245358
246359if __name__ == "__main__" :
247360 absltest .main ()
0 commit comments