Skip to content

Commit f803504

Browse files
authored
fix(pu): fix obs_shape tuple bug in initialize_zeros_batch (#327)
1 parent 06aee46 commit f803504

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

lzero/entry/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from typing import Optional, Callable
2+
from typing import Optional, Callable, Union, List, Tuple
33

44
import psutil
55
import torch
@@ -90,18 +90,18 @@ def calculate_update_per_collect(cfg, new_data):
9090

9191
return update_per_collect
9292

93-
def initialize_zeros_batch(observation_shape, batch_size, device):
93+
def initialize_zeros_batch(observation_shape: Union[int, List[int], Tuple[int]], batch_size: int, device: str) -> torch.Tensor:
9494
"""
9595
Overview:
9696
Initialize a zeros tensor for batch observations based on the shape. This function is used to initialize the UniZero model input.
9797
Arguments:
98-
- observation_shape (:obj:`Union[int, List[int]]`): The shape of the observation tensor.
98+
- observation_shape (:obj:`Union[int, List[int], Tuple[int]]`): The shape of the observation tensor.
9999
- batch_size (:obj:`int`): The batch size.
100100
- device (:obj:`str`): The device to store the tensor.
101101
Returns:
102102
- zeros (:obj:`torch.Tensor`): The zeros tensor.
103103
"""
104-
if isinstance(observation_shape, list):
104+
if isinstance(observation_shape, (list, tuple)):
105105
shape = [batch_size, *observation_shape]
106106
elif isinstance(observation_shape, int):
107107
shape = [batch_size, observation_shape]

0 commit comments

Comments
 (0)