Skip to content

Commit 93afb00

Browse files
authored
Update core_uis1.py
1 parent 93f413c commit 93afb00

File tree

1 file changed

+11
-21
lines changed

1 file changed

+11
-21
lines changed

UI-S1/uis1/core_uis1.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,45 +28,35 @@ def to_hashable(x):
2828
raise TypeError(f"Unsupported type: {type(x)}")
2929

3030

31-
def compute_step_discounted_returns(batch: DataProto, gamma: float):
31+
def compute_step_discounted_returns(batch, gamma):
3232
rewards = batch.non_tensor_batch['rewards'].astype(np.float32)
3333
traj_uids = batch.non_tensor_batch['traj_uid']
34-
# active_masks = batch.non_tensor_batch['active_masks'].astype(np.float32)
3534
returns_by_traj = {}
3635
unique_traj_uids = np.unique(traj_uids)
3736
for uid in unique_traj_uids:
38-
# Get indices for this trajectory
3937
traj_indices = np.where(traj_uids == uid)[0]
40-
41-
# Extract rewards and masks for this trajectory
4238
traj_rewards = rewards[traj_indices]
4339
traj_extract_matches = batch.non_tensor_batch['extract_match'][traj_indices]
44-
# print("traj_rewards",traj_rewards)
45-
# traj_active_masks = active_masks[traj_indices]
46-
# assert traj_active_masks.all(), "active_masks should be all 1s for the same trajectory"
4740

48-
# Calculate returns
4941
traj_returns = np.zeros_like(traj_rewards)
50-
running_return = 0
42+
running_return = 0.0
5143

52-
# Calculate returns from the end to the start
44+
# 从后往前,遇到False则断开区间,不再累加future reward
5345
for t in reversed(range(len(traj_rewards))):
54-
running_return = traj_rewards[t] + gamma * running_return
55-
traj_returns[t] = running_return
56-
for i in range(len(traj_rewards)): # fix bug : if the step is false, do not add future reward
57-
if traj_extract_matches[i] == False:
58-
traj_returns[i] = traj_rewards[i]
59-
# Store the results
60-
# print("traj_returns",traj_returns)
46+
if traj_extract_matches[t]:
47+
running_return = traj_rewards[t] + gamma * running_return
48+
traj_returns[t] = running_return
49+
else:
50+
running_return = 0.0
51+
traj_returns[t] = traj_rewards[t]
6152
returns_by_traj[uid] = traj_returns
6253

63-
# Recombine the returns into the original batch order
54+
# Recombine to original batch order
6455
all_returns = np.zeros_like(rewards)
6556
for i, uid in enumerate(traj_uids):
6657
traj_indices = np.where(traj_uids == uid)[0]
67-
idx_in_traj = np.where(traj_indices == i)[0][0] # Find position of i in its trajectory
58+
idx_in_traj = np.where(traj_indices == i)[0][0]
6859
all_returns[i] = returns_by_traj[uid][idx_in_traj]
69-
7060
all_returns = torch.tensor(all_returns, dtype=torch.float32, device=batch.batch['input_ids'].device)
7161
return all_returns
7262

0 commit comments

Comments
 (0)