11import math
2- from typing import Dict , Tuple
32
43import gymnasium as gym
54import torch
@@ -43,7 +42,7 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
4342 layer_init (torch .nn .Linear (64 , envs .single_action_space .n ), std = 0.01 , ortho_init = ortho_init ),
4443 )
4544
46- def get_action (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor ]:
45+ def get_action (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor ]:
4746 logits = self .actor (x )
4847 distribution = Categorical (logits = logits )
4948 if action is None :
@@ -58,12 +57,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor:
5857 def get_value (self , x : Tensor ) -> Tensor :
5958 return self .critic (x )
6059
61- def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
60+ def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
6261 action , log_prob , entropy = self .get_action (x , action )
6362 value = self .get_value (x )
6463 return action , log_prob , entropy , value
6564
66- def forward (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
65+ def forward (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
6766 return self .get_action_and_value (x , action )
6867
6968 @torch .no_grad ()
@@ -77,7 +76,7 @@ def estimate_returns_and_advantages(
7776 num_steps : int ,
7877 gamma : float ,
7978 gae_lambda : float ,
80- ) -> Tuple [Tensor , Tensor ]:
79+ ) -> tuple [Tensor , Tensor ]:
8180 next_value = self .get_value (next_obs ).reshape (1 , - 1 )
8281 advantages = torch .zeros_like (rewards )
8382 lastgaelam = 0
@@ -143,7 +142,7 @@ def __init__(
143142 self .avg_value_loss = MeanMetric (** torchmetrics_kwargs )
144143 self .avg_ent_loss = MeanMetric (** torchmetrics_kwargs )
145144
146- def get_action (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor ]:
145+ def get_action (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor ]:
147146 logits = self .actor (x )
148147 distribution = Categorical (logits = logits )
149148 if action is None :
@@ -158,12 +157,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor:
158157 def get_value (self , x : Tensor ) -> Tensor :
159158 return self .critic (x )
160159
161- def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
160+ def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
162161 action , log_prob , entropy = self .get_action (x , action )
163162 value = self .get_value (x )
164163 return action , log_prob , entropy , value
165164
166- def forward (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
165+ def forward (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
167166 return self .get_action_and_value (x , action )
168167
169168 @torch .no_grad ()
@@ -177,7 +176,7 @@ def estimate_returns_and_advantages(
177176 num_steps : int ,
178177 gamma : float ,
179178 gae_lambda : float ,
180- ) -> Tuple [Tensor , Tensor ]:
179+ ) -> tuple [Tensor , Tensor ]:
181180 next_value = self .get_value (next_obs ).reshape (1 , - 1 )
182181 advantages = torch .zeros_like (rewards )
183182 lastgaelam = 0
@@ -193,7 +192,7 @@ def estimate_returns_and_advantages(
193192 returns = advantages + values
194193 return returns , advantages
195194
196- def training_step (self , batch : Dict [str , Tensor ]):
195+ def training_step (self , batch : dict [str , Tensor ]):
197196 # Get actions and values given the current observations
198197 _ , newlogprob , entropy , newvalue = self (batch ["obs" ], batch ["actions" ].long ())
199198 logratio = newlogprob - batch ["logprobs" ]
0 commit comments