55
66import ray
77
8+ from trinity .buffer .writer .file_writer import JSONWriter
89from trinity .buffer .writer .sql_writer import SQLWriter
910from trinity .common .config import BufferConfig , StorageConfig
1011from trinity .common .constants import StorageType
1112
1213
14+ def is_database_url (path : str ) -> bool :
15+ return any (path .startswith (prefix ) for prefix in ["sqlite:///" , "postgresql://" , "mysql://" ])
16+
17+
18+ def is_json_file (path : str ) -> bool :
19+ return path .endswith (".json" ) or path .endswith (".jsonl" )
20+
21+
1322@ray .remote
1423class QueueActor :
1524 """An asyncio.Queue based queue actor."""
@@ -21,12 +30,21 @@ def __init__(self, storage_config: StorageConfig, config: BufferConfig) -> None:
2130 self .capacity = getattr (config , "capacity" , 10000 )
2231 self .queue = asyncio .Queue (self .capacity )
2332 if storage_config .path is not None and len (storage_config .path ) > 0 :
24- sql_config = deepcopy (storage_config )
25- sql_config .storage_type = StorageType .SQL
26- sql_config .wrap_in_ray = False
27- self .sql_writer = SQLWriter (sql_config , self .config )
33+ if is_database_url (storage_config .path ):
34+ storage_config .storage_type = StorageType .SQL
35+ sql_config = deepcopy (storage_config )
36+ sql_config .storage_type = StorageType .SQL
37+ sql_config .wrap_in_ray = False
38+ self .writer = SQLWriter (sql_config , self .config )
39+ elif is_json_file (storage_config .path ):
40+ storage_config .storage_type = StorageType .FILE
41+ json_config = deepcopy (storage_config )
42+ json_config .storage_type = StorageType .FILE
43+ self .writer = JSONWriter (json_config , self .config )
44+ else :
45+ self .writer = None
2846 else :
29- self .sql_writer = None
47+ self .writer = None
3048
3149 def length (self ) -> int :
3250 """The length of the queue."""
@@ -35,8 +53,8 @@ def length(self) -> int:
3553 async def put_batch (self , exp_list : List ) -> None :
3654 """Put batch of experience."""
3755 await self .queue .put (exp_list )
38- if self .sql_writer is not None :
39- self .sql_writer .write (exp_list )
56+ if self .writer is not None :
57+ self .writer .write (exp_list )
4058
4159 async def finish (self ) -> None :
4260 """Stop the queue."""
0 commit comments