diff --git a/tests/common/vllm_test.py b/tests/common/vllm_test.py index ef3a5e7f7a..db7156cd4d 100644 --- a/tests/common/vllm_test.py +++ b/tests/common/vllm_test.py @@ -1,7 +1,7 @@ +import asyncio import os import unittest -import ray import torch from openai import BadRequestError from parameterized import parameterized_class @@ -13,7 +13,6 @@ get_model_path, get_template_config, ) -from trinity.common.constants import ROLLOUT_WEIGHT_SYNC_GROUP_NAME from trinity.common.models import create_inference_models from trinity.common.models.model import ModelWrapper from trinity.common.models.utils import ( @@ -29,6 +28,16 @@ def print_debug(*args): print(*args) +async def prepare_engines(engines, auxiliary_engines): + prepare_model_refs = [] + for engine in engines: + prepare_model_refs.append(engine.prepare.remote()) + for engines in auxiliary_engines: + for engine in engines: + prepare_model_refs.append(engine.prepare.remote()) + await asyncio.gather(*prepare_model_refs) + + # Qwen2.5 chat template with {% generation %} mark CHAT_TEMPLATE = r""" {%- if tools %} @@ -127,6 +136,7 @@ def setUp(self): async def test_generate( self, ): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() self.assertEqual(self.model_wrapper.model_path, self.config.model.model_path) self.assertEqual(await self.model_wrapper.model_path_async, self.config.model.model_path) @@ -244,6 +254,7 @@ def setUp(self): self.tokenizer = AutoTokenizer.from_pretrained(self.config.model.model_path) async def test_model_len(self): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() messages = [ {"role": "system", "content": "You are a helpful assistant."}, @@ -311,6 +322,7 @@ def setUp(self): self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) async def test_model_len(self): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() messages = [ {"role": "user", "content": "How are you?"}, @@ -362,6 +374,7 @@ def setUp(self): ) async def test_api(self): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() await self.model_wrapper_no_history.prepare() openai_client = self.model_wrapper.get_openai_client() @@ -435,12 +448,42 @@ async def test_api(self): self.assertEqual(len(self.model_wrapper_no_history.history), 0) -class DummySynchronizer: - def __init__(self): - pass +SYSTEM_PROMPT = """ +You are Qwen, created by Alibaba Cloud. You are a helpful assistant. You are walking on a frozen lake. + +FrozenLake Quick Guide +Goal: Reach the goal (G). Player (P) and Goal (G) must overlap. + +Symbols: +_ Frozen | O Hole | G Goal | P Player + +Rules: +1. Avoid falling into holes (O). +2. Frozen tiles are slippery, you may move perpendicular to your intended direction. + +Valid Action (separated by | ): +Up | Down | Left | Right + +Rewards: +Fall into hole: 0 +Reach goal: +1.0 + +You will be provided the current observation, please decide on the next Action. +You should show your thought process and then input the final action in ``` ```. +You should only output the NEXT ACTION at each interation in the ``` ```. For example, if you want to move up, you should output ```Up```. +You should plan ahead and need to achieve it in minimum number of steps. +You should be aware that frozen tiles can be slippery, but the chance is small and you should not overthink it. + +Please show your thinking process and put the final action in ``` ```. In every turn, the final action MUST be one of Up, Down, Left, Right. +""" - def do_nothing(self): - pass +USER_PROMPT = """Current Observation (0): + _ G _ + _ _ _ + P O O +You have not achieved the goal, P has not reached G yet. Please give the next action. +The maximum number of steps remaining is 10. +""" class TestLogprobs(RayUnittestBaseAysnc): @@ -458,31 +501,76 @@ def setUp(self): self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) - async def test_logprobs(self): - # use init process group to apply patches - sync = ( - ray.remote(DummySynchronizer) - .options(name="synchronizer", namespace=self.config.ray_namespace) - .remote() - ) - await sync.__ray_ready__.remote() + async def test_logprobs_api(self): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() - master_address, master_port = await self.engines[0].get_available_address.remote() - await self.engines[0].init_process_group.remote( - master_address, - master_port, - world_size=1, - rank_offset=0, - group_name=ROLLOUT_WEIGHT_SYNC_GROUP_NAME, - explorer_name=self.config.explorer.name, - timeout=20, - ) messages = [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "What is your name?"}, + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": USER_PROMPT}, ] - response_1 = self.model_wrapper.chat(messages, n=1, temperature=1.0, logprobs=True)[0] - response_2 = self.model_wrapper.chat(messages, n=1, temperature=0.8, logprobs=True)[0] + + # Test openai api logprobs with different temperature + + self.model_client = self.model_wrapper.get_openai_async_client() + _ = await self.model_client.chat.completions.create( + model=self.model_client.model_path, + messages=messages, + n=1, + temperature=1.0, + logprobs=True, + max_tokens=15, + ) + response_1 = self.model_wrapper.extract_experience_from_history()[0] + _ = await self.model_client.chat.completions.create( + model=self.model_client.model_path, + messages=messages, + n=1, + temperature=0.8, + logprobs=True, + max_tokens=15, + ) + response_2 = self.model_wrapper.extract_experience_from_history()[0] + self.assertTrue(response_1.logprobs is not None) + self.assertTrue(len(response_1.logprobs) > 0) + self.assertTrue(response_2.logprobs is not None) + self.assertTrue(len(response_2.logprobs) > 0) + logprobs_1 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=1.0) + logprobs_2 = self.model_wrapper.logprobs(response_1.tokens.tolist(), temperature=0.8) + logprobs_3 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=1.0) + logprobs_4 = self.model_wrapper.logprobs(response_2.tokens.tolist(), temperature=0.8) + self.assertEqual(logprobs_1.shape, logprobs_2.shape) + self.assertEqual(logprobs_3.shape, logprobs_4.shape) + self.assertFalse(torch.allclose(logprobs_1, logprobs_2, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3, logprobs_4, atol=0.4)) + logprobs_1_prompt = logprobs_1[: response_1.prompt_length - 1] + logprobs_2_prompt = logprobs_2[: response_1.prompt_length - 1] + logprobs_3_prompt = logprobs_3[: response_2.prompt_length - 1] + logprobs_4_prompt = logprobs_4[: response_2.prompt_length - 1] + self.assertEqual(logprobs_1_prompt.shape, logprobs_2_prompt.shape) + self.assertFalse(torch.allclose(logprobs_1_prompt, logprobs_2_prompt, rtol=0.4)) + self.assertFalse(torch.allclose(logprobs_3_prompt, logprobs_4_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_1_prompt, logprobs_3_prompt, rtol=0.4)) + self.assertTrue(torch.allclose(logprobs_2_prompt, logprobs_4_prompt, rtol=0.4)) + logprobs_1_response = logprobs_1[response_1.prompt_length - 1 :] + logprobs_2_response = logprobs_2[response_1.prompt_length - 1 :] + logprobs_3_response = logprobs_3[response_2.prompt_length - 1 :] + logprobs_4_response = logprobs_4[response_2.prompt_length - 1 :] + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(logprobs_3_response.shape, logprobs_4_response.shape) + self.assertEqual(logprobs_1_response.shape, logprobs_2_response.shape) + self.assertEqual(response_1.logprobs.shape, logprobs_1_response.shape) + self.assertTrue(torch.allclose(response_1.logprobs, logprobs_1_response, rtol=0.5)) + self.assertFalse(torch.allclose(response_1.logprobs, logprobs_2_response, rtol=0.5)) + self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) + self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + + # test vllm engine logprobs with different temperature + response_1 = self.model_wrapper.chat( + messages, n=1, temperature=1.0, logprobs=True, max_tokens=15 + )[0] + response_2 = self.model_wrapper.chat( + messages, n=1, temperature=0.8, logprobs=True, max_tokens=15 + )[0] self.assertTrue(response_1.logprobs is not None) self.assertTrue(len(response_1.logprobs) > 0) self.assertTrue(response_2.logprobs is not None) @@ -517,6 +605,56 @@ async def test_logprobs(self): self.assertTrue(torch.allclose(response_2.logprobs, logprobs_4_response, rtol=0.8)) self.assertFalse(torch.allclose(response_2.logprobs, logprobs_3_response, rtol=0.8)) + # test openai api and vllm engine logprobs consistency + await self.model_wrapper.clean_workflow_state() + _ = await self.model_client.chat.completions.create( + model=self.model_client.model_path, + messages=messages, + n=1, + temperature=1.0, + logprobs=0, + max_tokens=1, + ) + response_openai_1 = self.model_wrapper.extract_experience_from_history()[0] + _ = await self.model_client.chat.completions.create( + model=self.model_client.model_path, + messages=messages, + n=1, + temperature=0.8, + logprobs=0, + max_tokens=1, + ) + response_openai_2 = self.model_wrapper.extract_experience_from_history()[0] + response_vllm_1 = self.model_wrapper.chat( + messages, + n=1, + temperature=1.0, + logprobs=0, + max_tokens=1, + )[0] + response_vllm_2 = self.model_wrapper.chat( + messages, + n=1, + temperature=0.8, + logprobs=0, + max_tokens=1, + )[0] + self.assertEqual(len(response_openai_1.tokens), len(response_vllm_1.tokens)) + self.assertTrue( + torch.allclose( + response_openai_1.logprobs, + response_vllm_1.logprobs, + rtol=0.1, + ) + ) + self.assertTrue( + torch.allclose( + response_openai_2.logprobs, + response_vllm_2.logprobs, + rtol=0.1, + ) + ) + class TestAsyncAPIServer(RayUnittestBaseAysnc): def setUp(self): @@ -537,6 +675,7 @@ def setUp(self): ) async def test_api_async(self): + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() await self.model_wrapper_no_history.prepare() openai_client = self.model_wrapper.get_openai_async_client() @@ -758,6 +897,7 @@ async def test_api_tool_calls(self): import json import time + await prepare_engines(self.engines, self.auxiliary_engines) await self.model_wrapper.prepare() await self.model_wrapper_no_history.prepare() tokenizer = AutoTokenizer.from_pretrained(get_api_model_path()) diff --git a/tests/explorer/scheduler_test.py b/tests/explorer/scheduler_test.py index 60c9cf9971..aa8706c27d 100644 --- a/tests/explorer/scheduler_test.py +++ b/tests/explorer/scheduler_test.py @@ -175,6 +175,9 @@ class DummyModel(InferenceModel): def sync_model(self, model_version, update_weight_args_list): return True + async def prepare(self): + return + def get_model_version(self): return 0 @@ -190,7 +193,7 @@ def init_process_group( ) -> None: pass - async def get_api_server_url(self) -> Optional[str]: + def get_api_server_url(self) -> Optional[str]: return None @@ -214,7 +217,7 @@ def init_process_group( ) -> None: pass - async def get_api_server_url(self) -> str: + def get_api_server_url(self) -> str: return "http://localhost:12345" diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index 31b232e148..882786213d 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -475,7 +475,7 @@ def test_workflow_repeatable(self, workflow_cls) -> None: (DummyAsyncMultiTurnWorkflow,), ], ) -class MultiTurnWorkflowTest(unittest.TestCase): +class MultiTurnWorkflowTest(unittest.IsolatedAsyncioTestCase): def setUp(self): # configure the model self.config = get_template_config() @@ -490,7 +490,8 @@ def setUp(self): self.engines, self.auxiliary_engines = create_inference_models(self.config) self.model_wrapper = ModelWrapper(self.engines[0], engine_type="vllm", enable_history=True) - def test_multi_turn_workflow(self): + async def test_multi_turn_workflow(self): + await asyncio.gather(*[engine.prepare.remote() for engine in self.engines]) task = Task( workflow=self.workflow_cls, repeat_times=3, @@ -500,7 +501,7 @@ def test_multi_turn_workflow(self): workflow = task.to_workflow(self.model_wrapper) workflow.set_repeat_times(2, run_id_base=0) if workflow.asynchronous: - answer = asyncio.run(workflow.run_async()) + answer = await workflow.run_async() else: answer = workflow.run() self.assertEqual(len(answer), 2) @@ -727,7 +728,7 @@ async def test_workflow_with_openai(self): config.explorer.rollout_model.enable_history = True config.check_and_update() engines, auxiliary_engines = create_inference_models(config) - + await asyncio.gather(*[engine.prepare.remote() for engine in engines]) runner = WorkflowRunner( config, model=engines[0], diff --git a/trinity/common/config.py b/trinity/common/config.py index 904aac0581..2e878395bb 100644 --- a/trinity/common/config.py +++ b/trinity/common/config.py @@ -456,10 +456,12 @@ class ModelConfig: # the maximum number of tokens for the response max_response_tokens: Optional[int] = None # the minimum number of tokens for the response - min_response_tokens: int = 1 + min_response_tokens: int = 0 # whether to truncate the prompt; if set to True, the prompt will be truncated to `max_prompt_tokens` tokens; # not applicable for OpenAI API enable_prompt_truncation: bool = True + # repetition penalty for response generation + repetition_penalty: float = 1.0 # lora config lora_configs: Optional[List[LoRAConfig]] = None @@ -474,7 +476,7 @@ class ModelConfig: @dataclass class InferenceModelConfig: # ! DO NOT SET in explorer.rollout_model, automatically set from config.model.model_path - model_path: str = "" + model_path: Optional[str] = None engine_type: str = "vllm" engine_num: int = 1 @@ -503,6 +505,8 @@ class InferenceModelConfig: min_response_tokens: Optional[int] = None # if not set, use `model.enable_prompt_truncation` enable_prompt_truncation: Optional[bool] = None + # If not set, use `model.repetition_penalty` + repetition_penalty: Optional[float] = None # used for testing very long response generation, do not set it unless you know what you are doing ignore_eos: bool = False @@ -517,6 +521,7 @@ class InferenceModelConfig: # For OpenAI API enable_openai_api: bool = False + enable_log_requests: bool = False # whether to enable request logging in vLLM API server # For tool calls in OpenAI API enable_auto_tool_choice: bool = False @@ -1275,7 +1280,7 @@ def check_and_update(self) -> Config: # noqa: C901 # check explorer if self.explorer is not None: - rollout_args = ["temperature", "top_p", "top_k", "logprobs"] + rollout_args = ["temperature", "top_p", "top_k", "logprobs", "repetition_penalty"] length_args = [ "max_model_len", "max_prompt_tokens", @@ -1286,7 +1291,7 @@ def check_and_update(self) -> Config: # noqa: C901 rope_args = ["rope_scaling", "rope_theta"] model_args = rollout_args + length_args + rope_args for args in ["model_path"] + model_args: - setattr(self.explorer.rollout_model, args, getattr(self.model, args)) + set_if_none(self.explorer.rollout_model, args, getattr(self.model, args)) if ( self.explorer.rollout_model.chat_template is None and self.model.custom_chat_template is not None diff --git a/trinity/common/models/__init__.py b/trinity/common/models/__init__.py index ffdc98c070..684f06ec21 100644 --- a/trinity/common/models/__init__.py +++ b/trinity/common/models/__init__.py @@ -152,11 +152,10 @@ def create_debug_inference_model(config: Config) -> None: model.engine_num = 1 rollout_models, auxiliary_models = create_inference_models(config) # make sure models are started - for m in rollout_models: - ray.get(m.run_api_server.remote()) - for models in auxiliary_models: - for m in models: - ray.get(m.run_api_server.remote()) + prepare_refs = [] + prepare_refs = [m.prepare.remote() for m in rollout_models] + prepare_refs.extend(m.prepare.remote() for models in auxiliary_models for m in models) + ray.get(prepare_refs) logger.info( "----------------------------------------------------\n" "Inference models started successfully for debugging.\n" diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index 483915b893..e08062f401 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -44,6 +44,10 @@ async def convert_messages_to_experience( """Convert a list of messages into an experience in async.""" raise NotImplementedError + async def prepare(self) -> None: + """Prepare the model before inference.""" + pass + @abstractmethod def get_model_version(self) -> int: """Get the checkpoint version.""" @@ -56,7 +60,7 @@ def get_available_address(self) -> Tuple[str, int]: port = s.getsockname()[1] return address, port - async def get_api_server_url(self) -> Optional[str]: + def get_api_server_url(self) -> Optional[str]: """Get the API server URL if available.""" return None diff --git a/trinity/common/models/vllm_model.py b/trinity/common/models/vllm_model.py index ba6ccf0f5f..092cafcee3 100644 --- a/trinity/common/models/vllm_model.py +++ b/trinity/common/models/vllm_model.py @@ -106,6 +106,7 @@ def __init__( "top_p": config.top_p, "top_k": config.top_k, "max_new_tokens": config.max_response_tokens, + "repetition_penalty": config.repetition_penalty, }, disable_log_stats=True, enable_lora=config.enable_lora, @@ -114,7 +115,7 @@ def __init__( **config.lora_kwargs, ) if get_vllm_version() > parse_version("0.10.0"): - engine_args.enable_log_requests = False + engine_args.enable_log_requests = True else: engine_args.disable_log_requests = True if get_vllm_version() >= parse_version("0.11.0"): @@ -131,6 +132,7 @@ def __init__( self.api_server_host = None self.api_server_port = None self.api_server = None + self._prepared = False self.async_lock = asyncio.Lock() async def _initialize_tokenizer(self): @@ -144,6 +146,17 @@ def _initialize_processor(self): ) self.tokenizer = self.processor.tokenizer + async def prepare( + self, + ) -> None: + """Prepare the model for inference.""" + async with self.async_lock: + if self._prepared: + return + await self._collective_rpc("apply_patches") + await self.run_api_server() + self._prepared = True + async def chat( self, messages: List[Dict], lora_request: LoRARequest = None, **kwargs ) -> Sequence[Experience]: @@ -540,40 +553,45 @@ async def run_api_server(self) -> bool: Returns: success (bool): Whether the API server is started successfully. """ - async with self.async_lock: - if not self.config.enable_openai_api: - return False # Not enabled + if not self.config.enable_openai_api: + self.logger.info("OpenAI API server is not enabled. Skipping...") + return False # Not enabled - if self.api_server_host is not None and self.api_server_port is not None: - return True # already running + if self.api_server_host is not None and self.api_server_port is not None: + self.logger.info("OpenAI API server is already running. Skipping...") + return True # already running - from trinity.common.models.vllm_patch.api_patch import ( - run_api_server_in_ray_actor, - ) + from trinity.common.models.vllm_patch.api_patch import ( + run_api_server_in_ray_actor, + ) - api_server_host, api_server_port = self.get_available_address() - self.api_server = asyncio.create_task( - run_api_server_in_ray_actor( - self.async_llm, - api_server_host, - api_server_port, - self.config.model_path, - self.config.enable_auto_tool_choice, - self.config.tool_call_parser, - self.config.reasoning_parser, - ) + api_server_host, api_server_port = self.get_available_address() + self.api_server = asyncio.create_task( + run_api_server_in_ray_actor( + self.async_llm, + api_server_host, + api_server_port, + self.config.model_path, # type: ignore [arg-type] + self.config.enable_auto_tool_choice, + self.config.tool_call_parser, + self.config.reasoning_parser, + self.config.enable_log_requests, ) - self.api_server_host = api_server_host - self.api_server_port = api_server_port - return True + ) + self.api_server_host = api_server_host + self.api_server_port = api_server_port + return True - async def get_api_server_url(self) -> Optional[str]: + def get_api_server_url(self) -> Optional[str]: """Get the URL of the OpenAI API server. Returns: api_url (str): The URL of the OpenAI API server. """ - if not await self.run_api_server(): + if not self._prepared: + raise RuntimeError("Model is not prepared. Please call `prepare()` first.") + if self.api_server_host is None or self.api_server_port is None: + # openai api is not enabled return None return f"http://{self.api_server_host}:{self.api_server_port}" @@ -584,7 +602,7 @@ def get_model_version(self) -> int: return self.model_version def get_model_path(self) -> str: - return self.config.model_path + return self.config.model_path # type: ignore [return-value] def get_lora_request(self, lora_path: Optional[str] = None) -> LoRARequest: assert self.config.lora_modules is not None diff --git a/trinity/common/models/vllm_patch/api_patch.py b/trinity/common/models/vllm_patch/api_patch.py index 71b61f1c5f..0036b0956c 100644 --- a/trinity/common/models/vllm_patch/api_patch.py +++ b/trinity/common/models/vllm_patch/api_patch.py @@ -335,6 +335,7 @@ async def run_api_server_in_ray_actor( enable_auto_tool_choice: bool = False, tool_call_parser: Optional[str] = None, reasoning_parser: Optional[str] = None, + enable_log_requests: bool = False, ): vllm_version = get_vllm_version() if vllm_version < parse_version("0.8.5") or vllm_version > parse_version("0.11.0"): @@ -354,6 +355,8 @@ async def run_api_server_in_ray_actor( model_path, "--enable-server-load-tracking", # enable tracking for load balancing ] + if enable_log_requests: + cli_args.append("--enable-log-requests") if enable_auto_tool_choice: cli_args.append("--enable-auto-tool-choice") if tool_call_parser: diff --git a/trinity/common/models/vllm_worker.py b/trinity/common/models/vllm_worker.py index 1cd9a95a2c..810b0d0d55 100644 --- a/trinity/common/models/vllm_worker.py +++ b/trinity/common/models/vllm_worker.py @@ -12,6 +12,11 @@ class WorkerExtension: + def apply_patches(self): + """Apply necessary patches to vLLM.""" + patch_vllm_moe_model_weight_loader(self.model_runner.model) + patch_vllm_prompt_logprobs(self.model_runner) + def init_process_group( self, master_address: str, @@ -56,8 +61,6 @@ def init_process_group( self._namespace = namespace self.synchronizer = Synchronizer.get_actor(namespace=self._namespace) self._checkpoint_converter = None - patch_vllm_moe_model_weight_loader(self.model_runner.model) - patch_vllm_prompt_logprobs(self.model_runner) def update_weight(self): """Broadcast weight to all vllm workers from source rank 0 (actor model)""" diff --git a/trinity/explorer/explorer.py b/trinity/explorer/explorer.py index cba0aa96c8..458c1ba626 100644 --- a/trinity/explorer/explorer.py +++ b/trinity/explorer/explorer.py @@ -166,11 +166,9 @@ async def prepare(self) -> None: await self.experience_pipeline.prepare.remote() self.logger.info("Experience pipeline is ready.") # make sure all rollout models are ready - run_api_ref = [model.run_api_server.remote() for model in self.models] + run_api_ref = [model.prepare.remote() for model in self.models] run_api_ref.extend( - model.run_api_server.remote() - for models in self.auxiliary_models - for model in models + model.prepare.remote() for models in self.auxiliary_models for model in models ) await asyncio.gather(*run_api_ref) self.logger.info("All models are ready.") diff --git a/trinity/explorer/workflow_runner.py b/trinity/explorer/workflow_runner.py index c5c8b01eb1..dddff55719 100644 --- a/trinity/explorer/workflow_runner.py +++ b/trinity/explorer/workflow_runner.py @@ -9,7 +9,7 @@ from typing import Dict, List, Optional, Tuple from trinity.buffer import get_buffer_reader, get_buffer_writer -from trinity.common.config import Config, ExperienceBufferConfig +from trinity.common.config import Config, StorageConfig from trinity.common.experience import Experience from trinity.common.models import get_debug_inference_model from trinity.common.models.model import InferenceModel, ModelWrapper @@ -247,12 +247,13 @@ def __init__( "experiences.db", ) self.sqlite_writer = get_buffer_writer( - ExperienceBufferConfig( + StorageConfig( name="debug_buffer", schema_type="experience", path=self.output_sqlite_file, storage_type="sql", batch_size=1, + wrap_in_ray=False, ) )