Skip to content

Commit 163b53a

Browse files
committed
Add default batch size for SQLReader
1 parent bb0875d commit 163b53a

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

tests/buffer/sql_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from copy import deepcopy
23

34
import ray
45
import torch
@@ -40,7 +41,9 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
4041
)
4142
if enable_replay:
4243
config.replay_buffer = ReplayBufferConfig(enable=True)
43-
sql_writer = SQLWriter(config.to_storage_config())
44+
writer_config = deepcopy(config)
45+
writer_config.batch_size = put_batch_size
46+
sql_writer = SQLWriter(writer_config.to_storage_config())
4447
sql_reader = SQLReader(config.to_storage_config())
4548
exps = [
4649
Experience(

trinity/buffer/reader/sql_reader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@ class SQLReader(BufferReader):
1616
def __init__(self, config: StorageConfig) -> None:
1717
assert config.storage_type == StorageType.SQL.value
1818
self.wrap_in_ray = config.wrap_in_ray
19+
self.read_batch_size = config.batch_size
1920
self.storage = SQLStorage.get_wrapper(config)
2021

2122
def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
23+
batch_size = batch_size or self.read_batch_size
2224
if self.wrap_in_ray:
2325
return ray.get(self.storage.read.remote(batch_size, **kwargs))
2426
else:
2527
return self.storage.read(batch_size, **kwargs)
2628

2729
async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
30+
batch_size = batch_size or self.read_batch_size
2831
if self.wrap_in_ray:
2932
try:
3033
return await self.storage.read.remote(batch_size, **kwargs)

0 commit comments

Comments
 (0)