Skip to content

Commit aac66b8

Browse files
committed
fix unittest
1 parent e99d23a commit aac66b8

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tests/explorer/workflow_test.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,12 +731,16 @@ async def mock_get_api_server_url_remote():
731731
async def mock_get_model_version_remote():
732732
return 1
733733

734+
async def mock_get_api_key_remote():
735+
return "dummy_api_key"
736+
734737
async def mock_get_model_config_remote():
735738
return InferenceModelConfig(model_path="dummy_model")
736739

737740
model = MagicMock()
738741
model.get_api_server_url.remote = MagicMock(side_effect=mock_get_api_server_url_remote)
739742
model.get_model_version.remote = MagicMock(side_effect=mock_get_model_version_remote)
743+
model.get_api_key.remote = MagicMock(side_effect=mock_get_api_key_remote)
740744
model.get_model_config.remote = MagicMock(side_effect=mock_get_model_config_remote)
741745

742746
runner = WorkflowRunner(

trinity/trainer/verl/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ def to_data_proto(
100100
)
101101
else:
102102
raise ValueError("Custom fields are not consistent across experiences.")
103-
meta_info = {"model_versions": np.array([exp.info.get("model_version", 0) for exp in experiences])}
103+
meta_info = {
104+
"model_versions": np.array([exp.info.get("model_version", 0) for exp in experiences])
105+
}
104106
return DataProto.from_single_dict(batch_dict, meta_info=meta_info)
105107

106108

0 commit comments

Comments
 (0)