Skip to content

Commit 031e3d8

Browse files
committed
[FIX] update the code for multi-gpu training and checkpoint save
1 parent 4753396 commit 031e3d8

File tree

3 files changed

+157
-51
lines changed

3 files changed

+157
-51
lines changed

internnav/model/basemodel/navdp/navdp_policy.py

Lines changed: 49 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch.nn.functional as F
88
import numpy as np
99
import os
10+
import random
1011
from scipy.signal import savgol_filter
1112
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
1213
from internnav.model.encoder.navdp_backbone import *
@@ -15,17 +16,13 @@
1516
from internnav.configs.model.base_encoders import ModelCfg
1617
from internnav.configs.trainer.exp import ExpCfg
1718

18-
19-
2019
class NavDPModelConfig(PretrainedConfig):
2120
model_type = 'navdp'
22-
2321
def __init__(self, **kwargs):
2422
super().__init__(**kwargs)
2523
# pass in navdp_exp_cfg
2624
self.model_cfg = kwargs.get('model_cfg', None)
2725

28-
2926
@classmethod
3027
def from_dict(cls, config_dict):
3128
if 'model_cfg' in config_dict:
@@ -89,7 +86,15 @@ def __init__(self, config: NavDPModelConfig):
8986
self.scratch=self.config.model_cfg['il']['scratch']
9087
self.finetune=self.config.model_cfg['il']['finetune']
9188
self.rgbd_encoder = NavDP_RGBD_Backbone(self.image_size,self.token_dim,memory_size=self.memory_size,finetune=self.finetune,device=self._device)
89+
self.pixel_encoder = NavDP_PixelGoal_Backbone(self.image_size,self.token_dim,device=self._device)
90+
self.image_encoder = NavDP_ImageGoal_Backbone(self.image_size,self.token_dim,device=self._device)
9291
self.point_encoder = nn.Linear(3,self.token_dim)
92+
93+
if not self.finetune:
94+
for p in self.rgbd_encoder.parameters():
95+
p.requires_grad = False
96+
self.rgbd_encoder.eval()
97+
9398
decoder_layer = nn.TransformerDecoderLayer(d_model = self.token_dim,
9499
nhead = self.attention_heads,
95100
dim_feedforward = 4 * self.token_dim,
@@ -101,7 +106,8 @@ def __init__(self, config: NavDPModelConfig):
101106
num_layers = self.temporal_depth)
102107
self.input_embed = nn.Linear(3,self.token_dim)
103108

104-
self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 2)
109+
110+
self.cond_pos_embed = LearnablePositionalEncoding(self.token_dim, self.memory_size * 16 + 4)
105111
self.out_pos_embed = LearnablePositionalEncoding(self.token_dim, self.predict_size)
106112
self.drop = nn.Dropout(self.dropout)
107113
self.time_emb = SinusoidalPosEmb(self.token_dim)
@@ -114,9 +120,13 @@ def __init__(self, config: NavDPModelConfig):
114120
prediction_type='epsilon')
115121
self.tgt_mask = (torch.triu(torch.ones(self.predict_size, self.predict_size)) == 1).transpose(0, 1)
116122
self.tgt_mask = self.tgt_mask.float().masked_fill(self.tgt_mask == 0, float('-inf')).masked_fill(self.tgt_mask == 1, float(0.0))
117-
self.cond_critic_mask = torch.zeros((self.predict_size,2 + self.memory_size * 16))
118-
self.cond_critic_mask[:,0:2] = float('-inf')
119123
self.tgt_mask = self.tgt_mask.to(self._device)
124+
125+
self.cond_critic_mask = torch.zeros((self.predict_size,4 + self.memory_size * 16))
126+
self.cond_critic_mask[:,0:4] = float('-inf')
127+
128+
self.pixel_aux_head = nn.Linear(self.token_dim,3)
129+
self.image_aux_head = nn.Linear(self.token_dim,3)
120130

121131
def to(self, device, *args, **kwargs):
122132
# first call the to method of the parent class
@@ -131,10 +141,6 @@ def to(self, device, *args, **kwargs):
131141
return self
132142

