@@ -138,6 +138,26 @@ def test_bcmodule_update(mock_env, trainer_config):
138138 env .close ()
139139
140140
141+ # Test with constant pretraining learning rate
142+ @pytest .mark .parametrize (
143+ "trainer_config" , [ppo_dummy_config (), sac_dummy_config ()], ids = ["ppo" , "sac" ]
144+ )
145+ @mock .patch ("mlagents.envs.environment.UnityEnvironment" )
146+ def test_bcmodule_constant_lr_update (mock_env , trainer_config ):
147+ mock_brain = mb .create_mock_3dball_brain ()
148+ trainer_config ["pretraining" ]["steps" ] = 0
149+ env , policy = create_policy_with_bc_mock (
150+ mock_env , mock_brain , trainer_config , False , "test.demo"
151+ )
152+ stats = policy .bc_module .update ()
153+ for _ , item in stats .items ():
154+ assert isinstance (item , np .float32 )
155+ old_learning_rate = policy .bc_module .current_lr
156+
157+ stats = policy .bc_module .update ()
158+ assert old_learning_rate == policy .bc_module .current_lr
159+
160+
141161# Test with RNN
142162@pytest .mark .parametrize (
143163 "trainer_config" , [ppo_dummy_config (), sac_dummy_config ()], ids = ["ppo" , "sac" ]
0 commit comments