|
| 1 | +from typing import Any, Optional, Union, Tuple, Dict |
| 2 | +from multiprocessing import Array |
| 3 | +import ctypes |
| 4 | +import numpy as np |
| 5 | +import torch |
| 6 | + |
| 7 | +_NTYPE_TO_CTYPE = { |
| 8 | + np.bool_: ctypes.c_bool, |
| 9 | + np.uint8: ctypes.c_uint8, |
| 10 | + np.uint16: ctypes.c_uint16, |
| 11 | + np.uint32: ctypes.c_uint32, |
| 12 | + np.uint64: ctypes.c_uint64, |
| 13 | + np.int8: ctypes.c_int8, |
| 14 | + np.int16: ctypes.c_int16, |
| 15 | + np.int32: ctypes.c_int32, |
| 16 | + np.int64: ctypes.c_int64, |
| 17 | + np.float32: ctypes.c_float, |
| 18 | + np.float64: ctypes.c_double, |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +class ShmBuffer(): |
| 23 | + """ |
| 24 | + Overview: |
| 25 | + Shared memory buffer to store numpy array. |
| 26 | + """ |
| 27 | + |
| 28 | + def __init__( |
| 29 | + self, |
| 30 | + dtype: Union[type, np.dtype], |
| 31 | + shape: Tuple[int], |
| 32 | + copy_on_get: bool = True, |
| 33 | + ctype: Optional[type] = None |
| 34 | + ) -> None: |
| 35 | + """ |
| 36 | + Overview: |
| 37 | + Initialize the buffer. |
| 38 | + Arguments: |
| 39 | + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. |
| 40 | + - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer. |
| 41 | + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. |
| 42 | + - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor. |
| 43 | + """ |
| 44 | + if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype |
| 45 | + dtype = dtype.type |
| 46 | + self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape))) |
| 47 | + self.dtype = dtype |
| 48 | + self.shape = shape |
| 49 | + self.copy_on_get = copy_on_get |
| 50 | + self.ctype = ctype |
| 51 | + |
| 52 | + def fill(self, src_arr: np.ndarray) -> None: |
| 53 | + """ |
| 54 | + Overview: |
| 55 | + Fill the shared memory buffer with a numpy array. (Replace the original one.) |
| 56 | + Arguments: |
| 57 | + - src_arr (:obj:`np.ndarray`): array to fill the buffer. |
| 58 | + """ |
| 59 | + assert isinstance(src_arr, np.ndarray), type(src_arr) |
| 60 | + # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten |
| 61 | + # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten |
| 62 | + # so we reshape dst_arr rather than flatten src_arr |
| 63 | + dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) |
| 64 | + np.copyto(dst_arr, src_arr) |
| 65 | + |
| 66 | + def get(self) -> np.ndarray: |
| 67 | + """ |
| 68 | + Overview: |
| 69 | + Get the array stored in the buffer. |
| 70 | + Return: |
| 71 | + - data (:obj:`np.ndarray`): A copy of the data stored in the buffer. |
| 72 | + """ |
| 73 | + data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape) |
| 74 | + if self.copy_on_get: |
| 75 | + data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory |
| 76 | + if self.ctype is torch.Tensor: |
| 77 | + data = torch.from_numpy(data) |
| 78 | + return data |
| 79 | + |
| 80 | + |
| 81 | +class ShmBufferContainer(object): |
| 82 | + """ |
| 83 | + Overview: |
| 84 | + Support multiple shared memory buffers. Each key-value is name-buffer. |
| 85 | + """ |
| 86 | + |
| 87 | + def __init__( |
| 88 | + self, |
| 89 | + dtype: Union[Dict[Any, type], type, np.dtype], |
| 90 | + shape: Union[Dict[Any, tuple], tuple], |
| 91 | + copy_on_get: bool = True |
| 92 | + ) -> None: |
| 93 | + """ |
| 94 | + Overview: |
| 95 | + Initialize the buffer container. |
| 96 | + Arguments: |
| 97 | + - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer. |
| 98 | + - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \ |
| 99 | + multiple buffers; If `tuple`, use single buffer. |
| 100 | + - copy_on_get (:obj:`bool`): Whether to copy data when calling get method. |
| 101 | + """ |
| 102 | + if isinstance(shape, dict): |
| 103 | + self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()} |
| 104 | + elif isinstance(shape, (tuple, list)): |
| 105 | + self._data = ShmBuffer(dtype, shape, copy_on_get) |
| 106 | + else: |
| 107 | + raise RuntimeError("not support shape: {}".format(shape)) |
| 108 | + self._shape = shape |
| 109 | + |
| 110 | + def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None: |
| 111 | + """ |
| 112 | + Overview: |
| 113 | + Fill the one or many shared memory buffer. |
| 114 | + Arguments: |
| 115 | + - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer. |
| 116 | + """ |
| 117 | + if isinstance(self._shape, dict): |
| 118 | + for k in self._shape.keys(): |
| 119 | + self._data[k].fill(src_arr[k]) |
| 120 | + elif isinstance(self._shape, (tuple, list)): |
| 121 | + self._data.fill(src_arr) |
| 122 | + |
| 123 | + def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]: |
| 124 | + """ |
| 125 | + Overview: |
| 126 | + Get the one or many arrays stored in the buffer. |
| 127 | + Return: |
| 128 | + - data (:obj:`np.ndarray`): The array(s) stored in the buffer. |
| 129 | + """ |
| 130 | + if isinstance(self._shape, dict): |
| 131 | + return {k: self._data[k].get() for k in self._shape.keys()} |
| 132 | + elif isinstance(self._shape, (tuple, list)): |
| 133 | + return self._data.get() |
0 commit comments