Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from copy import deepcopy

import ray
import torch
Expand Down Expand Up @@ -40,7 +41,11 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
)
if enable_replay:
config.replay_buffer = ReplayBufferConfig(enable=True)
sql_writer = SQLWriter(config.to_storage_config())
writer_config = deepcopy(config)
writer_config.batch_size = put_batch_size
# Create buffer by writer, so buffer.batch_size will be set to put_batch_size
# This will check whether read_batch_size tasks effect
sql_writer = SQLWriter(writer_config.to_storage_config())
sql_reader = SQLReader(config.to_storage_config())
exps = [
Experience(
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/reader/queue_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def __init__(self, config: StorageConfig):

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
try:
batch_size = batch_size or self.read_batch_size
batch_size = self.read_batch_size if batch_size is None else batch_size
exps = ray.get(self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs))
if len(exps) != batch_size:
raise TimeoutError(
Expand All @@ -32,7 +32,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
return exps

async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
batch_size = self.read_batch_size if batch_size is None else batch_size
exps = await self.queue.get_batch.remote(batch_size, timeout=self.timeout, **kwargs)
if len(exps) != batch_size:
raise TimeoutError(
Expand Down
3 changes: 3 additions & 0 deletions trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@ class SQLReader(BufferReader):
def __init__(self, config: StorageConfig) -> None:
assert config.storage_type == StorageType.SQL.value
self.wrap_in_ray = config.wrap_in_ray
self.read_batch_size = config.batch_size
self.storage = SQLStorage.get_wrapper(config)

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = self.read_batch_size if batch_size is None else batch_size
if self.wrap_in_ray:
return ray.get(self.storage.read.remote(batch_size, **kwargs))
else:
return self.storage.read(batch_size, **kwargs)

async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = self.read_batch_size if batch_size is None else batch_size
if self.wrap_in_ray:
try:
return await self.storage.read.remote(batch_size, **kwargs)
Expand Down
4 changes: 2 additions & 2 deletions trinity/buffer/storage/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def read(self, batch_size: Optional[int] = None, **kwargs) -> List[Experience]:
if self.stopped:
raise StopIteration()

batch_size = batch_size or self.batch_size
batch_size = self.batch_size if batch_size is None else batch_size
return self._read_method(batch_size, **kwargs)

@classmethod
Expand Down Expand Up @@ -248,7 +248,7 @@ def read(self, batch_size: Optional[int] = None) -> List[Task]:
raise StopIteration()
if self.offset > self.total_samples:
raise StopIteration()
batch_size = batch_size or self.batch_size
batch_size = self.batch_size if batch_size is None else batch_size
with retry_session(self.session, self.max_retry_times, self.max_retry_interval) as session:
query = (
session.query(self.table_model_cls)
Expand Down