|
16 | 16 |
|
17 | 17 |
|
18 | 18 | def resolve_nn_activation(act_name: str) -> torch.nn.Module: |
19 | | - """Resolves the activation function from the name. |
| 19 | + """Resolve the activation function from the name. |
20 | 20 |
|
21 | 21 | Args: |
22 | 22 | act_name: The name of the activation function. |
@@ -50,7 +50,7 @@ def resolve_nn_activation(act_name: str) -> torch.nn.Module: |
50 | 50 |
|
51 | 51 |
|
52 | 52 | def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer: |
53 | | - """Resolves the optimizer from the name. |
| 53 | + """Resolve the optimizer from the name. |
54 | 54 |
|
55 | 55 | Args: |
56 | 56 | optimizer_name: The name of the optimizer. |
@@ -78,9 +78,9 @@ def resolve_optimizer(optimizer_name: str) -> torch.optim.Optimizer: |
78 | 78 | def split_and_pad_trajectories( |
79 | 79 | tensor: torch.Tensor | TensorDict, dones: torch.Tensor |
80 | 80 | ) -> tuple[torch.Tensor | TensorDict, torch.Tensor]: |
81 | | - """Splits trajectories at done indices. |
| 81 | + """Split trajectories at done indices. |
82 | 82 |
|
83 | | - Splits trajectories, concatenates them and pads with zeros up to the length of the longest trajectory. Returns masks |
| 83 | + Split trajectories, concatenate them and pad with zeros up to the length of the longest trajectory. Return masks |
84 | 84 | corresponding to valid parts of the trajectories. |
85 | 85 |
|
86 | 86 | Example: |
@@ -133,7 +133,7 @@ def split_and_pad_trajectories( |
133 | 133 |
|
134 | 134 |
|
135 | 135 | def unpad_trajectories(trajectories: torch.Tensor | TensorDict, masks: torch.Tensor) -> torch.Tensor | TensorDict: |
136 | | - """Does the inverse operation of `split_and_pad_trajectories()`.""" |
| 136 | + """Do the inverse operation of `split_and_pad_trajectories()`.""" |
137 | 137 | # Need to transpose before and after the masking to have proper reshaping |
138 | 138 | return ( |
139 | 139 | trajectories.transpose(1, 0)[masks.transpose(1, 0)] |
@@ -171,7 +171,7 @@ def store_code_state(logdir: str, repositories: list[str]) -> list[str]: |
171 | 171 |
|
172 | 172 |
|
173 | 173 | def string_to_callable(name: str) -> Callable: |
174 | | - """Resolves the module and function names to return the function. |
| 174 | + """Resolve the module and function names to return the function. |
175 | 175 |
|
176 | 176 | Args: |
177 | 177 | name: The function name. The format should be 'module:attribute_name'. |
@@ -203,7 +203,7 @@ def string_to_callable(name: str) -> Callable: |
203 | 203 | def resolve_obs_groups( |
204 | 204 | obs: TensorDict, obs_groups: dict[str, list[str]], default_sets: list[str] |
205 | 205 | ) -> dict[str, list[str]]: |
206 | | - """Validates the observation configuration and defaults missing observation sets. |
| 206 | + """Validate the observation configuration and defaults missing observation sets. |
207 | 207 |
|
208 | 208 | The input is an observation dictionary `obs` containing observation groups and a configuration dictionary |
209 | 209 | `obs_groups` where the keys are the observation sets and the values are lists of observation groups. |
|
0 commit comments