Skip to content

Add KL/Entorpy Fn

Add KL/Entorpy Fn #80

Triggered via issue June 4, 2025 14:05
@pan-x-cpan-x-c
commented on #64 fefbbee
Status Failure
Total duration 19m 58s
Artifacts

unittest.yaml

on: issue_comment
Fit to window
Zoom out
Zoom in

Annotations

4 errors
unittest
Process completed with exit code 1.
Failed Test: tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer
tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerGSM8KWithSFT testMethod=test_trainer> def test_trainer(self): """Test GSM8K With SFT.""" # test both mode self.config.algorithm.algorithm_type = AlgorithmType.GRPO self.config.algorithm.repeat_times = 4 self.config.algorithm.advantage_fn_type = "grpo_adv_fn" self.config.algorithm.advantage_fn_args = {} self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.buffer.trainer_input.sft_warmup_steps = 2 self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config( "sft_for_gsm8k" ) self.config.check_and_update() self.config.trainer.trainer_config.trainer.total_training_steps = 4 self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 > both(self.config) tests/trainer/trainer_test.py:175: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ trinity/cli/launcher.py:98: in both train_continue, train_step_num = ray.get( /usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:21: in auto_init_wrapper return fn(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:103: in wrapper return func(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2822: in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <ray._private.worker.Worker object at 0x7f2b2b9d79d0> object_refs = [ObjectRef(688c4e14691ac120b48cacdd09721b0bf13f07c71100000001000000)] timeout = None, return_exceptions = False, skip_deserialization = False def get_objects( self, object_refs: list, timeout: Optional[float] = None, return_exceptions: bool = False, skip_deserialization: bool = False, ): """Get the values in the object store associated with the IDs. Return the values from the local object store for object_refs. This will block until all the values for object_refs have been written to the local object store. Args: object_refs: A list of the object refs whose values should be retrieved. timeout: The maximum amount of time in seconds to wait before returning. return_exceptions: If any of the objects deserialize to an Exception object, whether to return them as values in the returned list. If False, then the first found exception will be raised. skip_deserialization: If true, only the buffer will be released and the object associated with the buffer will not be deserialized. Returns: list: List of deserialized objects or None if skip_deserialization is True. bytes: UUID of the debugger breakpoint we should drop into or b"" if there is no breakpoint. """ # Make sure that the values are object refs. for object_ref in object_refs: if not isinstance(object_ref, ObjectRef): raise TypeError( f"Attempting to call `get` on the value {object_ref}, " "which is not an ray.ObjectRef." ) timeout_ms = ( int(timeout * 1000) if timeout is not None and timeout != -1 else -1 ) data_metadata_pairs: List[ Tuple[ray._raylet.Buffer, bytes] ] = self.core_worker.get_objects( object_refs, timeout_ms, ) debugger_breakpoint = b"" for data, metadata in data_metadata_pairs: if metadata:
Failed Test: tests/trainer/trainer_test.py::TestTrainerGSM8K::test_trainer
tests/trainer/trainer_test.py::TestTrainerGSM8K::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerGSM8K testMethod=test_trainer> def test_trainer(self): """Test GSM8K.""" # test both mode self.config.algorithm.algorithm_type = AlgorithmType.GRPO self.config.algorithm.repeat_times = 4 # self.config.algorithm.repeat_times = 8 # TODO: used for real testing self.config.algorithm.advantage_fn_type = "grpo_adv_fn" self.config.algorithm.advantage_fn_args = {} # self.config.buffer.batch_size = 96 # TODO: used for real testing self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k") self.config.check_and_update() self.config.trainer.trainer_config.trainer.total_training_steps = 4 self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2 self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5 > both(self.config) tests/trainer/trainer_test.py:133: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ trinity/cli/launcher.py:130: in both raise e trinity/cli/launcher.py:114: in both train_continue, train_step_num = ray.get(ref_train) /usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:21: in auto_init_wrapper return fn(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:103: in wrapper return func(*args, **kwargs) /usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2822: in get values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = <ray._private.worker.Worker object at 0x7f2b2b9d79d0> object_refs = [ObjectRef(e0772ef5daca5c83a1114e9574ebd288727a2d651000000001000000)] timeout = None, return_exceptions = False, skip_deserialization = False def get_objects( self, object_refs: list, timeout: Optional[float] = None, return_exceptions: bool = False, skip_deserialization: bool = False, ): """Get the values in the object store associated with the IDs. Return the values from the local object store for object_refs. This will block until all the values for object_refs have been written to the local object store. Args: object_refs: A list of the object refs whose values should be retrieved. timeout: The maximum amount of time in seconds to wait before returning. return_exceptions: If any of the objects deserialize to an Exception object, whether to return them as values in the returned list. If False, then the first found exception will be raised. skip_deserialization: If true, only the buffer will be released and the object associated with the buffer will not be deserialized. Returns: list: List of deserialized objects or None if skip_deserialization is True. bytes: UUID of the debugger breakpoint we should drop into or b"" if there is no breakpoint. """ # Make sure that the values are object refs. for object_ref in object_refs: if not isinstance(object_ref, ObjectRef): raise TypeError( f"Attempting to call `get` on the value {object_ref}, " "which is not an ray.ObjectRef." ) timeout_ms = ( int(timeout * 1000) if timeout is not None and timeout != -1 else -1 ) data_metadata_pairs: List[ Tuple[ray._raylet.Buffer, bytes] ] = self.core_worker.get_objects( object_refs, timeout_ms, ) debugger_breakpoint = b"" for data, metadata in data_metadata_pairs: if metadata:
unittest
Process completed with exit code 1.