133143
def sample_noise(self,action):
134-
# device = next(self.parameters()).device
135-
# if device is None:
136-
# device = action.device
137-
# action = action.to(self._device)
138144
device = action.device
139145
noise = torch.randn(action.shape, device=device)
140146
timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps,(action.shape[0],), device=device).long()
@@ -146,7 +152,7 @@ def sample_noise(self,action):
146152
def predict_noise(self,last_actions,timestep,goal_embed,rgbd_embed):
147153
action_embeds = self.input_embed(last_actions)
148154
time_embeds = self.time_emb(timestep.to(self._device)).unsqueeze(1)
149-
cond_embedding = torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([time_embeds,goal_embed,rgbd_embed],dim=1))
155+
cond_embedding = torch.cat([time_embeds,goal_embed,goal_embed,goal_embed,rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([time_embeds,goal_embed,goal_embed,goal_embed,rgbd_embed],dim=1))
150156
cond_embedding = cond_embedding.repeat(action_embeds.shape[0],1,1)
151157
input_embedding = action_embeds + self.out_pos_embed(action_embeds)
152158
output = self.decoder(tgt = input_embedding,memory = cond_embedding, tgt_mask = self.tgt_mask.to(self._device))
@@ -159,13 +165,13 @@ def predict_critic(self,predict_trajectory,rgbd_embed):
159165
nogoal_embed = torch.zeros_like(repeat_rgbd_embed[:,0:1])
160166
action_embeddings = self.input_embed(predict_trajectory)
161167
action_embeddings = action_embeddings + self.out_pos_embed(action_embeddings)
162-
cond_embeddings = torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1))
168+
cond_embeddings = torch.cat([nogoal_embed,nogoal_embed,nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1) + self.cond_pos_embed(torch.cat([nogoal_embed,nogoal_embed,nogoal_embed,nogoal_embed,repeat_rgbd_embed],dim=1))
163169
critic_output = self.decoder(tgt = action_embeddings, memory = cond_embeddings, memory_mask = self.cond_critic_mask)
164170
critic_output = self.layernorm(critic_output)
165171
critic_output = self.critic_head(critic_output.mean(dim=1))[:,0]
166172
return critic_output
167173

168-
def forward(self,goal_point,goal_image,input_images,input_depths,output_actions,augment_actions):
174+
def forward(self,goal_point,goal_image,goal_pixel,input_images,input_depths,output_actions,augment_actions):
169175
# """get device safely"""
170176
# # get device safely
171177
# try:
@@ -193,53 +199,61 @@ def forward(self,goal_point,goal_image,input_images,input_depths,output_actions,
193199
input_depths = input_depths.to(device)
194200

195201
ng_noise,ng_time_embed,ng_noisy_action_embed = self.sample_noise(tensor_label_actions)
196-
pg_noise,pg_time_embed,pg_noisy_action_embed = self.sample_noise(tensor_label_actions)
197-
# ig_noise,ig_time_embed,ig_noisy_action_embed = self.sample_noise(tensor_label_actions)
202+
mg_noise,mg_time_embed,mg_noisy_action_embed = self.sample_noise(tensor_label_actions)
198203

199204
rgbd_embed = self.rgbd_encoder(input_images,input_depths)
200205
pointgoal_embed = self.point_encoder(tensor_point_goal).unsqueeze(1)
201206
nogoal_embed = torch.zeros_like(pointgoal_embed)
202-
# imagegoal_embed = torch.zeros_like(pointgoal_embed)
207+
imagegoal_embed = self.image_encoder(goal_image).unsqueeze(1)
208+
pixelgoal_embed = self.pixel_encoder(goal_pixel).unsqueeze(1)
209+
210+
imagegoal_aux_pred = self.image_aux_head(imagegoal_embed[:,0])
211+
pixelgoal_aux_pred = self.pixel_aux_head(pixelgoal_embed[:,0])
203212

204213
label_embed = self.input_embed(tensor_label_actions).detach()
205214
augment_embed = self.input_embed(tensor_augment_actions).detach()
206215

207-
cond_pos_embed = self.cond_pos_embed(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1))
208-
ng_cond_embeddings = self.drop(torch.cat([ng_time_embed,nogoal_embed,rgbd_embed],dim=1) + cond_pos_embed)
209-
pg_cond_embeddings = self.drop(torch.cat([pg_time_embed,pointgoal_embed,rgbd_embed],dim=1) + cond_pos_embed)
210-
# ig_cond_embeddings = self.drop(torch.cat([ig_time_embed,imagegoal_embed,rgbd_embed],dim=1) + cond_pos_embed)
216+
cond_pos_embed = self.cond_pos_embed(torch.cat([ng_time_embed,nogoal_embed,imagegoal_embed,pixelgoal_embed,rgbd_embed],dim=1))
217+
ng_cond_embeddings = self.drop(torch.cat([ng_time_embed,nogoal_embed,nogoal_embed,nogoal_embed,rgbd_embed],dim=1) + cond_pos_embed)
218+
219+
cand_goal_embed = [pointgoal_embed,imagegoal_embed,pixelgoal_embed]
220+
batch_size = pointgoal_embed.shape[0]
221+
222+
# Generate deterministic selections for each sample in the batch using vectorized operations
223+
batch_indices = torch.arange(batch_size, device=pointgoal_embed.device)
224+
pattern_indices = batch_indices % 27 # 3^3 = 27 possible combinations
225+
selections_0 = pattern_indices % 3
226+
selections_1 = (pattern_indices // 3) % 3
227+
selections_2 = (pattern_indices // 9) % 3
228+
goal_embeds = torch.stack(cand_goal_embed, dim=0) # [3, batch_size, 1, token_dim]
229+
selected_goals_0 = goal_embeds[selections_0, torch.arange(batch_size), :, :] # [batch_size, 1, token_dim]
230+
selected_goals_1 = goal_embeds[selections_1, torch.arange(batch_size), :, :]
231+
selected_goals_2 = goal_embeds[selections_2, torch.arange(batch_size), :, :]
232+
mg_cond_embed_tensor = torch.cat([mg_time_embed, selected_goals_0, selected_goals_1, selected_goals_2, rgbd_embed], dim=1)
233+
mg_cond_embeddings = self.drop(mg_cond_embed_tensor + cond_pos_embed)
211234

212235
out_pos_embed = self.out_pos_embed(ng_noisy_action_embed)
213236
ng_action_embeddings = self.drop(ng_noisy_action_embed + out_pos_embed)
214-
pg_action_embeddings = self.drop(pg_noisy_action_embed + out_pos_embed)
215-
# ig_action_embeddings = self.drop(ig_noisy_action_embed + out_pos_embed)
237+
mg_action_embeddings = self.drop(mg_noisy_action_embed + out_pos_embed)
216238
label_action_embeddings = self.drop(label_embed + out_pos_embed)
217239
augment_action_embeddings = self.drop(augment_embed + out_pos_embed)
218240

219-
# ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
220241
ng_output = self.decoder(tgt = ng_action_embeddings,memory = ng_cond_embeddings, tgt_mask = self.tgt_mask)
221242
ng_output = self.layernorm(ng_output)
222243
noise_pred_ng = self.action_head(ng_output)
223244

224-
pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
225-
# pg_output = self.decoder(tgt = pg_action_embeddings,memory = pg_cond_embeddings, tgt_mask = self.tgt_mask)
226-
pg_output = self.layernorm(pg_output)
227-
noise_pred_pg = self.action_head(pg_output)
228-
229-
# ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
230-
# ig_output = self.decoder(tgt = ig_action_embeddings,memory = ig_cond_embeddings, tgt_mask = self.tgt_mask)
231-
# ig_output = self.layernorm(ig_output)
232-
# noise_pred_ig = self.action_head(ig_output)
245+
mg_output = self.decoder(tgt = mg_action_embeddings,memory = mg_cond_embeddings, tgt_mask = self.tgt_mask.to(ng_action_embeddings.device))
246+
mg_output = self.layernorm(mg_output)
247+
noise_pred_mg = self.action_head(mg_output)
233248

234249
cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device))
235-
# cr_label_output = self.decoder(tgt = label_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask)
236250
cr_label_output = self.layernorm(cr_label_output)
237251
cr_label_pred = self.critic_head(cr_label_output.mean(dim=1))[:,0]
238252

239253
cr_augment_output = self.decoder(tgt = augment_action_embeddings, memory = ng_cond_embeddings, memory_mask = self.cond_critic_mask.to(self._device))
240254
cr_augment_output = self.layernorm(cr_augment_output)
241255
cr_augment_pred = self.critic_head(cr_augment_output.mean(dim=1))[:,0]
242-
return noise_pred_ng,noise_pred_pg,cr_label_pred,cr_augment_pred,[ng_noise,pg_noise]
256+
return noise_pred_ng,noise_pred_mg,cr_label_pred,cr_augment_pred,[ng_noise,mg_noise],[imagegoal_aux_pred,pixelgoal_aux_pred]
243257

244258
def _get_device(self):
245259
"""Safe get device information"""

internnav/model/encoder/navdp_backbone.py

Lines changed: 77 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def forward(self,images,depths):
260260
memory_token = self.former_net(former_query,former_token)
261261
memory_token = self.project_layer(memory_token)
262262
return memory_token
263+
263264
def _get_device(self):
264265
"""get device safely"""
265266
# try to get device through model parameters
@@ -293,6 +294,12 @@ def __init__(self,
293294
embed_size=512,
294295
device='cuda:0'):
295296
super().__init__()
297+
if device is None:
298+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
299+
elif isinstance(device, int):
300+
device = torch.device(f"cuda:{device}")
301+
elif isinstance(device, str):
302+
device = torch.device(device)
296303
self.device = device
297304
self.image_size = image_size
298305
self.embed_size = embed_size
@@ -306,37 +313,103 @@ def __init__(self,
306313
padding = self.imagegoal_encoder.patch_embed.proj.padding)
307314
self.imagegoal_encoder.train()
308315
self.project_layer = nn.Linear(384,embed_size)
316+
self.to(device)
309317

310318
def forward(self,images):
311319
assert len(images.shape) == 4 # B,C,H,W
312-
tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2)
320+
device = self._get_device()
321+
images = images.to(device)
322+
tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,3,1,2)
313323
image_token = self.imagegoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1)
314324
image_token = self.project_layer(image_token)
315325
return image_token
326+
327+
def _get_device(self):
328+
"""get device safely"""
329+
# try to get device through model parameters
330+
try:
331+
for param in self.parameters():
332+
return param.device
333+
except StopIteration:
334+
pass
335+
336+
# try to get device through buffer
337+
try:
338+
for buffer in self.buffers():
339+
return buffer.device
340+
except StopIteration:
341+
pass
342+
343+
# try to get device through submodule
344+
for module in self.children():
345+
try:
346+
for param in module.parameters():
347+
return param.device
348+
except StopIteration:
349+
continue
350+
351+
# finally revert to default device
352+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
316353

