Add KL/Entorpy Fn #80
Annotations
4 errors
|
unittest
Process completed with exit code 1.
|
|
Failed Test: tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer
tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerGSM8KWithSFT testMethod=test_trainer>
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
self.config.algorithm.advantage_fn_args = {}
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.trainer_input.sft_warmup_steps = 2
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
"sft_for_gsm8k"
)
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
> both(self.config)
tests/trainer/trainer_test.py:175:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
trinity/cli/launcher.py:98: in both
train_continue, train_step_num = ray.get(
/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:21: in auto_init_wrapper
return fn(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:103: in wrapper
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2822: in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ray._private.worker.Worker object at 0x7f2b2b9d79d0>
object_refs = [ObjectRef(688c4e14691ac120b48cacdd09721b0bf13f07c71100000001000000)]
timeout = None, return_exceptions = False, skip_deserialization = False
def get_objects(
self,
object_refs: list,
timeout: Optional[float] = None,
return_exceptions: bool = False,
skip_deserialization: bool = False,
):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_refs. This
will block until all the values for object_refs have been written to
the local object store.
Args:
object_refs: A list of the object refs
whose values should be retrieved.
timeout: The maximum amount of time in
seconds to wait before returning.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If true, only the buffer will be released and
the object associated with the buffer will not be deserialized.
Returns:
list: List of deserialized objects or None if skip_deserialization is True.
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
f"Attempting to call `get` on the value {object_ref}, "
"which is not an ray.ObjectRef."
)
timeout_ms = (
int(timeout * 1000) if timeout is not None and timeout != -1 else -1
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.get_objects(
object_refs,
timeout_ms,
)
debugger_breakpoint = b""
for data, metadata in data_metadata_pairs:
if metadata:
|
|
Failed Test: tests/trainer/trainer_test.py::TestTrainerGSM8K::test_trainer
tests/trainer/trainer_test.py::TestTrainerGSM8K::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerGSM8K testMethod=test_trainer>
def test_trainer(self):
"""Test GSM8K."""
# test both mode
self.config.algorithm.algorithm_type = AlgorithmType.GRPO
self.config.algorithm.repeat_times = 4
# self.config.algorithm.repeat_times = 8 # TODO: used for real testing
self.config.algorithm.advantage_fn_type = "grpo_adv_fn"
self.config.algorithm.advantage_fn_args = {}
# self.config.buffer.batch_size = 96 # TODO: used for real testing
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
> both(self.config)
tests/trainer/trainer_test.py:133:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
trinity/cli/launcher.py:130: in both
raise e
trinity/cli/launcher.py:114: in both
train_continue, train_step_num = ray.get(ref_train)
/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:21: in auto_init_wrapper
return fn(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:103: in wrapper
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2822: in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ray._private.worker.Worker object at 0x7f2b2b9d79d0>
object_refs = [ObjectRef(e0772ef5daca5c83a1114e9574ebd288727a2d651000000001000000)]
timeout = None, return_exceptions = False, skip_deserialization = False
def get_objects(
self,
object_refs: list,
timeout: Optional[float] = None,
return_exceptions: bool = False,
skip_deserialization: bool = False,
):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_refs. This
will block until all the values for object_refs have been written to
the local object store.
Args:
object_refs: A list of the object refs
whose values should be retrieved.
timeout: The maximum amount of time in
seconds to wait before returning.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If true, only the buffer will be released and
the object associated with the buffer will not be deserialized.
Returns:
list: List of deserialized objects or None if skip_deserialization is True.
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
f"Attempting to call `get` on the value {object_ref}, "
"which is not an ray.ObjectRef."
)
timeout_ms = (
int(timeout * 1000) if timeout is not None and timeout != -1 else -1
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.get_objects(
object_refs,
timeout_ms,
)
debugger_breakpoint = b""
for data, metadata in data_metadata_pairs:
if metadata:
|
|
unittest
Process completed with exit code 1.
|