|
9 | 9 | import unittest |
10 | 10 | from copy import deepcopy |
11 | 11 | from datetime import datetime |
| 12 | +from typing import Dict |
12 | 13 | from unittest import mock |
13 | 14 |
|
14 | 15 | import ray |
@@ -1134,10 +1135,10 @@ async def test_serve_with_trainer(self): # noqa: C901 |
1134 | 1135 | + metrics["rollout/model_1/total_request_count"], |
1135 | 1136 | metrics["rollout/total_experience_count"], |
1136 | 1137 | ) |
1137 | | - # at least updated to version 2 |
| 1138 | + # at least updated to version 1 |
1138 | 1139 | await asyncio.sleep(5) # wait for model version update |
1139 | | - self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 2) |
1140 | | - self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 2) |
| 1140 | + self.assertGreaterEqual(metrics["rollout/model_0/model_version"], 1) |
| 1141 | + self.assertGreaterEqual(metrics["rollout/model_1/model_version"], 1) |
1141 | 1142 | # check final checkpoint |
1142 | 1143 | _, step_num = get_checkpoint_dir_with_step_num( |
1143 | 1144 | checkpoint_root_path=serve_config.checkpoint_job_dir, |
@@ -1433,3 +1434,149 @@ def test_trainer(self): |
1433 | 1434 | def tearDown(self): |
1434 | 1435 | # remove dir only when the test passed |
1435 | 1436 | shutil.rmtree(self.config.checkpoint_job_dir, ignore_errors=True) |
| 1437 | + |
| 1438 | + |
| 1439 | +@unittest.skip("Require agentscope >= 1.0.12") |
| 1440 | +class AgentScopeTunerTest(unittest.IsolatedAsyncioTestCase): |
| 1441 | + def setUp(self) -> None: |
| 1442 | + ray.init(ignore_reinit_error=True) |
| 1443 | + |
| 1444 | + def tearDown(self) -> None: |
| 1445 | + ray.shutdown(_exiting_interpreter=True) |
| 1446 | + |
| 1447 | + def test_agentscope_tuner(self): |
| 1448 | + try: |
| 1449 | + from agentscope.agent import ReActAgent |
| 1450 | + from agentscope.formatter import OpenAIChatFormatter |
| 1451 | + from agentscope.message import Msg |
| 1452 | + from agentscope.model import ChatModelBase |
| 1453 | + from agentscope.tuner import ( |
| 1454 | + Algorithm, |
| 1455 | + Dataset, |
| 1456 | + JudgeOutput, |
| 1457 | + TunerChatModel, |
| 1458 | + WorkflowOutput, |
| 1459 | + tune, |
| 1460 | + ) |
| 1461 | + except ImportError: |
| 1462 | + self.skipTest("agentscope >= 1.0.12 is not installed") |
| 1463 | + |
| 1464 | + async def workflow_func( |
| 1465 | + task: Dict, |
| 1466 | + model: ChatModelBase, |
| 1467 | + auxiliary_models: Dict[str, ChatModelBase], |
| 1468 | + ) -> WorkflowOutput: |
| 1469 | + assert isinstance(model, ChatModelBase) |
| 1470 | + assert "judge_model" in auxiliary_models |
| 1471 | + assert isinstance(auxiliary_models["judge_model"], ChatModelBase) |
| 1472 | + agent = ReActAgent( |
| 1473 | + name="test_agent", |
| 1474 | + model=model, |
| 1475 | + sys_prompt="You are a helpful assistant.", |
| 1476 | + formatter=OpenAIChatFormatter(), |
| 1477 | + ) |
| 1478 | + st = time.time() |
| 1479 | + response = await agent.reply(Msg("user", task["question"], role="user")) |
| 1480 | + et = time.time() |
| 1481 | + return WorkflowOutput(response=response, metrics={"workflow_time": et - st}) |
| 1482 | + |
| 1483 | + async def judge_func( |
| 1484 | + task: Dict, response: Msg, auxiliary_models: Dict[str, ChatModelBase] |
| 1485 | + ) -> JudgeOutput: |
| 1486 | + assert "judge_model" in auxiliary_models |
| 1487 | + judge_model = auxiliary_models["judge_model"] |
| 1488 | + assert isinstance(judge_model, ChatModelBase) |
| 1489 | + agent = ReActAgent( |
| 1490 | + name="judge_agent", |
| 1491 | + model=judge_model, |
| 1492 | + sys_prompt="You are a judge to evaluate the correctness of answers.", |
| 1493 | + formatter=OpenAIChatFormatter(), |
| 1494 | + ) |
| 1495 | + workflow_text_response = response.get_text_content() |
| 1496 | + st = time.time() |
| 1497 | + judge_response = await agent.reply( |
| 1498 | + Msg( |
| 1499 | + "user", |
| 1500 | + f"Question: {task['question']}\nAnswer: {workflow_text_response}\nIs the answer correct? Reply with 'Yes' or 'No'.", |
| 1501 | + role="user", |
| 1502 | + ) |
| 1503 | + ) |
| 1504 | + et = time.time() |
| 1505 | + judge_response = judge_response.get_text_content() |
| 1506 | + if judge_response is not None and "yes" in judge_response.lower(): |
| 1507 | + is_correct = True |
| 1508 | + else: |
| 1509 | + is_correct = False |
| 1510 | + return JudgeOutput( |
| 1511 | + reward=float(is_correct), |
| 1512 | + metrics={"judge_time": et - st}, |
| 1513 | + ) |
| 1514 | + |
| 1515 | + gsm8k_dataset = get_unittest_dataset_config("gsm8k") |
| 1516 | + |
| 1517 | + dataset = Dataset( |
| 1518 | + path=gsm8k_dataset.path, |
| 1519 | + split="train", |
| 1520 | + total_steps=2, |
| 1521 | + ) |
| 1522 | + eval_dataset = Dataset( |
| 1523 | + path=gsm8k_dataset.path, |
| 1524 | + split="test", |
| 1525 | + ) |
| 1526 | + |
| 1527 | + model = TunerChatModel( |
| 1528 | + model_path=get_model_path(), |
| 1529 | + max_model_len=4096, |
| 1530 | + max_tokens=2048, |
| 1531 | + inference_engine_num=2, |
| 1532 | + ) |
| 1533 | + |
| 1534 | + auxiliary_models = { |
| 1535 | + "judge_model": TunerChatModel( |
| 1536 | + model_path=get_model_path(), |
| 1537 | + max_model_len=8192, |
| 1538 | + max_tokens=2048, |
| 1539 | + inference_engine_num=2, |
| 1540 | + ) |
| 1541 | + } |
| 1542 | + |
| 1543 | + algorithm = Algorithm( |
| 1544 | + algorithm_type="multi_step_grpo", |
| 1545 | + batch_size=4, |
| 1546 | + group_size=4, |
| 1547 | + eval_interval_steps=2, |
| 1548 | + save_interval_steps=2, |
| 1549 | + ) |
| 1550 | + |
| 1551 | + tune( |
| 1552 | + workflow_func=workflow_func, |
| 1553 | + judge_func=judge_func, |
| 1554 | + train_dataset=dataset, |
| 1555 | + eval_dataset=eval_dataset, |
| 1556 | + model=model, |
| 1557 | + auxiliary_models=auxiliary_models, |
| 1558 | + algorithm=algorithm, |
| 1559 | + ) |
| 1560 | + # check checkpoint dir in `./checkpoints/AgentScope/Experiment-<timestamp>` |
| 1561 | + self.assertTrue(os.path.exists("./checkpoints/AgentScope")) |
| 1562 | + exp_dirs = os.listdir("./checkpoints/AgentScope") |
| 1563 | + self.assertGreaterEqual(len(exp_dirs), 1) |
| 1564 | + latest_exp_dir = sorted(exp_dirs)[-1] |
| 1565 | + exp_dir_path = os.path.join("./checkpoints/AgentScope", latest_exp_dir) |
| 1566 | + _, step_num = get_checkpoint_dir_with_step_num( |
| 1567 | + checkpoint_root_path=exp_dir_path, |
| 1568 | + trainer_type="verl", |
| 1569 | + ) |
| 1570 | + self.assertEqual(step_num, 2) |
| 1571 | + # check tensorboard |
| 1572 | + parser = TensorBoardParser(os.path.join(exp_dir_path, "monitor", "tensorboard")) |
| 1573 | + rollout_metrics = parser.metric_list("rollout") |
| 1574 | + self.assertIn("rollout/workflow_time/mean", rollout_metrics) |
| 1575 | + self.assertIn("rollout/judge_time/mean", rollout_metrics) |
| 1576 | + self.assertEqual(parser.metric_max_step(rollout_metrics[0]), 2) |
| 1577 | + eval_metrics = parser.metric_list("eval") |
| 1578 | + self.assertGreater(len(eval_metrics), 0) |
| 1579 | + self.assertEqual(parser.metric_max_step(eval_metrics[0]), 2) |
| 1580 | + actor_metrics = parser.metric_list("actor") |
| 1581 | + self.assertGreater(len(actor_metrics), 0) |
| 1582 | + self.assertEqual(parser.metric_max_step(actor_metrics[0]), 2) |
0 commit comments