@@ -734,14 +734,25 @@ class HindsightExperienceTransformer(DataTransformer):
734734 of the current timestep.
735735 The exact field names can be provided via arguments to the class ``__init__``.
736736
737+ NOTE: The HindsightExperienceTransformer has to happen before any transformer which changes
738+ reward or achieved_goal fields, e.g. observation normalizer, reward clipper, etc..
739+ See `documentation <../../docs/notes/knowledge_base.rst#datatransformers>`_ for details.
740+
737741 To use this class, add it to any existing data transformers, e.g. use this config if
738742 ``ObservationNormalizer`` is an existing data transformer:
739743
740744 .. code-block:: python
741745
742- ReplayBuffer.keep_episodic_info=True
743- HindsightExperienceTransformer.her_proportion=0.8
744- TrainerConfig.data_transformer_ctor=[@HindsightExperienceTransformer, @ObservationNormalizer]
746+ alf.config('ReplayBuffer', keep_episodic_info=True)
747+ alf.config(
748+ 'HindsightExperienceTransformer',
749+ her_proportion=0.8
750+ )
751+ alf.config(
752+ 'TrainerConfig',
753+ data_transformer_ctor=[
754+ HindsightExperienceTransformer, ObservationNormalizer
755+ ])
745756
746757 See unit test for more details on behavior.
747758 """
@@ -818,9 +829,10 @@ def transform_experience(self, experience: Experience):
818829 # relabel only these sampled indices
819830 her_cond = torch .rand (batch_size ) < her_proportion
820831 (her_indices , ) = torch .where (her_cond )
832+ has_her = torch .any (her_cond )
821833
822- last_step_pos = start_pos [ her_indices ] + batch_length - 1
823- last_env_ids = env_ids [ her_indices ]
834+ last_step_pos = start_pos + batch_length - 1
835+ last_env_ids = env_ids
824836 # Get x, y indices of LAST steps
825837 dist = buffer .steps_to_episode_end (last_step_pos , last_env_ids )
826838 if alf .summary .should_record_summaries ():
@@ -829,22 +841,24 @@ def transform_experience(self, experience: Experience):
829841 torch .mean (dist .type (torch .float32 )))
830842
831843 # get random future state
832- future_idx = last_step_pos + (torch .rand (* dist .shape ) *
833- (dist + 1 )).to (torch .int64 )
844+ future_dist = (torch .rand (* dist .shape ) * (dist + 1 )).to (
845+ torch .int64 )
846+ future_idx = last_step_pos + future_dist
834847 future_ag = buffer .get_field (self ._achieved_goal_field ,
835848 last_env_ids , future_idx ).unsqueeze (1 )
836849
837850 # relabel desired goal
838851 result_desired_goal = alf .nest .get_field (result ,
839852 self ._desired_goal_field )
840- relabed_goal = result_desired_goal .clone ()
853+ relabeled_goal = result_desired_goal .clone ()
841854 her_batch_index_tuple = (her_indices .unsqueeze (1 ),
842855 torch .arange (batch_length ).unsqueeze (0 ))
843- relabed_goal [her_batch_index_tuple ] = future_ag
856+ if has_her :
857+ relabeled_goal [her_batch_index_tuple ] = future_ag [her_indices ]
844858
845859 # recompute rewards
846860 result_ag = alf .nest .get_field (result , self ._achieved_goal_field )
847- relabeled_rewards = self ._reward_fn (result_ag , relabed_goal )
861+ relabeled_rewards = self ._reward_fn (result_ag , relabeled_goal )
848862
849863 non_her_or_fst = ~ her_cond .unsqueeze (1 ) & (result .step_type !=
850864 StepType .FIRST )
@@ -874,21 +888,28 @@ def transform_experience(self, experience: Experience):
874888 alf .summary .scalar (
875889 "replayer/" + buffer ._name + ".reward_mean_before_relabel" ,
876890 torch .mean (result .reward [her_indices ][:- 1 ]))
877- alf .summary .scalar (
878- "replayer/" + buffer ._name + ".reward_mean_after_relabel" ,
879- torch .mean (relabeled_rewards [her_indices ][:- 1 ]))
891+ if has_her :
892+ alf .summary .scalar (
893+ "replayer/" + buffer ._name + ".reward_mean_after_relabel" ,
894+ torch .mean (relabeled_rewards [her_indices ][:- 1 ]))
895+ alf .summary .scalar ("replayer/" + buffer ._name + ".future_distance" ,
896+ torch .mean (future_dist .float ()))
880897
881898 result = alf .nest .transform_nest (
882- result , self ._desired_goal_field , lambda _ : relabed_goal )
883-
899+ result , self ._desired_goal_field , lambda _ : relabeled_goal )
884900 result = result .update_time_step_field ('reward' , relabeled_rewards )
885-
901+ info = info . _replace ( her = her_cond , future_distance = future_dist )
886902 if alf .get_default_device () != buffer .device :
887903 for f in accessed_fields :
888904 result = alf .nest .transform_nest (
889905 result , f , lambda t : convert_device (t ))
890- result = alf .nest .transform_nest (
891- result , "batch_info.replay_buffer" , lambda _ : buffer )
906+ info = convert_device (info )
907+ info = info ._replace (
908+ her = info .her .unsqueeze (1 ).expand (result .reward .shape [:2 ]),
909+ future_distance = info .future_distance .unsqueeze (1 ).expand (
910+ result .reward .shape [:2 ]),
911+ replay_buffer = buffer )
912+ result = alf .data_structures .add_batch_info (result , info )
892913 return result
893914
894915
0 commit comments