Skip to content

Commit 638f7e6

Browse files
committed
apply suggestions from gemini
1 parent 8aedc04 commit 638f7e6

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

tests/buffer/sql_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
4343
config.replay_buffer = ReplayBufferConfig(enable=True)
4444
writer_config = deepcopy(config)
4545
writer_config.batch_size = put_batch_size
46+
# Create buffer by writer, so buffer.batch_size will be set to put_batch_size
47+
# This will check whether read_batch_size tasks effect
4648
sql_writer = SQLWriter(writer_config.to_storage_config())
4749
sql_reader = SQLReader(config.to_storage_config())
4850
exps = [

trinity/buffer/storage/sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
197197
if self.stopped:
198198
raise StopIteration()
199199

200-
batch_size = batch_size or self.batch_size
200+
batch_size = self.batch_size if batch_size is None else batch_size
201201
return self._read_method(batch_size, **kwargs)
202202

203203
@classmethod
@@ -248,7 +248,7 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]:
248248
raise StopIteration()
249249
if self.offset > self.total_samples:
250250
raise StopIteration()
251-
batch_size = batch_size or self.batch_size
251+
batch_size = self.batch_size if batch_size is None else batch_size
252252
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
253253
query = (
254254
session.query(self.table_model_cls)

0 commit comments

Comments
 (0)