1+ import json
2+ import os
13import time
24from typing import List , Optional
35
810from sqlalchemy .pool import NullPool
911
1012from trinity .buffer .schema import Base , create_dynamic_table
11- from trinity .buffer .utils import retry_session
13+ from trinity .buffer .utils import default_storage_path , retry_session
1214from trinity .common .config import BufferConfig , StorageConfig
1315from trinity .common .constants import ReadStrategy
16+ from trinity .common .experience import Experience
17+ from trinity .common .workflows import Task
1418from trinity .utils .log import get_logger
1519
1620
@@ -27,6 +31,8 @@ class DBWrapper:
2731
2832 def __init__ (self , storage_config : StorageConfig , config : BufferConfig ) -> None :
2933 self .logger = get_logger (__name__ )
34+ if storage_config .path is None :
35+ storage_config .path = default_storage_path (storage_config , config )
3036 self .engine = create_engine (storage_config .path , poolclass = NullPool )
3137 self .table_model_cls = create_dynamic_table (
3238 storage_config .algorithm_type , storage_config .name
@@ -61,7 +67,9 @@ def write(self, data: list) -> None:
6167 experience_models = [self .table_model_cls .from_experience (exp ) for exp in data ]
6268 session .add_all (experience_models )
6369
64- def read (self , strategy : Optional [ReadStrategy ] = None ) -> List :
70+ def read (
71+ self , batch_size : Optional [int ] = None , strategy : Optional [ReadStrategy ] = None
72+ ) -> List :
6573 if strategy is None :
6674 strategy = ReadStrategy .LFU
6775
@@ -78,7 +86,8 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
7886 raise NotImplementedError (f"Unsupported strategy { strategy } by SQLStorage" )
7987
8088 exp_list = []
81- while len (exp_list ) < self .batch_size :
89+ batch_size = batch_size or self .batch_size
90+ while len (exp_list ) < batch_size :
8291 if len (exp_list ):
8392 self .logger .info ("waiting for experiences..." )
8493 time .sleep (1 )
@@ -90,7 +99,7 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
9099 session .query (self .table_model_cls )
91100 .filter (self .table_model_cls .reward .isnot (None ))
92101 .order_by (* sortOrder ) # TODO: very slow
93- .limit (self . batch_size - len (exp_list ))
102+ .limit (batch_size - len (exp_list ))
94103 .with_for_update ()
95104 .all ()
96105 )
@@ -103,3 +112,63 @@ def read(self, strategy: Optional[ReadStrategy] = None) -> List:
103112 self .logger .info (f"first prompt_text = { exp_list [0 ].prompt_text } " )
104113 self .logger .info (f"first response_text = { exp_list [0 ].response_text } " )
105114 return exp_list
115+
116+
117+ class _Encoder (json .JSONEncoder ):
118+ def default (self , o ):
119+ if isinstance (o , Experience ):
120+ return o .to_dict ()
121+ if isinstance (o , Task ):
122+ return o .to_dict ()
123+ return super ().default (o )
124+
125+
126+ class FileWrapper :
127+ """
128+ A wrapper of a local jsonl file.
129+
130+ If `wrap_in_ray` in `StorageConfig` is `True`, this class will be run as
131+ a Ray Actor, and provide a remote interface to the local file.
132+
133+ This wrapper is only for writing, if you want to read from the file, use
134+ StorageType.QUEUE instead.
135+ """
136+
137+ def __init__ (self , storage_config : StorageConfig , config : BufferConfig ) -> None :
138+ if storage_config .path is None :
139+ storage_config .path = default_storage_path (storage_config , config )
140+ ext = os .path .splitext (storage_config .path )[- 1 ]
141+ if ext != ".jsonl" and ext != ".json" :
142+ raise ValueError (
143+ f"File path must end with '.json' or '.jsonl', got { storage_config .path } "
144+ )
145+ self .file = open (storage_config .path , "a" , encoding = "utf-8" )
146+ self .encoder = _Encoder (ensure_ascii = False )
147+
148+ @classmethod
149+ def get_wrapper (cls , storage_config : StorageConfig , config : BufferConfig ):
150+ if storage_config .wrap_in_ray :
151+ return (
152+ ray .remote (cls )
153+ .options (
154+ name = f"json-{ storage_config .name } " ,
155+ get_if_exists = True ,
156+ )
157+ .remote (storage_config , config )
158+ )
159+ else :
160+ return cls (storage_config , config )
161+
162+ def write (self , data : List ) -> None :
163+ for item in data :
164+ json_str = self .encoder .encode (item )
165+ self .file .write (json_str + "\n " )
166+ self .file .flush ()
167+
168+ def read (self ) -> List :
169+ raise NotImplementedError (
170+ "read() is not implemented for FileWrapper, please use QUEUE instead"
171+ )
172+
173+ def finish (self ) -> None :
174+ self .file .close ()
0 commit comments