@@ -32,7 +32,7 @@ def __init__(
3232 # This model is continuous only
3333 self .continuous_action = True
3434
35- # PR: register the per-dim scale and bias so we can rescale [-1,1]→[low,high].
35+ # Register the per-dim scale and bias so we can rescale [-1,1]→[low,high].
3636 action_low = torch .as_tensor (action_low , dtype = torch .float32 )
3737 action_high = torch .as_tensor (action_high , dtype = torch .float32 )
3838 self .register_buffer (
@@ -67,42 +67,75 @@ def __init__(
6767 self .hidden_sizes = feature_extractor_kwargs .get ("hidden_sizes" , [256 , 256 ])
6868 self .activation = feature_extractor_kwargs .get ("activation" , "relu" )
6969
70- # Shared feature extractor for policy
71- self .feature_extractor , out_dim = make_feature_extractor (
70+ # Policy feature extractor and head
71+ self .policy_feature_extractor , policy_feat_dim = make_feature_extractor (
7272 ** feature_extractor_kwargs
7373 )
74-
75- # Policy network outputs mean and log_std
76- # CHANGE: Create separate policy network (actor) similar to CleanRL
77- self .policy_net = make_policy_head (
78- in_size = self .obs_size ,
74+
75+ # Policy head: just the final output layer
76+ self .policy_head = make_policy_head (
77+ in_size = policy_feat_dim ,
7978 out_size = self .action_size * 2 , # mean and log_std
80- ** head_kwargs
79+ hidden_sizes = [], # No hidden layers, just final linear layer
80+ activation = head_kwargs ["activation" ]
8181 )
8282
83- # Twin Q-networks
84- # — live Q-nets —
85- self .q_net1 = make_q_head (
86- in_size = self .obs_size + self .action_size , ** head_kwargs
83+ # Create policy_net for backward compatibility
84+ self .policy_net = nn .Sequential (self .policy_feature_extractor , self .policy_head )
85+
86+ # Q-networks: feature extractors + heads
87+ q_feature_extractor_kwargs = feature_extractor_kwargs .copy ()
88+ q_feature_extractor_kwargs ["obs_shape" ] = self .obs_size + self .action_size
89+
90+ # Q-network 1
91+ self .q_feature_extractor1 , q_feat_dim = make_feature_extractor (** q_feature_extractor_kwargs )
92+ self .q_head1 = make_q_head (
93+ in_size = q_feat_dim ,
94+ hidden_sizes = [], # No hidden layers, just final linear layer
95+ activation = head_kwargs ["activation" ]
8796 )
88- self .q_net2 = make_q_head (
89- in_size = self .obs_size + self .action_size , ** head_kwargs
97+ self .q_net1 = nn .Sequential (self .q_feature_extractor1 , self .q_head1 )
98+
99+ # Q-network 2
100+ self .q_feature_extractor2 , _ = make_feature_extractor (** q_feature_extractor_kwargs )
101+ self .q_head2 = make_q_head (
102+ in_size = q_feat_dim ,
103+ hidden_sizes = [], # No hidden layers, just final linear layer
104+ activation = head_kwargs ["activation" ]
90105 )
106+ self .q_net2 = nn .Sequential (self .q_feature_extractor2 , self .q_head2 )
91107
92108 # Target Q-networks
93- self .target_q_net1 = make_q_head (
94- in_size = self .obs_size + self .action_size , ** head_kwargs
109+ self .target_q_feature_extractor1 , _ = make_feature_extractor (** q_feature_extractor_kwargs )
110+ self .target_q_head1 = make_q_head (
111+ in_size = q_feat_dim ,
112+ hidden_sizes = [], # No hidden layers, just final linear layer
113+ activation = head_kwargs ["activation" ]
95114 )
96- self .target_q_net1 .load_state_dict (self .q_net1 .state_dict ())
97- self .target_q_net2 = make_q_head (
98- in_size = self .obs_size + self .action_size , ** head_kwargs
115+ self .target_q_net1 = nn .Sequential (self .target_q_feature_extractor1 , self .target_q_head1 )
116+
117+ self .target_q_feature_extractor2 , _ = make_feature_extractor (** q_feature_extractor_kwargs )
118+ self .target_q_head2 = make_q_head (
119+ in_size = q_feat_dim ,
120+ hidden_sizes = [], # No hidden layers, just final linear layer
121+ activation = head_kwargs ["activation" ]
99122 )
100- self .target_q_net2 .load_state_dict (self .q_net2 .state_dict ())
123+ self .target_q_net2 = nn .Sequential (self .target_q_feature_extractor2 , self .target_q_head2 )
124+
125+ # Copy weights from live to target networks
126+ self .target_q_feature_extractor1 .load_state_dict (self .q_feature_extractor1 .state_dict ())
127+ self .target_q_head1 .load_state_dict (self .q_head1 .state_dict ())
128+ self .target_q_feature_extractor2 .load_state_dict (self .q_feature_extractor2 .state_dict ())
129+ self .target_q_head2 .load_state_dict (self .q_head2 .state_dict ())
101130
102131 # Freeze target networks
103- for p in self .target_q_net1 .parameters ():
132+ for p in self .target_q_feature_extractor1 .parameters ():
133+ p .requires_grad = False
134+ for p in self .target_q_head1 .parameters ():
135+ p .requires_grad = False
136+ for p in self .target_q_feature_extractor2 .parameters ():
104137 p .requires_grad = False
105- for p in self .target_q_net2 .parameters ():
138+ for p in self .target_q_head2 .parameters ():
106139 p .requires_grad = False
107140
108141 # Create a value function wrapper for compatibility
@@ -133,7 +166,7 @@ def forward(
133166 Forward pass for policy sampling.
134167
135168 Returns:
136- action: torch.Tensor in [-1,1 ]
169+ action: torch.Tensor in rescaled range [action_low, action_high ]
137170 z: raw Gaussian sample before tanh
138171 mean: Gaussian mean
139172 log_std: Gaussian log std
@@ -155,7 +188,7 @@ def forward(
155188 # tanh→[-1,1]
156189 raw_action = torch .tanh (z )
157190
158- # **HERE** we rescale into [low,high ]
191+ # Rescale into [action_low, action_high ]
159192 action = raw_action * self .action_scale + self .action_bias
160193
161194 return action , z , mean , log_std
0 commit comments