Skip to content

Commit a348d96

Browse files
authored
Add default batch size for SQLReader (#467)
1 parent 07302f2 commit a348d96

File tree

4 files changed

+13
-5
lines changed

4 files changed

+13
-5
lines changed

tests/buffer/sql_test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from copy import deepcopy
23

34
import ray
45
import torch
@@ -40,7 +41,11 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
4041
)
4142
if enable_replay:
4243
config.replay_buffer = ReplayBufferConfig(enable=True)
43-
sql_writer = SQLWriter(config.to_storage_config())
44+
writer_config = deepcopy(config)
45+
writer_config.batch_size = put_batch_size
46+
# Create buffer by writer, so buffer.batch_size will be set to put_batch_size
47+
# This will check whether read_batch_size tasks effect
48+
sql_writer = SQLWriter(writer_config.to_storage_config())
4449
sql_reader = SQLReader(config.to_storage_config())
4550
exps = [
4651
Experience(

trinity/buffer/reader/queue_reader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, config: StorageConfig):
2121

2222
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
2323
try:
24-
batch_size = batch_size or self.read_batch_size
24+
batch_size = self.read_batch_size if batch_size is None else batch_size
2525
exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs))
2626
if len(exps) != batch_size:
2727
raise TimeoutError(
@@ -32,7 +32,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
3232
return exps
3333

3434
async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
35-
batch_size = batch_size or self.read_batch_size
35+
batch_size = self.read_batch_size if batch_size is None else batch_size
3636
exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)
3737
if len(exps) != batch_size:
3838
raise TimeoutError(

trinity/buffer/reader/sql_reader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ class SQLReader(BufferReader):
1616
def __init__(self, config: StorageConfig) -> None:
1717
assert config.storage_type == StorageType.SQL.value
1818
self.wrap_in_ray = config.wrap_in_ray
19+
self.read_batch_size = config.batch_size
1920
self.storage = SQLStorage.get_wrapper(config)
2021

2122
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
23+
batch_size = self.read_batch_size if batch_size is None else batch_size
2224
if self.wrap_in_ray:
2325
return ray.get(self.storage.read.remote(batch_size, **kwargs))
2426
else:
2527
return self.storage.read(batch_size, **kwargs)
2628

2729
async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
30+
batch_size = self.read_batch_size if batch_size is None else batch_size
2831
if self.wrap_in_ray:
2932
try:
3033
return await self.storage.read.remote(batch_size, **kwargs)

trinity/buffer/storage/sql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
197197
if self.stopped:
198198
raise StopIteration()
199199

200-
batch_size = batch_size or self.batch_size
200+
batch_size = self.batch_size if batch_size is None else batch_size
201201
return self._read_method(batch_size, **kwargs)
202202

203203
@classmethod
@@ -248,7 +248,7 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]:
248248
raise StopIteration()
249249
if self.offset > self.total_samples:
250250
raise StopIteration()
251-
batch_size = batch_size or self.batch_size
251+
batch_size = self.batch_size if batch_size is None else batch_size
252252
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
253253
query = (
254254
session.query(self.table_model_cls)

0 commit comments

Comments
 (0)