77import torch .nn .functional as F
88import numpy as np
99import os
10+ import random
1011from scipy .signal import savgol_filter
1112from diffusers .schedulers .scheduling_ddpm import DDPMScheduler
1213from internnav .model .encoder .navdp_backbone import *
1516from internnav .configs .model .base_encoders import ModelCfg
1617from internnav .configs .trainer .exp import ExpCfg
1718
18-
19-
2019class NavDPModelConfig (PretrainedConfig ):
2120 model_type = 'navdp'
22-
2321 def __init__ (self , ** kwargs ):
2422 super ().__init__ (** kwargs )
2523 # pass in navdp_exp_cfg
2624 self .model_cfg = kwargs .get ('model_cfg' , None )
2725
28-
2926 @classmethod
3027 def from_dict (cls , config_dict ):
3128 if 'model_cfg' in config_dict :
@@ -89,7 +86,15 @@ def __init__(self, config: NavDPModelConfig):
8986 self .scratch = self .config .model_cfg ['il' ]['scratch' ]
9087 self .finetune = self .config .model_cfg ['il' ]['finetune' ]
9188 self .rgbd_encoder = NavDP_RGBD_Backbone (self .image_size ,self .token_dim ,memory_size = self .memory_size ,finetune = self .finetune ,device = self ._device )
89+ self .pixel_encoder = NavDP_PixelGoal_Backbone (self .image_size ,self .token_dim ,device = self ._device )
90+ self .image_encoder = NavDP_ImageGoal_Backbone (self .image_size ,self .token_dim ,device = self ._device )
9291 self .point_encoder = nn .Linear (3 ,self .token_dim )
92+
93+ if not self .finetune :
94+ for p in self .rgbd_encoder .parameters ():
95+ p .requires_grad = False
96+ self .rgbd_encoder .eval ()
97+
9398 decoder_layer = nn .TransformerDecoderLayer (d_model = self .token_dim ,
9499 nhead = self .attention_heads ,
95100 dim_feedforward = 4 * self .token_dim ,
@@ -101,7 +106,8 @@ def __init__(self, config: NavDPModelConfig):
101106 num_layers = self .temporal_depth )
102107 self .input_embed = nn .Linear (3 ,self .token_dim )
103108
104- self .cond_pos_embed = LearnablePositionalEncoding (self .token_dim , self .memory_size * 16 + 2 )
109+
110+ self .cond_pos_embed = LearnablePositionalEncoding (self .token_dim , self .memory_size * 16 + 4 )
105111 self .out_pos_embed = LearnablePositionalEncoding (self .token_dim , self .predict_size )
106112 self .drop = nn .Dropout (self .dropout )
107113 self .time_emb = SinusoidalPosEmb (self .token_dim )
@@ -114,9 +120,13 @@ def __init__(self, config: NavDPModelConfig):
114120 prediction_type = 'epsilon' )
115121 self .tgt_mask = (torch .triu (torch .ones (self .predict_size , self .predict_size )) == 1 ).transpose (0 , 1 )
116122 self .tgt_mask = self .tgt_mask .float ().masked_fill (self .tgt_mask == 0 , float ('-inf' )).masked_fill (self .tgt_mask == 1 , float (0.0 ))
117- self .cond_critic_mask = torch .zeros ((self .predict_size ,2 + self .memory_size * 16 ))
118- self .cond_critic_mask [:,0 :2 ] = float ('-inf' )
119123 self .tgt_mask = self .tgt_mask .to (self ._device )
124+
125+ self .cond_critic_mask = torch .zeros ((self .predict_size ,4 + self .memory_size * 16 ))
126+ self .cond_critic_mask [:,0 :4 ] = float ('-inf' )
127+
128+ self .pixel_aux_head = nn .Linear (self .token_dim ,3 )
129+ self .image_aux_head = nn .Linear (self .token_dim ,3 )
120130
121131 def to (self , device , * args , ** kwargs ):
122132 # first call the to method of the parent class
@@ -131,10 +141,6 @@ def to(self, device, *args, **kwargs):
131141 return self
132142
133143 def sample_noise (self ,action ):
134- # device = next(self.parameters()).device
135- # if device is None:
136- # device = action.device
137- # action = action.to(self._device)
138144 device = action .device
139145 noise = torch .randn (action .shape , device = device )
140146 timesteps = torch .randint (0 , self .noise_scheduler .config .num_train_timesteps ,(action .shape [0 ],), device = device ).long ()
@@ -146,7 +152,7 @@ def sample_noise(self,action):
146152 def predict_noise (self ,last_actions ,timestep ,goal_embed ,rgbd_embed ):
147153 action_embeds = self .input_embed (last_actions )
148154 time_embeds = self .time_emb (timestep .to (self ._device )).unsqueeze (1 )
149- cond_embedding = torch .cat ([time_embeds ,goal_embed ,rgbd_embed ],dim = 1 ) + self .cond_pos_embed (torch .cat ([time_embeds ,goal_embed ,rgbd_embed ],dim = 1 ))
155+ cond_embedding = torch .cat ([time_embeds ,goal_embed ,goal_embed , goal_embed , rgbd_embed ],dim = 1 ) + self .cond_pos_embed (torch .cat ([time_embeds , goal_embed , goal_embed ,goal_embed ,rgbd_embed ],dim = 1 ))
150156 cond_embedding = cond_embedding .repeat (action_embeds .shape [0 ],1 ,1 )
151157 input_embedding = action_embeds + self .out_pos_embed (action_embeds )
152158 output = self .decoder (tgt = input_embedding ,memory = cond_embedding , tgt_mask = self .tgt_mask .to (self ._device ))
@@ -159,13 +165,13 @@ def predict_critic(self,predict_trajectory,rgbd_embed):
159165 nogoal_embed = torch .zeros_like (repeat_rgbd_embed [:,0 :1 ])
160166 action_embeddings = self .input_embed (predict_trajectory )
161167 action_embeddings = action_embeddings + self .out_pos_embed (action_embeddings )
162- cond_embeddings = torch .cat ([nogoal_embed ,nogoal_embed ,repeat_rgbd_embed ],dim = 1 ) + self .cond_pos_embed (torch .cat ([nogoal_embed ,nogoal_embed ,repeat_rgbd_embed ],dim = 1 ))
168+ cond_embeddings = torch .cat ([nogoal_embed ,nogoal_embed ,nogoal_embed , nogoal_embed , repeat_rgbd_embed ],dim = 1 ) + self .cond_pos_embed (torch .cat ([nogoal_embed , nogoal_embed , nogoal_embed ,nogoal_embed ,repeat_rgbd_embed ],dim = 1 ))
163169 critic_output = self .decoder (tgt = action_embeddings , memory = cond_embeddings , memory_mask = self .cond_critic_mask )
164170 critic_output = self .layernorm (critic_output )
165171 critic_output = self .critic_head (critic_output .mean (dim = 1 ))[:,0 ]
166172 return critic_output
167173
168- def forward (self ,goal_point ,goal_image ,input_images ,input_depths ,output_actions ,augment_actions ):
174+ def forward (self ,goal_point ,goal_image ,goal_pixel , input_images ,input_depths ,output_actions ,augment_actions ):
169175 # """get device safely"""
170176 # # get device safely
171177 # try:
@@ -193,53 +199,61 @@ def forward(self,goal_point,goal_image,input_images,input_depths,output_actions,
193199 input_depths = input_depths .to (device )
194200
195201 ng_noise ,ng_time_embed ,ng_noisy_action_embed = self .sample_noise (tensor_label_actions )
196- pg_noise ,pg_time_embed ,pg_noisy_action_embed = self .sample_noise (tensor_label_actions )
197- # ig_noise,ig_time_embed,ig_noisy_action_embed = self.sample_noise(tensor_label_actions)
202+ mg_noise ,mg_time_embed ,mg_noisy_action_embed = self .sample_noise (tensor_label_actions )
198203
199204 rgbd_embed = self .rgbd_encoder (input_images ,input_depths )
200205 pointgoal_embed = self .point_encoder (tensor_point_goal ).unsqueeze (1 )
201206 nogoal_embed = torch .zeros_like (pointgoal_embed )
202- # imagegoal_embed = torch.zeros_like(pointgoal_embed)
207+ imagegoal_embed = self .image_encoder (goal_image ).unsqueeze (1 )
208+ pixelgoal_embed = self .pixel_encoder (goal_pixel ).unsqueeze (1 )
209+
210+ imagegoal_aux_pred = self .image_aux_head (imagegoal_embed [:,0 ])
211+ pixelgoal_aux_pred = self .pixel_aux_head (pixelgoal_embed [:,0 ])
203212
204213 label_embed = self .input_embed (tensor_label_actions ).detach ()
205214 augment_embed = self .input_embed (tensor_augment_actions ).detach ()
206215
207- cond_pos_embed = self .cond_pos_embed (torch .cat ([ng_time_embed ,nogoal_embed ,rgbd_embed ],dim = 1 ))
208- ng_cond_embeddings = self .drop (torch .cat ([ng_time_embed ,nogoal_embed ,rgbd_embed ],dim = 1 ) + cond_pos_embed )
209- pg_cond_embeddings = self .drop (torch .cat ([pg_time_embed ,pointgoal_embed ,rgbd_embed ],dim = 1 ) + cond_pos_embed )
210- # ig_cond_embeddings = self.drop(torch.cat([ig_time_embed,imagegoal_embed,rgbd_embed],dim=1) + cond_pos_embed)
216+ cond_pos_embed = self .cond_pos_embed (torch .cat ([ng_time_embed ,nogoal_embed ,imagegoal_embed ,pixelgoal_embed ,rgbd_embed ],dim = 1 ))
217+ ng_cond_embeddings = self .drop (torch .cat ([ng_time_embed ,nogoal_embed ,nogoal_embed ,nogoal_embed ,rgbd_embed ],dim = 1 ) + cond_pos_embed )
218+
219+ cand_goal_embed = [pointgoal_embed ,imagegoal_embed ,pixelgoal_embed ]
220+ batch_size = pointgoal_embed .shape [0 ]
221+
222+ # Generate deterministic selections for each sample in the batch using vectorized operations
223+ batch_indices = torch .arange (batch_size , device = pointgoal_embed .device )
224+ pattern_indices = batch_indices % 27 # 3^3 = 27 possible combinations
225+ selections_0 = pattern_indices % 3
226+ selections_1 = (pattern_indices // 3 ) % 3
227+ selections_2 = (pattern_indices // 9 ) % 3
228+ goal_embeds = torch .stack (cand_goal_embed , dim = 0 ) # [3, batch_size, 1, token_dim]
229+ selected_goals_0 = goal_embeds [selections_0 , torch .arange (batch_size ), :, :] # [batch_size, 1, token_dim]
230+ selected_goals_1 = goal_embeds [selections_1 , torch .arange (batch_size ), :, :]
231+ selected_goals_2 = goal_embeds [selections_2 , torch .arange (batch_size ), :, :]
232+ mg_cond_embed_tensor = torch .cat ([mg_time_embed , selected_goals_0 , selected_goals_1 , selected_goals_2 , rgbd_embed ], dim = 1 )
233+ mg_cond_embeddings = self .drop (mg_cond_embed_tensor + cond_pos_embed )
211234
212235 out_pos_embed = self .out_pos_embed (ng_noisy_action_embed )
213236 ng_action_embeddings = self .drop (ng_noisy_action_embed + out_pos_embed )
214- pg_action_embeddings = self .drop (pg_noisy_action_embed + out_pos_embed )
215- # ig_action_embeddings = self.drop(ig_noisy_action_embed + out_pos_embed)
237+ mg_action_embeddings = self .drop (mg_noisy_action_embed + out_pos_embed )
216238 label_action_embeddings = self .drop (label_embed + out_pos_embed )
217239 augment_action_embeddings = self .drop (augment_embed + out_pos_embed )
218240
219- # ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
220241 ng_output = self .decoder (tgt = ng_action_embeddings ,memory = ng_cond_embeddings , tgt_mask = self .tgt_mask )
221242 ng_output = self .layernorm (ng_output )
222243 noise_pred_ng = self .action_head (ng_output )
223244
224- pg_output = self .decoder (tgt = pg_action_embeddings ,memory = pg_cond_embeddings , tgt_mask = self .tgt_mask .to (ng_action_embeddings .device ))
225- # pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask)
226- pg_output = self .layernorm (pg_output )
227- noise_pred_pg = self .action_head (pg_output )
228-
229- # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
230- # ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask)
231- # ig_output = self.layernorm(ig_output)
232- # noise_pred_ig = self.action_head(ig_output)
245+ mg_output = self .decoder (tgt = mg_action_embeddings ,memory = mg_cond_embeddings , tgt_mask = self .tgt_mask .to (ng_action_embeddings .device ))
246+ mg_output = self .layernorm (mg_output )
247+ noise_pred_mg = self .action_head (mg_output )
233248
234249 cr_label_output = self .decoder (tgt = label_action_embeddings , memory = ng_cond_embeddings , memory_mask = self .cond_critic_mask .to (self ._device ))
235- # cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask)
236250 cr_label_output = self .layernorm (cr_label_output )
237251 cr_label_pred = self .critic_head (cr_label_output .mean (dim = 1 ))[:,0 ]
238252
239253 cr_augment_output = self .decoder (tgt = augment_action_embeddings , memory = ng_cond_embeddings , memory_mask = self .cond_critic_mask .to (self ._device ))
240254 cr_augment_output = self .layernorm (cr_augment_output )
241255 cr_augment_pred = self .critic_head (cr_augment_output .mean (dim = 1 ))[:,0 ]
242- return noise_pred_ng ,noise_pred_pg ,cr_label_pred ,cr_augment_pred ,[ng_noise ,pg_noise ]
256+ return noise_pred_ng ,noise_pred_mg ,cr_label_pred ,cr_augment_pred ,[ng_noise ,mg_noise ],[ imagegoal_aux_pred , pixelgoal_aux_pred ]
243257
244258 def _get_device (self ):
245259 """Safe get device information"""
0 commit comments