Skip to content

Commit 10dcadf

Browse files
committed
fix unittest
1 parent 0bc0398 commit 10dcadf

File tree

4 files changed

+12
-13
lines changed

4 files changed

+12
-13
lines changed

tests/buffer/sample_strategy_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def _init_buffer_writer_and_sample_strategy(self):
5858
async def _verify_model_version(self, step, expected_versions):
5959
batch, metrics, _ = await self.sample_strategy.sample(step=step)
6060
self.assertEqual(
61-
batch.rewards.tolist(), expected_versions, f"Model versions mismatch at step {step}"
61+
[exp.reward for exp in batch],
62+
expected_versions,
63+
f"Model versions mismatch at step {step}",
6264
)
6365
self.assertEqual(
6466
metrics["sample/model_version/min"],

tests/common/vllm_test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,17 +125,13 @@ def setUp(self):
125125
self.config.algorithm.repeat_times = self.repeat_times
126126
self.config.explorer.rollout_model.enable_history = self.enable_history
127127
self.config.check_and_update()
128-
from pprint import pprint
129128

130-
pprint(self.config)
131129
self.engines, self.auxiliary_engines = create_inference_models(self.config)
132130
self.model_wrapper = ModelWrapper(
133131
self.engines[0], engine_type="vllm", enable_history=self.enable_history
134132
)
135133

136-
async def test_generate(
137-
self,
138-
):
134+
async def test_generate(self):
139135
await prepare_engines(self.engines, self.auxiliary_engines)
140136
await self.model_wrapper.prepare()
141137
self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path)

tests/trainer/trainer_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,7 @@ def tearDown(self):
13251325

13261326

13271327
class TestTinkerTrainer(BaseTrainerCase):
1328-
# @unittest.skip("Require tinker API key")
1328+
@unittest.skip("Require tinker API key")
13291329
def test_trainer(self):
13301330
"""Test GSM8K on tinker."""
13311331
# test both mode

trinity/trainer/tinker_trainer.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def _init_algorithm(self):
4545
self.kl_fn = KL_FN.get(algorithm_config.kl_penalty_fn)(
4646
**algorithm_config.kl_penalty_fn_args
4747
)
48+
# TODO
49+
raise NotImplementedError(
50+
"`compute_advantage_in_trainer` is not implemented yet in tinker"
51+
)
4852
self.loss_agg_mode = algorithm_config.loss_agg_mode
4953
self.policy_loss_fn = POLICY_LOSS_FN.get(algorithm_config.policy_loss_fn)(
5054
backend="tinker", **algorithm_config.policy_loss_fn_args
@@ -227,12 +231,9 @@ async def train_step(self, batch_exps: List[Experience]) -> Dict:
227231

228232
if self.algorithm.compute_advantage_in_trainer:
229233
# TODO: following is verl format, which is not compatible with tinker
230-
with marked_timer("adv", timing_raw):
231-
# compute kl penalty
232-
batch, kl_metrics = self.kl_fn.apply_kl_penalty_to_reward(batch)
233-
metrics.update(prefix_metrics(kl_metrics, prefix="critic"))
234-
# compute advantages, executed on the driver process
235-
batch, _ = self.advantage_fn(batch)
234+
raise NotImplementedError(
235+
"`compute_advantage_in_trainer` is not implemented yet in tinker"
236+
)
236237
else:
237238
# skip token_level_scores for sft/dpo
238239
for model_inputs in model_inputs_list:

0 commit comments

Comments
 (0)