66
77import ray
88from datasets import Dataset
9- from sqlalchemy import asc
9+ from sqlalchemy import asc , desc
1010from sqlalchemy .orm import sessionmaker
1111
1212from trinity .buffer .schema import init_engine
@@ -88,29 +88,33 @@ def release(self) -> int:
8888
8989
9090class SQLExperienceStorage (SQLStorage ):
91+ """Used as trainer input."""
92+
9193 def __init__ (self , storage_config : StorageConfig , config : BufferConfig ) -> None :
9294 super ().__init__ (storage_config , config )
9395 self .batch_size = config .train_batch_size
9496 self .max_timeout = storage_config .max_read_timeout
97+ # TODO: optimize the following logic
98+ if storage_config .schema_type == "experience" :
99+ # NOTE: consistent with the old version of experience buffer
100+ self ._read_method = self ._read_priority
101+ else :
102+ # SFT / DPO uses FIFO style
103+ self ._read_method = self ._read_fifo
95104
96105 def write (self , data : List [Experience ]) -> None :
97106 with retry_session (self .session , self .max_retry_times , self .max_retry_interval ) as session :
98107 experience_models = [self .table_model_cls .from_experience (exp ) for exp in data ]
99108 session .add_all (experience_models )
109+ self .logger .info (f"Write { len (experience_models )} experiences to SQL storage." )
100110
101- def read (self , batch_size : Optional [int ] = None ) -> List [Experience ]:
102- if self .stopped :
103- raise StopIteration ()
104-
111+ def _read_fifo (self , batch_size : int ) -> List [Experience ]:
112+ """Read experiences in FIFO order."""
105113 exp_list = []
106- batch_size = batch_size or self .batch_size # type: ignore
107114 start_time = time .time ()
108115 while len (exp_list ) < batch_size :
109116 if self .stopped :
110117 raise StopIteration ()
111- if len (exp_list ):
112- self .logger .info (f"Waiting for { batch_size - len (exp_list )} more experiences..." )
113- time .sleep (1 )
114118 if time .time () - start_time > self .max_timeout :
115119 self .logger .warning (
116120 f"Max read timeout reached ({ self .max_timeout } s), only get { len (exp_list )} experiences, stopping..."
@@ -131,8 +135,61 @@ def read(self, batch_size: Optional[int] = None) -> List[Experience]:
131135 self .offset = experiences [- 1 ].id
132136 start_time = time .time ()
133137 exp_list .extend ([self .table_model_cls .to_experience (exp ) for exp in experiences ])
138+ if len (exp_list ) < batch_size :
139+ self .logger .info (f"Waiting for { batch_size - len (exp_list )} more experiences..." )
140+ time .sleep (1 )
134141 return exp_list
135142
143+ def _read_priority (self , batch_size : int ) -> List [Experience ]:
144+ exp_list = []
145+ start_time = time .time ()
146+ latest_size = 0
147+ while latest_size < batch_size :
148+ if self .stopped :
149+ raise StopIteration ()
150+ if time .time () - start_time > self .max_timeout :
151+ self .logger .warning (
152+ f"Max read timeout reached ({ self .max_timeout } s), only get { latest_size } experiences, stopping..."
153+ )
154+ raise StopIteration ()
155+ with retry_session (
156+ self .session , self .max_retry_times , self .max_retry_interval
157+ ) as session :
158+ experiences = (
159+ session .query (self .table_model_cls )
160+ .order_by (asc (self .table_model_cls .consumed ), desc (self .table_model_cls .id ))
161+ .limit (batch_size )
162+ .with_for_update ()
163+ .all ()
164+ )
165+ if len (experiences ) != batch_size :
166+ if latest_size != len (experiences ):
167+ latest_size = len (experiences )
168+ start_time = time .time ()
169+ else :
170+ ids = [exp .id for exp in experiences ]
171+ session .query (self .table_model_cls ).filter (
172+ self .table_model_cls .id .in_ (ids )
173+ ).update (
174+ {self .table_model_cls .consumed : self .table_model_cls .consumed + 1 },
175+ synchronize_session = False ,
176+ )
177+ exp_list .extend (
178+ [self .table_model_cls .to_experience (exp ) for exp in experiences ]
179+ )
180+ break
181+
182+ self .logger .info (f"Waiting for { batch_size - len (exp_list )} more experiences..." )
183+ time .sleep (1 )
184+ return exp_list
185+
186+ def read (self , batch_size : Optional [int ] = None ) -> List [Experience ]:
187+ if self .stopped :
188+ raise StopIteration ()
189+
190+ batch_size = batch_size or self .batch_size
191+ return self ._read_method (batch_size )
192+
136193 @classmethod
137194 def load_from_dataset (
138195 cls , dataset : Dataset , storage_config : StorageConfig , config : BufferConfig
@@ -158,6 +215,8 @@ def load_from_dataset(
158215
159216
160217class SQLTaskStorage (SQLStorage ):
218+ """Used as explorer input."""
219+
161220 def __init__ (self , storage_config : StorageConfig , config : BufferConfig ) -> None :
162221 super ().__init__ (storage_config , config )
163222 self .batch_size = config .batch_size
0 commit comments