@@ -28,45 +28,35 @@ def to_hashable(x):
2828 raise TypeError (f"Unsupported type: { type (x )} " )
2929
3030
31- def compute_step_discounted_returns (batch : DataProto , gamma : float ):
31+ def compute_step_discounted_returns (batch , gamma ):
3232 rewards = batch .non_tensor_batch ['rewards' ].astype (np .float32 )
3333 traj_uids = batch .non_tensor_batch ['traj_uid' ]
34- # active_masks = batch.non_tensor_batch['active_masks'].astype(np.float32)
3534 returns_by_traj = {}
3635 unique_traj_uids = np .unique (traj_uids )
3736 for uid in unique_traj_uids :
38- # Get indices for this trajectory
3937 traj_indices = np .where (traj_uids == uid )[0 ]
40-
41- # Extract rewards and masks for this trajectory
4238 traj_rewards = rewards [traj_indices ]
4339 traj_extract_matches = batch .non_tensor_batch ['extract_match' ][traj_indices ]
44- # print("traj_rewards",traj_rewards)
45- # traj_active_masks = active_masks[traj_indices]
46- # assert traj_active_masks.all(), "active_masks should be all 1s for the same trajectory"
4740
48- # Calculate returns
4941 traj_returns = np .zeros_like (traj_rewards )
50- running_return = 0
42+ running_return = 0.0
5143
52- # Calculate returns from the end to the start
44+ # 从后往前,遇到False则断开区间,不再累加future reward
5345 for t in reversed (range (len (traj_rewards ))):
54- running_return = traj_rewards [t ] + gamma * running_return
55- traj_returns [t ] = running_return
56- for i in range (len (traj_rewards )): # fix bug : if the step is false, do not add future reward
57- if traj_extract_matches [i ] == False :
58- traj_returns [i ] = traj_rewards [i ]
59- # Store the results
60- # print("traj_returns",traj_returns)
46+ if traj_extract_matches [t ]:
47+ running_return = traj_rewards [t ] + gamma * running_return
48+ traj_returns [t ] = running_return
49+ else :
50+ running_return = 0.0
51+ traj_returns [t ] = traj_rewards [t ]
6152 returns_by_traj [uid ] = traj_returns
6253
63- # Recombine the returns into the original batch order
54+ # Recombine to original batch order
6455 all_returns = np .zeros_like (rewards )
6556 for i , uid in enumerate (traj_uids ):
6657 traj_indices = np .where (traj_uids == uid )[0 ]
67- idx_in_traj = np .where (traj_indices == i )[0 ][0 ] # Find position of i in its trajectory
58+ idx_in_traj = np .where (traj_indices == i )[0 ][0 ]
6859 all_returns [i ] = returns_by_traj [uid ][idx_in_traj ]
69-
7060 all_returns = torch .tensor (all_returns , dtype = torch .float32 , device = batch .batch ['input_ids' ].device )
7161 return all_returns
7262
0 commit comments