Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tests/buffer/sql_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from copy import deepcopy

import ray
import torch
Expand Down Expand Up @@ -40,7 +41,9 @@ async def test_sql_exp_buffer_read_write(self, enable_replay: bool) -> None:
)
if enable_replay:
config.replay_buffer = ReplayBufferConfig(enable=True)
sql_writer = SQLWriter(config.to_storage_config())
writer_config = deepcopy(config)
writer_config.batch_size = put_batch_size
sql_writer = SQLWriter(writer_config.to_storage_config())
sql_reader = SQLReader(config.to_storage_config())
exps = [
Experience(
Expand Down
3 changes: 3 additions & 0 deletions trinity/buffer/reader/sql_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,18 @@ class SQLReader(BufferReader):
def __init__(self, config: StorageConfig) -> None:
assert config.storage_type == StorageType.SQL.value
self.wrap_in_ray = config.wrap_in_ray
self.read_batch_size = config.batch_size
self.storage = SQLStorage.get_wrapper(config)

def read(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
if self.wrap_in_ray:
return ray.get(self.storage.read.remote(batch_size, **kwargs))
else:
return self.storage.read(batch_size, **kwargs)

async def read_async(self, batch_size: Optional[int] = None, **kwargs) -> List:
batch_size = batch_size or self.read_batch_size
if self.wrap_in_ray:
try:
return await self.storage.read.remote(batch_size, **kwargs)
Expand Down