Skip to content

Commit b4c152f

Browse files
PaParaZz1sailxjxhiha3456
authored
feature(nyz&xjx): add new middleware distributed demo (#321)
* demo(nyz): add naive dp demo * demo(nyz): add naive ddp demo * feature(nyz): add naive tb_logger in new evaluator * Add singleton log writer * Use get_instance on writer * feature(nyz): add general logger middleware * feature(nyz): add soft update in DQN target network * fix(nyz): fix termination env_step bug and eval task.finish broadcast bug * Support distributed dqn * Add more desc (ci skip) * Support distributed dqn Add more desc (ci skip) Add timeout on model exchanger * feature(nyz): add online logger freq * fix(nyz): fix policy set device bug * add offline rl logger * change a bit * add else in checking ctx type * add test_logger.py * add mock of offline_logger * add mock of online writer * reformat * reformat * feature(nyz): polish atari ddp demo and add dist demo * fix(nyz): fix mq listen bug when stop * demo(nyz): add atari ppo(sm+ddp) demo * demo(nyz): add ppo ddp avgsplit demo * demo(nyz): add ditask + pytorch ddp demo * fix(nyz): fix dict-type obs bugs * fix(nyz): fix get_shape0 bug when nested structure * Route finish event to all processes in the cluster * demo(nyz): add naive dp demo * demo(nyz): add naive ddp demo * feature(nyz): add naive tb_logger in new evaluator * feature(nyz): add soft update in DQN target network * fix(nyz): fix termination env_step bug and eval task.finish broadcast bug * Add singleton log writer * Use get_instance on writer * feature(nyz): add general logger middleware * Support distributed dqn * Add more desc (ci skip) * Support distributed dqn Add more desc (ci skip) Add timeout on model exchanger * feature(nyz): add online logger freq * fix(nyz): fix policy set device bug * add offline rl logger * change a bit * add else in checking ctx type * add test_logger.py * add mock of offline_logger * add mock of online writer * reformat * reformat * feature(nyz): polish atari ddp demo and add dist demo * fix(nyz): fix mq listen bug when stop * demo(nyz): add atari ppo(sm+ddp) demo * demo(nyz): add ppo ddp avgsplit demo * demo(nyz): add ditask + pytorch ddp demo * fix(nyz): fix dict-type obs bugs * fix(nyz): fix get_shape0 bug when nested structure * Route finish event to all processes in the cluster * refactor(nyz): split dist ddp demo implementation * feature(nyz): add rdma test demo(ci skip) * feature(xjx): new style dist version, add storage loader and model loader (#425) * Add singleton log writer * Use get_instance on writer * feature(nyz): polish atari ddp demo and add dist demo * Refactor dist version * Wrap class based middleware * Change if condition in wrapper * Only run enhancer on learner * Support new parallel mode on slurm cluster * Temp data loader * Stash commit * Init data serializer * Update dump part of code * Test StorageLoader * Turn data serializer into storage loader, add storage loader in context exchanger * Add local id and startup interval * Fix storage loader * Support treetensor * Add role on event name in context exchanger, use share_memory function on tensor * Double size buffer * Copy tensor to cpu, skip wait for context on collector and evaluator * Remove data loader middleware * Upgrade k8s parser * Add epoch timer * Dont use lb * Change tensor to numpy * Remove files when stop storage loader * Discard shared object * Ensure correct load shm memory * Add model loader * Rename model_exchanger to ModelExchanger * Add model loader benchmark * Shutdown loaders when task finish * Upgrade supervisor * Dont cleanup files when shutting down * Fix async cleanup in model loader * Check model loader on dqn * Dont use loader in dqn example * Fix style check * Fix dp * Fix github tests * Skip github ci * Fix bug in event loop * Fix enhancer tests, move router from start to __init__ * Change default ttl * Add comments Co-authored-by: niuyazhe <[email protected]> * style(nyz): correct yapf style * fix(nyz): fix ctx and logger compatibility bugs * polish(nyz): update demo from cartpole v0 to v1 * fix(nyz): fix evaluator condition bug * style(nyz): correct flake8 style * demo(nyz): move back to CartPole-v0 * fix(nyz): fix context manager env step merge bug(ci skip) * fix(nyz): fix context manager env step merge bug(ci skip) * fix(nyz): fix flake8 style Co-authored-by: Xu Jingxin <[email protected]> Co-authored-by: zhumengshen <[email protected]>
1 parent dd2b3a5 commit b4c152f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+2645
-367
lines changed

ding/config/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def save_project_state(exp_name: str) -> None:
315315
def _fn(cmd: str):
316316
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
317317

318-
if subprocess.run("git status", shell=True, stderr=subprocess.PIPE).returncode == 0:
318+
if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
319319
short_sha = _fn("git describe --always")
320320
log = _fn("git log --stat -n 5")
321321
diff = _fn("git diff")

ding/data/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
from torch.utils.data import Dataset, DataLoader
22
from ding.utils.data import create_dataset, offline_data_save_type # for compatibility
33
from .buffer import *
4+
from .storage import *
5+
from .storage_loader import StorageLoader, FileStorageLoader
6+
from .shm_buffer import ShmBufferContainer, ShmBuffer
7+
from .model_loader import ModelLoader, FileModelLoader
File renamed without changes.

ding/data/model_loader.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from abc import ABC, abstractmethod
2+
import logging
3+
from os import path
4+
import os
5+
from threading import Thread
6+
from time import sleep, time
7+
from typing import Callable, Optional
8+
import uuid
9+
import torch.multiprocessing as mp
10+
11+
import torch
12+
from ding.data.storage.file import FileModelStorage
13+
from ding.data.storage.storage import Storage
14+
from ding.framework import Supervisor
15+
from ding.framework.supervisor import ChildType, SendPayload
16+
17+
18+
class ModelWorker():
19+
20+
def __init__(self, model: torch.nn.Module) -> None:
21+
self._model = model
22+
23+
def save(self, storage: Storage) -> Storage:
24+
storage.save(self._model.state_dict())
25+
return storage
26+
27+
28+
class ModelLoader(Supervisor, ABC):
29+
30+
def __init__(self, model: torch.nn.Module) -> None:
31+
"""
32+
Overview:
33+
Save and send models asynchronously and load them synchronously.
34+
Arguments:
35+
- model (:obj:`torch.nn.Module`): Torch module.
36+
"""
37+
if next(model.parameters()).is_cuda:
38+
super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
39+
else:
40+
super().__init__(type_=ChildType.PROCESS)
41+
self._model = model
42+
self._send_callback_loop = None
43+
self._send_callbacks = {}
44+
self._model_worker = ModelWorker(self._model)
45+
46+
def start(self):
47+
if not self._running:
48+
self._model.share_memory()
49+
self.register(self._model_worker)
50+
self.start_link()
51+
self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
52+
self._send_callback_loop.start()
53+
54+
def shutdown(self, timeout: Optional[float] = None) -> None:
55+
super().shutdown(timeout)
56+
self._send_callback_loop = None
57+
self._send_callbacks = {}
58+
59+
def _loop_send_callback(self):
60+
while True:
61+
payload = self.recv(ignore_err=True)
62+
if payload.err:
63+
logging.warning("Got error when loading data: {}".format(payload.err))
64+
if payload.req_id in self._send_callbacks:
65+
del self._send_callbacks[payload.req_id]
66+
else:
67+
if payload.req_id in self._send_callbacks:
68+
callback = self._send_callbacks.pop(payload.req_id)
69+
callback(payload.data)
70+
71+
def load(self, storage: Storage) -> object:
72+
"""
73+
Overview:
74+
Load model synchronously.
75+
Arguments:
76+
- storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
77+
Returns:
78+
- object (:obj:): The loaded model.
79+
"""
80+
return storage.load()
81+
82+
@abstractmethod
83+
def save(self, callback: Callable) -> Storage:
84+
"""
85+
Overview:
86+
Save model asynchronously.
87+
Arguments:
88+
- callback (:obj:`Callable`): The callback function after saving model.
89+
Returns:
90+
- storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
91+
"""
92+
raise NotImplementedError
93+
94+
95+
class FileModelLoader(ModelLoader):
96+
97+
def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
98+
"""
99+
Overview:
100+
Model loader using files as storage media.
101+
Arguments:
102+
- model (:obj:`torch.nn.Module`): Torch module.
103+
- dirname (:obj:`str`): The directory for saving files.
104+
- ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
105+
files that do not time out when the process is stopped are not cleaned up \
106+
(to avoid errors when other processes read the file), so you may need to \
107+
clean up the remaining files manually
108+
"""
109+
super().__init__(model)
110+
self._dirname = dirname
111+
self._ttl = ttl
112+
self._files = []
113+
self._cleanup_thread = None
114+
115+
def _start_cleanup(self):
116+
"""
117+
Overview:
118+
Start a cleanup thread to clean up files that are taking up too much time on the disk.
119+
"""
120+
if self._cleanup_thread is None:
121+
self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
122+
self._cleanup_thread.start()
123+
124+
def shutdown(self, timeout: Optional[float] = None) -> None:
125+
super().shutdown(timeout)
126+
self._cleanup_thread = None
127+
128+
def _loop_cleanup(self):
129+
while True:
130+
if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
131+
sleep(1)
132+
continue
133+
_, file_path = self._files.pop(0)
134+
if path.exists(file_path):
135+
os.remove(file_path)
136+
137+
def save(self, callback: Callable) -> FileModelStorage:
138+
if not self._running:
139+
logging.warning("Please start model loader before saving model.")
140+
return
141+
if not path.exists(self._dirname):
142+
os.mkdir(self._dirname)
143+
file_path = "model_{}.pth.tar".format(uuid.uuid1())
144+
file_path = path.join(self._dirname, file_path)
145+
model_storage = FileModelStorage(file_path)
146+
payload = SendPayload(proc_id=0, method="save", args=[model_storage])
147+
self.send(payload)
148+
149+
def clean_callback(storage: Storage):
150+
self._files.append([time(), file_path])
151+
callback(storage)
152+
153+
self._send_callbacks[payload.req_id] = clean_callback
154+
self._start_cleanup()
155+
return model_storage

ding/data/shm_buffer.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from typing import Any, Optional, Union, Tuple, Dict
2+
from multiprocessing import Array
3+
import ctypes
4+
import numpy as np
5+
import torch
6+
7+
_NTYPE_TO_CTYPE = {
8+
np.bool_: ctypes.c_bool,
9+
np.uint8: ctypes.c_uint8,
10+
np.uint16: ctypes.c_uint16,
11+
np.uint32: ctypes.c_uint32,
12+
np.uint64: ctypes.c_uint64,
13+
np.int8: ctypes.c_int8,
14+
np.int16: ctypes.c_int16,
15+
np.int32: ctypes.c_int32,
16+
np.int64: ctypes.c_int64,
17+
np.float32: ctypes.c_float,
18+
np.float64: ctypes.c_double,
19+
}
20+
21+
22+
class ShmBuffer():
23+
"""
24+
Overview:
25+
Shared memory buffer to store numpy array.
26+
"""
27+
28+
def __init__(
29+
self,
30+
dtype: Union[type, np.dtype],
31+
shape: Tuple[int],
32+
copy_on_get: bool = True,
33+
ctype: Optional[type] = None
34+
) -> None:
35+
"""
36+
Overview:
37+
Initialize the buffer.
38+
Arguments:
39+
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
40+
- shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
41+
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
42+
- ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
43+
"""
44+
if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
45+
dtype = dtype.type
46+
self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
47+
self.dtype = dtype
48+
self.shape = shape
49+
self.copy_on_get = copy_on_get
50+
self.ctype = ctype
51+
52+
def fill(self, src_arr: np.ndarray) -> None:
53+
"""
54+
Overview:
55+
Fill the shared memory buffer with a numpy array. (Replace the original one.)
56+
Arguments:
57+
- src_arr (:obj:`np.ndarray`): array to fill the buffer.
58+
"""
59+
assert isinstance(src_arr, np.ndarray), type(src_arr)
60+
# for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
61+
# for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
62+
# so we reshape dst_arr rather than flatten src_arr
63+
dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
64+
np.copyto(dst_arr, src_arr)
65+
66+
def get(self) -> np.ndarray:
67+
"""
68+
Overview:
69+
Get the array stored in the buffer.
70+
Return:
71+
- data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
72+
"""
73+
data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
74+
if self.copy_on_get:
75+
data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
76+
if self.ctype is torch.Tensor:
77+
data = torch.from_numpy(data)
78+
return data
79+
80+
81+
class ShmBufferContainer(object):
82+
"""
83+
Overview:
84+
Support multiple shared memory buffers. Each key-value is name-buffer.
85+
"""
86+
87+
def __init__(
88+
self,
89+
dtype: Union[Dict[Any, type], type, np.dtype],
90+
shape: Union[Dict[Any, tuple], tuple],
91+
copy_on_get: bool = True
92+
) -> None:
93+
"""
94+
Overview:
95+
Initialize the buffer container.
96+
Arguments:
97+
- dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
98+
- shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
99+
multiple buffers; If `tuple`, use single buffer.
100+
- copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
101+
"""
102+
if isinstance(shape, dict):
103+
self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
104+
elif isinstance(shape, (tuple, list)):
105+
self._data = ShmBuffer(dtype, shape, copy_on_get)
106+
else:
107+
raise RuntimeError("not support shape: {}".format(shape))
108+
self._shape = shape
109+
110+
def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
111+
"""
112+
Overview:
113+
Fill the one or many shared memory buffer.
114+
Arguments:
115+
- src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
116+
"""
117+
if isinstance(self._shape, dict):
118+
for k in self._shape.keys():
119+
self._data[k].fill(src_arr[k])
120+
elif isinstance(self._shape, (tuple, list)):
121+
self._data.fill(src_arr)
122+
123+
def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
124+
"""
125+
Overview:
126+
Get the one or many arrays stored in the buffer.
127+
Return:
128+
- data (:obj:`np.ndarray`): The array(s) stored in the buffer.
129+
"""
130+
if isinstance(self._shape, dict):
131+
return {k: self._data[k].get() for k in self._shape.keys()}
132+
elif isinstance(self._shape, (tuple, list)):
133+
return self._data.get()

ding/data/storage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .storage import Storage
2+
from .file import FileStorage, FileModelStorage

ding/data/storage/file.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Any
2+
from ding.data.storage import Storage
3+
import pickle
4+
5+
from ding.utils.file_helper import read_file, save_file
6+
7+
8+
class FileStorage(Storage):
9+
10+
def save(self, data: Any) -> None:
11+
with open(self.path, "wb") as f:
12+
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
13+
14+
def load(self) -> Any:
15+
with open(self.path, "rb") as f:
16+
return pickle.load(f)
17+
18+
19+
class FileModelStorage(Storage):
20+
21+
def save(self, state_dict: object) -> None:
22+
save_file(self.path, state_dict)
23+
24+
def load(self) -> object:
25+
return read_file(self.path)

ding/data/storage/storage.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Any
3+
4+
5+
class Storage(ABC):
6+
7+
def __init__(self, path: str) -> None:
8+
self.path = path
9+
10+
@abstractmethod
11+
def save(self, data: Any) -> None:
12+
raise NotImplementedError
13+
14+
@abstractmethod
15+
def load(self) -> Any:
16+
raise NotImplementedError
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import tempfile
2+
import pytest
3+
import os
4+
from os import path
5+
from ding.data.storage import FileStorage
6+
7+
8+
@pytest.mark.unittest
9+
def test_file_storage():
10+
path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
11+
try:
12+
storage = FileStorage(path=path_)
13+
storage.save("test")
14+
content = storage.load()
15+
assert content == "test"
16+
finally:
17+
if path.exists(path_):
18+
os.remove(path_)

0 commit comments

Comments
 (0)