@@ -497,30 +497,17 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
497
497
0 , len (batch [BufferKey .CRITIC_MEMORY ]), self .policy .sequence_length
498
498
)
499
499
]
500
- offset = 1 if self .policy .sequence_length > 1 else 0
501
- next_value_memories_list = [
502
- ModelUtils .list_to_tensor (
503
- batch [BufferKey .CRITIC_MEMORY ][i ]
504
- ) # only pass value part of memory to target network
505
- for i in range (
506
- offset , len (batch [BufferKey .CRITIC_MEMORY ]), self .policy .sequence_length
507
- )
508
- ]
509
500
510
501
if len (memories_list ) > 0 :
511
502
memories = torch .stack (memories_list ).unsqueeze (0 )
512
503
value_memories = torch .stack (value_memories_list ).unsqueeze (0 )
513
- next_value_memories = torch .stack (next_value_memories_list ).unsqueeze (0 )
514
504
else :
515
505
memories = None
516
506
value_memories = None
517
- next_value_memories = None
518
507
519
508
# Q and V network memories are 0'ed out, since we don't have them during inference.
520
509
q_memories = (
521
- torch .zeros_like (next_value_memories )
522
- if next_value_memories is not None
523
- else None
510
+ torch .zeros_like (value_memories ) if value_memories is not None else None
524
511
)
525
512
526
513
# Copy normalizers from policy
@@ -568,6 +555,18 @@ def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]:
568
555
q1_stream , q2_stream = q1_out , q2_out
569
556
570
557
with torch .no_grad ():
558
+ # Since we didn't record the next value memories, evaluate one step in the critic to
559
+ # get them.
560
+ if value_memories is not None :
561
+ # Get the first observation in each sequence
562
+ just_first_obs = [
563
+ _obs [:: self .policy .sequence_length ] for _obs in current_obs
564
+ ]
565
+ _ , next_value_memories = self ._critic .critic_pass (
566
+ just_first_obs , value_memories , sequence_length = 1
567
+ )
568
+ else :
569
+ next_value_memories = None
571
570
target_values , _ = self .target_network (
572
571
next_obs ,
573
572
memories = next_value_memories ,
0 commit comments