|
1 | 1 | import os |
2 | | -from typing import Optional, Callable |
| 2 | +from typing import Optional, Callable, Union, List, Tuple |
3 | 3 |
|
4 | 4 | import psutil |
5 | 5 | import torch |
@@ -90,18 +90,18 @@ def calculate_update_per_collect(cfg, new_data): |
90 | 90 |
|
91 | 91 | return update_per_collect |
92 | 92 |
|
93 | | -def initialize_zeros_batch(observation_shape, batch_size, device): |
| 93 | +def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str) -> torch.Tensor: |
94 | 94 | """ |
95 | 95 | Overview: |
96 | 96 | Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input. |
97 | 97 | Arguments: |
98 | | - - observation_shape (:obj:`Union[int, List[int]]`): The shape of the observation tensor. |
| 98 | + - observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor. |
99 | 99 | - batch_size (:obj:`int`): The batch size. |
100 | 100 | - device (:obj:`str`): The device to store the tensor. |
101 | 101 | Returns: |
102 | 102 | - zeros (:obj:`torch.Tensor`): The zeros tensor. |
103 | 103 | """ |
104 | | - if isinstance(observation_shape, list): |
| 104 | + if isinstance(observation_shape, (list, tuple)): |
105 | 105 | shape = [batch_size, *observation_shape] |
106 | 106 | elif isinstance(observation_shape, int): |
107 | 107 | shape = [batch_size, observation_shape] |
|
0 commit comments