@@ -79,6 +79,7 @@ def __init__(
7979 ] = "memory" ,
8080 storage_uri : str | None = None ,
8181 schema : str = "public" ,
82+ max_worker : int = 4 ,
8283 ):
8384 """
8485 Initialize the DataCollector with configuration options.
@@ -101,15 +102,17 @@ def __init__(
101102 URI or path corresponding to the selected storage backend.
102103 schema: str
103104 Schema name used for PostgreSQL storage.
104-
105+ max_worker : int
106+ Maximum number of worker threads used for flushing collected data asynchronously
105107 """
106108 super ().__init__ (
107109 model = model ,
108110 model_reporters = model_reporters ,
109111 agent_reporters = agent_reporters ,
110112 trigger = trigger ,
111113 reset_memory = reset_memory ,
112- storage = storage , # literal won't work
114+ storage = storage ,
115+ max_workers = max_worker ,
113116 )
114117 self ._writers = {
115118 "csv" : self ._write_csv_local ,
@@ -120,6 +123,8 @@ def __init__(
120123 }
121124 self ._storage_uri = storage_uri
122125 self ._schema = schema
126+ self ._current_model_step = None
127+ self ._batch_id = None
123128
124129 self ._validate_inputs ()
125130
@@ -130,28 +135,42 @@ def _collect(self):
130135 This method checks for the presence of model and agent reporters
131136 and calls the appropriate collection routines for each.
132137 """
138+ if (
139+ self ._current_model_step is None
140+ or self ._current_model_step != self ._model .steps
141+ ):
142+ self ._current_model_step = self ._model .steps
143+ self ._batch_id = 0
144+
133145 if self ._model_reporters :
134- self ._collect_model_reporters ()
146+ self ._collect_model_reporters (
147+ current_model_step = self ._current_model_step , batch_id = self ._batch_id
148+ )
135149
136150 if self ._agent_reporters :
137- self ._collect_agent_reporters ()
151+ self ._collect_agent_reporters (
152+ current_model_step = self ._current_model_step , batch_id = self ._batch_id
153+ )
154+
155+ self ._batch_id += 1
138156
139- def _collect_model_reporters (self ):
157+ def _collect_model_reporters (self , current_model_step : int , batch_id : int ):
140158 """
141159 Collect model-level data using the model_reporters.
142160
143161 Creates a LazyFrame containing the step, seed, and values
144162 returned by each model reporter. Appends the LazyFrame to internal storage.
145163 """
146164 model_data_dict = {}
147- model_data_dict ["step" ] = self . _model . _steps
165+ model_data_dict ["step" ] = current_model_step
148166 model_data_dict ["seed" ] = str (self .seed )
167+ model_data_dict ["batch" ] = batch_id
149168 for column_name , reporter in self ._model_reporters .items ():
150169 model_data_dict [column_name ] = reporter (self ._model )
151170 model_lazy_frame = pl .LazyFrame ([model_data_dict ])
152- self ._frames .append (("model" , str ( self . _model . _steps ) , model_lazy_frame ))
171+ self ._frames .append (("model" , current_model_step , batch_id , model_lazy_frame ))
153172
154- def _collect_agent_reporters (self ):
173+ def _collect_agent_reporters (self , current_model_step : int , batch_id : int ):
155174 """
156175 Collect agent-level data using the agent_reporters.
157176
@@ -164,15 +183,16 @@ def _collect_agent_reporters(self):
164183 for k , v in self ._model .agents [reporter ].items ():
165184 agent_data_dict [col_name + "_" + str (k .__class__ .__name__ )] = v
166185 else :
167- agent_data_dict [col_name ] = reporter (self ._model . agents )
186+ agent_data_dict [col_name ] = reporter (self ._model )
168187 agent_lazy_frame = pl .LazyFrame (agent_data_dict )
169188 agent_lazy_frame = agent_lazy_frame .with_columns (
170189 [
171- pl .lit (self . _model . _steps ).alias ("step" ),
190+ pl .lit (current_model_step ).alias ("step" ),
172191 pl .lit (str (self .seed )).alias ("seed" ),
192+ pl .lit (batch_id ).alias ("batch" ),
173193 ]
174194 )
175- self ._frames .append (("agent" , str ( self . _model . _steps ) , agent_lazy_frame ))
195+ self ._frames .append (("agent" , current_model_step , batch_id , agent_lazy_frame ))
176196
177197 @property
178198 def data (self ) -> dict [str , pl .DataFrame ]:
@@ -185,96 +205,108 @@ def data(self) -> dict[str, pl.DataFrame]:
185205 A dictionary with keys "model" and "agent" mapping to concatenated DataFrames of collected data.
186206 """
187207 model_frames = [
188- lf .collect () for kind , step , lf in self ._frames if kind == "model"
208+ lf .collect () for kind , step , batch_id , lf in self ._frames if kind == "model"
189209 ]
190210 agent_frames = [
191- lf .collect () for kind , step , lf in self ._frames if kind == "agent"
211+ lf .collect () for kind , step , batch_id , lf in self ._frames if kind == "agent"
192212 ]
193213 return {
194214 "model" : pl .concat (model_frames ) if model_frames else pl .DataFrame (),
195215 "agent" : pl .concat (agent_frames ) if agent_frames else pl .DataFrame (),
196216 }
197217
198- def _flush (self ):
218+ def _flush (self , frames_to_flush : list ):
199219 """
200220 Flush the collected data to the configured external storage backend.
201221
202222 Uses the appropriate writer function based on the specified storage option.
203223 """
204- self ._writers [self ._storage ](self ._storage_uri )
224+ self ._writers [self ._storage ](
225+ uri = self ._storage_uri , frames_to_flush = frames_to_flush
226+ )
205227
206- def _write_csv_local (self , uri : str ):
228+ def _write_csv_local (self , uri : str , frames_to_flush : list ):
207229 """
208230 Write collected data to local CSV files.
209231
210232 Parameters
211233 ----------
212234 uri : str
213235 Local directory path to write files into.
236+ frames_to_flush : list
237+ the collected data in the current thread.
214238 """
215- for kind , step , df in self . _frames :
216- df .collect ().write_csv (f"{ uri } /{ kind } _step{ step } .csv" )
239+ for kind , step , batch , df in frames_to_flush :
240+ df .collect ().write_csv (f"{ uri } /{ kind } _step{ step } _batch { batch } .csv" )
217241
218- def _write_parquet_local (self , uri : str ):
242+ def _write_parquet_local (self , uri : str , frames_to_flush : list ):
219243 """
220244 Write collected data to local Parquet files.
221245
222246 Parameters
223247 ----------
224248 uri: str
225249 Local directory path to write files into.
250+ frames_to_flush : list
251+ the collected data in the current thread.
226252 """
227- for kind , step , df in self . _frames :
228- df .collect ().write_parquet (f"{ uri } /{ kind } _step{ step } .parquet" )
253+ for kind , step , batch , df in frames_to_flush :
254+ df .collect ().write_parquet (f"{ uri } /{ kind } _step{ step } _batch { batch } .parquet" )
229255
230- def _write_csv_s3 (self , uri : str ):
256+ def _write_csv_s3 (self , uri : str , frames_to_flush : list ):
231257 """
232258 Write collected data to AWS S3 in CSV format.
233259
234260 Parameters
235261 ----------
236262 uri: str
237263 S3 URI (e.g., s3://bucket/path) to upload files to.
264+ frames_to_flush : list
265+ the collected data in the current thread.
238266 """
239- self ._write_s3 (uri , format_ = "csv" )
267+ self ._write_s3 (uri = uri , frames_to_flush = frames_to_flush , format_ = "csv" )
240268
241- def _write_parquet_s3 (self , uri : str ):
269+ def _write_parquet_s3 (self , uri : str , frames_to_flush : list ):
242270 """
243271 Write collected data to AWS S3 in Parquet format.
244272
245273 Parameters
246274 ----------
247275 uri: str
248276 S3 URI (e.g., s3://bucket/path) to upload files to.
277+ frames_to_flush : list
278+ the collected data in the current thread.
249279 """
250- self ._write_s3 (uri , format_ = "parquet" )
280+ self ._write_s3 (uri = uri , frames_to_flush = frames_to_flush , format_ = "parquet" )
251281
252- def _write_s3 (self , uri : str , format_ : str ):
282+ def _write_s3 (self , uri : str , frames_to_flush : list , format_ : str ):
253283 """
254284 Upload collected data to S3 in a specified format.
255285
256286 Parameters
257287 ----------
258288 uri: str
259289 S3 URI to upload to.
290+ frames_to_flush : list
291+ the collected data in the current thread.
260292 format_: str
261293 Format of the output files ("csv" or "parquet").
262294 """
263295 s3 = boto3 .client ("s3" )
264296 parsed = urlparse (uri )
265297 bucket = parsed .netloc
266298 prefix = parsed .path .lstrip ("/" )
267- for kind , step , lf in self . _frames :
299+ for kind , step , batch , lf in frames_to_flush :
268300 df = lf .collect ()
269301 with tempfile .NamedTemporaryFile (suffix = f".{ format_ } " ) as tmp :
270302 if format_ == "csv" :
271303 df .write_csv (tmp .name )
272304 elif format_ == "parquet" :
273305 df .write_parquet (tmp .name )
274- key = f"{ prefix } /{ kind } _step{ step } .{ format_ } "
306+ key = f"{ prefix } /{ kind } _step{ step } _batch { batch } .{ format_ } "
275307 s3 .upload_file (tmp .name , bucket , key )
276308
277- def _write_postgres (self , uri : str ):
309+ def _write_postgres (self , uri : str , frames_to_flush : list ):
278310 """
279311 Write collected data to a PostgreSQL database.
280312
@@ -285,10 +317,12 @@ def _write_postgres(self, uri: str):
285317 ----------
286318 uri: str
287319 PostgreSQL connection URI in the form postgresql://testuser:testpass@localhost:5432/testdb
320+ frames_to_flush : list
321+ the collected data in the current thread.
288322 """
289323 conn = self ._get_db_connection (uri = uri )
290324 cur = conn .cursor ()
291- for kind , step , lf in self . _frames :
325+ for kind , step , batch , lf in frames_to_flush :
292326 df = lf .collect ()
293327 table = f"{ kind } _data"
294328 cols = df .columns
0 commit comments