317354
class NavDP_PixelGoal_Backbone(nn.Module):
318355
def __init__(self,
319356
image_size=224,
320357
embed_size=512,
321358
device='cuda:0'):
322359
super().__init__()
360+
if device is None:
361+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
362+
elif isinstance(device, int):
363+
device = torch.device(f"cuda:{device}")
364+
elif isinstance(device, str):
365+
device = torch.device(device)
323366
self.device = device
324367
self.image_size = image_size
325368
self.embed_size = embed_size
326369
model_configs = {'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}}
327370
self.pixelgoal_encoder = DepthAnythingV2(**model_configs['vits'])
328371
self.pixelgoal_encoder = self.pixelgoal_encoder.pretrained.float()
329-
self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=4,
372+
self.pixelgoal_encoder.patch_embed.proj = nn.Conv2d(in_channels=7,
330373
out_channels = self.pixelgoal_encoder.patch_embed.proj.out_channels,
331374
kernel_size = self.pixelgoal_encoder.patch_embed.proj.kernel_size,
332375
stride = self.pixelgoal_encoder.patch_embed.proj.stride,
333376
padding = self.pixelgoal_encoder.patch_embed.proj.padding)
334377
self.pixelgoal_encoder.train()
335378
self.project_layer = nn.Linear(384,embed_size)
379+
self.to(device)
336380

337381
def forward(self,images):
338382
assert len(images.shape) == 4 # B,C,H,W
339-
tensor_images = torch.as_tensor(images,dtype=torch.float32,device=self.device).permute(0,3,1,2)
383+
device = self._get_device()
384+
images = images.to(device)
385+
tensor_images = torch.as_tensor(images,dtype=torch.float32,device=device).permute(0,3,1,2)
340386
image_token = self.pixelgoal_encoder.get_intermediate_layers(tensor_images)[0].mean(dim=1)
341387
image_token = self.project_layer(image_token)
342-
return image_token
388+
return image_token
389+
390+
def _get_device(self):
391+
"""get device safely"""
392+
# try to get device through model parameters
393+
try:
394+
for param in self.parameters():
395+
return param.device
396+
except StopIteration:
397+
pass
398+
399+
# try to get device through buffer
400+
try:
401+
for buffer in self.buffers():
402+
return buffer.device
403+
except StopIteration:
404+
pass
405+
406+
# try to get device through submodule
407+
for module in self.children():
408+
try:
409+
for param in module.parameters():
410+
return param.device
411+
except StopIteration:
412+
continue
413+
414+
# finally revert to default device
415+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")

0 commit comments

Comments
 (0)