Trinity-RFT v0.2.0 #112
Annotations
4 errors
|
unittest
Process completed with exit code 1.
|
|
Failed Test: tests/trainer/trainer_test.py::TestTrainerDPO::test_trainer
tests/trainer/trainer_test.py::TestTrainerDPO::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerDPO testMethod=test_trainer>
def test_trainer(self):
"""Test DPO."""
# test both mode
self.config.mode = "train"
self.config.algorithm.algorithm_type = "dpo"
self.config.algorithm.policy_loss_fn = "dpo"
self.config.algorithm.policy_loss_fn_args = {}
# self.config.buffer.batch_size = 32
self.config.buffer.trainer_input.experience_buffer = get_unittest_dataset_config("dpo")
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 5e-7
> train(self.config)
tests/trainer/trainer_test.py:205:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
trinity/cli/launcher.py:50: in train
ray.get(trainer.prepare.remote())
/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:22: in auto_init_wrapper
return fn(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:104: in wrapper
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2849: in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ray._private.worker.Worker object at 0x7f35f81cd4e0>
object_refs = [ObjectRef(c3f3ec298be7daba1beab5fe70e95a59314c75b01200000001000000)]
timeout = None, return_exceptions = False, skip_deserialization = False
def get_objects(
self,
object_refs: list,
timeout: Optional[float] = None,
return_exceptions: bool = False,
skip_deserialization: bool = False,
):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_refs. This
will block until all the values for object_refs have been written to
the local object store.
Args:
object_refs: A list of the object refs
whose values should be retrieved.
timeout: The maximum amount of time in
seconds to wait before returning.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If true, only the buffer will be released and
the object associated with the buffer will not be deserialized.
Returns:
list: List of deserialized objects or None if skip_deserialization is True.
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
f"Attempting to call `get` on the value {object_ref}, "
"which is not an ray.ObjectRef."
)
timeout_ms = (
int(timeout * 1000) if timeout is not None and timeout != -1 else -1
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.get_objects(
object_refs,
timeout_ms,
)
debugger_breakpoint = b""
for data, metadata in data_metadata_pairs:
if metadata:
metadata_fields = metadata.split(b",")
if len(metadata_fields) >= 2 and metadata_fields[1].startswith(
ray_constants.OBJECT_METADATA_DEBUG_PREFIX
):
|
|
Failed Test: tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer
tests/trainer/trainer_test.py::TestTrainerGSM8KWithSFT::test_trainer: The test failed in the call phase due to an exception - self = <tests.trainer.trainer_test.TestTrainerGSM8KWithSFT testMethod=test_trainer>
def test_trainer(self):
"""Test GSM8K With SFT."""
# test both mode
self.config.algorithm.algorithm_type = "grpo"
self.config.algorithm.repeat_times = 4
self.config.algorithm.advantage_fn = "grpo"
self.config.algorithm.advantage_fn_args = {}
self.config.buffer.explorer_input.taskset = get_unittest_dataset_config("gsm8k")
self.config.buffer.trainer_input.sft_warmup_steps = 2
self.config.buffer.trainer_input.sft_warmup_dataset = get_unittest_dataset_config(
"sft_for_gsm8k"
)
self.config.check_and_update()
self.config.trainer.trainer_config.trainer.total_training_steps = 4
self.config.trainer.trainer_config.trainer.max_actor_ckpt_to_keep = 2
self.config.trainer.trainer_config.actor_rollout_ref.actor.optim.lr = 1e-5
> both(self.config)
tests/trainer/trainer_test.py:173:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
trinity/cli/launcher.py:73: in both
ray.get([explorer.__ray_ready__.remote(), trainer.__ray_ready__.remote()])
/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py:22: in auto_init_wrapper
return fn(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py:104: in wrapper
return func(*args, **kwargs)
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2849: in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <ray._private.worker.Worker object at 0x7f35f81cd4e0>
object_refs = [ObjectRef(aacfdf30a54b8250f694b7ad54f9d85af77bc37f1100000001000000), ObjectRef(5a92aba0550c8f639fb2160ae3eaa5bbcd0d27421100000001000000)]
timeout = None, return_exceptions = False, skip_deserialization = False
def get_objects(
self,
object_refs: list,
timeout: Optional[float] = None,
return_exceptions: bool = False,
skip_deserialization: bool = False,
):
"""Get the values in the object store associated with the IDs.
Return the values from the local object store for object_refs. This
will block until all the values for object_refs have been written to
the local object store.
Args:
object_refs: A list of the object refs
whose values should be retrieved.
timeout: The maximum amount of time in
seconds to wait before returning.
return_exceptions: If any of the objects deserialize to an
Exception object, whether to return them as values in the
returned list. If False, then the first found exception will be
raised.
skip_deserialization: If true, only the buffer will be released and
the object associated with the buffer will not be deserialized.
Returns:
list: List of deserialized objects or None if skip_deserialization is True.
bytes: UUID of the debugger breakpoint we should drop
into or b"" if there is no breakpoint.
"""
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ObjectRef):
raise TypeError(
f"Attempting to call `get` on the value {object_ref}, "
"which is not an ray.ObjectRef."
)
timeout_ms = (
int(timeout * 1000) if timeout is not None and timeout != -1 else -1
)
data_metadata_pairs: List[
Tuple[ray._raylet.Buffer, bytes]
] = self.core_worker.get_objects(
object_refs,
timeout_ms,
)
debugger_breakpoint = b""
fo
|
|
unittest
Process completed with exit code 1.
|