22import itertools
33import os
44import re
5- import time
6- import torch
75import sys
8- import numpy as np
6+ import time
97from datetime import datetime
108from pathlib import Path
119
10+ import numpy as np
11+ import torch
12+
1213sys .path .append (str (Path (__file__ ).parent .parent .parent ))
1314
14- from PIL import Image , ImageFile , ImageDraw , ImageFont
15- from internnav .model .utils .vln_utils import split_and_clean , S2Output , traj_to_actions
1615from collections import OrderedDict
1716
18- from transformers import (
19- AutoTokenizer ,
20- AutoProcessor ,
21- )
22- from internnav .model .basemodel .internvla_n1 .internvla_n1 import InternVLAN1ForCausalLM
17+ from PIL import Image
18+ from transformers import AutoProcessor
2319
20+ from internnav .model .basemodel .internvla_n1 .internvla_n1 import InternVLAN1ForCausalLM
21+ from internnav .model .utils .vln_utils import S2Output , split_and_clean , traj_to_actions
2422
2523DEFAULT_IMAGE_TOKEN = "<image>"
24+
25+
2626class InternVLAN1AsyncAgent :
2727 def __init__ (self , args ):
2828 self .device = torch .device (args .device )
2929 self .save_dir = "test_data/" + datetime .now ().strftime ("%Y%m%d_%H%M%S" )
3030 self .model = InternVLAN1ForCausalLM .from_pretrained (
31- args .model_path , torch_dtype = torch .bfloat16 ,
32- attn_implementation = "flash_attention_2" , device_map = {"" : self .device }
31+ args .model_path ,
32+ torch_dtype = torch .bfloat16 ,
33+ attn_implementation = "flash_attention_2" ,
34+ device_map = {"" : self .device },
3335 )
3436 self .model .eval ()
3537 self .model .to (self .device )
36-
38+
3739 self .processor = AutoProcessor .from_pretrained (args .model_path )
3840 self .processor .tokenizer .padding_side = 'left'
39-
41+
4042 self .resize_w = args .resize_w
4143 self .resize_h = args .resize_h
4244 self .num_history = args .num_history
43-
45+
4446 prompt = f"You are an autonomous navigation assistant. Your task is to <instruction>. Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task."
4547 answer = ""
4648 self .conversation = [{"from" : "human" , "value" : prompt }, {"from" : "gpt" , "value" : answer }]
4749 self .conjunctions = [
48- 'you can see ' ,
49- 'in front of you is ' ,
50- 'there is ' ,
51- 'you can spot ' ,
52- 'you are toward the ' ,
53- 'ahead of you is ' ,
54- 'in your sight is '
55- ]
56-
57- self .actions2idx = OrderedDict ({
58- 'STOP' : [0 ],
59- "↑" : [1 ],
60- "←" : [2 ],
61- "→" : [3 ],
62- "↓" : [5 ],
63- })
64-
50+ 'you can see ' ,
51+ 'in front of you is ' ,
52+ 'there is ' ,
53+ 'you can spot ' ,
54+ 'you are toward the ' ,
55+ 'ahead of you is ' ,
56+ 'in your sight is ' ,
57+ ]
58+
59+ self .actions2idx = OrderedDict (
60+ {
61+ 'STOP' : [0 ],
62+ "↑" : [1 ],
63+ "←" : [2 ],
64+ "→" : [3 ],
65+ "↓" : [5 ],
66+ }
67+ )
68+
6569 self .rgb_list = []
6670 self .depth_list = []
6771 self .pose_list = []
68- self .episode_idx = 0
72+ self .episode_idx = 0
6973 self .conversation_history = []
7074 self .llm_output = ""
7175 self .past_key_values = None
7276 self .last_s2_idx = - 100
73-
77+
7478 # output
7579 self .output_action = None
7680 self .output_latent = None
7781 self .output_pixel = None
7882 self .pixel_goal_rgb = None
7983 self .pixel_goal_depth = None
80-
84+
8185 def reset (self ):
8286 self .rgb_list = []
8387 self .depth_list = []
@@ -86,63 +90,74 @@ def reset(self):
8690 self .conversation_history = []
8791 self .llm_output = ""
8892 self .past_key_values = None
89-
93+
9094 self .save_dir = "test_data/" + datetime .now ().strftime ("%Y%m%d_%H%M%S" )
9195 os .makedirs (self .save_dir , exist_ok = True )
96+
9297 def parse_actions (self , output ):
9398 action_patterns = '|' .join (re .escape (action ) for action in self .actions2idx )
9499 regex = re .compile (action_patterns )
95100 matches = regex .findall (output )
96101 actions = [self .actions2idx [match ] for match in matches ]
97102 actions = itertools .chain .from_iterable (actions )
98103 return list (actions )
99-
104+
100105 def step_no_infer (self , rgb , depth , pose ):
101106 image = Image .fromarray (rgb ).convert ('RGB' )
102107 raw_image_size = image .size
103108 image = image .resize ((self .resize_w , self .resize_h ))
104109 self .rgb_list .append (image )
105110 image .save (f"{ self .save_dir } /debug_raw_{ self .episode_idx :04d} .jpg" )
106111 self .episode_idx += 1
107-
108- def trajectory_tovw (self , trajectory , kp = 1.0 ):
112+
113+ def trajectory_tovw (self , trajectory , kp = 1.0 ):
109114 subgoal = trajectory [- 1 ]
110115 linear_vel , angular_vel = kp * np .linalg .norm (subgoal [:2 ]), kp * subgoal [2 ]
111116 linear_vel = np .clip (linear_vel , 0 , 0.5 )
112117 angular_vel = np .clip (angular_vel , - 0.5 , 0.5 )
113118 return linear_vel , angular_vel
114119
115- def step (self , rgb , depth , pose , instruction , intrinsic , look_down = False ):
120+ def step (self , rgb , depth , pose , instruction , intrinsic , look_down = False ):
116121 dual_sys_output = S2Output ()
117122 PLAN_STEP_GAP = 8
118- no_output_flag = ( self .output_action is None and self .output_latent is None )
123+ no_output_flag = self .output_action is None and self .output_latent is None
119124 if (self .episode_idx - self .last_s2_idx > PLAN_STEP_GAP ) or look_down or no_output_flag :
120- self .output_action , self .output_latent , self .output_pixel = self .step_s2 (rgb , depth , pose , instruction , intrinsic , look_down )
125+ self .output_action , self .output_latent , self .output_pixel = self .step_s2 (
126+ rgb , depth , pose , instruction , intrinsic , look_down
127+ )
121128 self .last_s2_idx = self .episode_idx
122129 dual_sys_output .output_pixel = self .output_pixel
123130 self .pixel_goal_rgb = copy .deepcopy (rgb )
124131 self .pixel_goal_depth = copy .deepcopy (depth )
125132 else :
126133 self .step_no_infer (rgb , depth , pose )
127-
128-
134+
129135 if self .output_action is not None :
130136 dual_sys_output .output_action = copy .deepcopy (self .output_action )
131- self .output_action = None
132- elif self .output_latent is not None :
137+ self .output_action = None
138+ elif self .output_latent is not None :
133139 processed_pixel_rgb = np .array (Image .fromarray (self .pixel_goal_rgb ).resize ((224 , 224 ))) / 255
134140 processed_pixel_depth = np .array (Image .fromarray (self .pixel_goal_depth ).resize ((224 , 224 )))
135141 processed_rgb = np .array (Image .fromarray (rgb ).resize ((224 , 224 ))) / 255
136142 processed_depth = np .array (Image .fromarray (depth ).resize ((224 , 224 )))
137- rgbs = torch .stack ([torch .from_numpy (processed_pixel_rgb ), torch .from_numpy (processed_rgb )]).unsqueeze (0 ).to (self .device )
138- depths = torch .stack ([torch .from_numpy (processed_pixel_depth ), torch .from_numpy (processed_depth )]).unsqueeze (0 ).unsqueeze (- 1 ).to (self .device )
143+ rgbs = (
144+ torch .stack ([torch .from_numpy (processed_pixel_rgb ), torch .from_numpy (processed_rgb )])
145+ .unsqueeze (0 )
146+ .to (self .device )
147+ )
148+ depths = (
149+ torch .stack ([torch .from_numpy (processed_pixel_depth ), torch .from_numpy (processed_depth )])
150+ .unsqueeze (0 )
151+ .unsqueeze (- 1 )
152+ .to (self .device )
153+ )
139154 trajectories = self .step_s1 (self .output_latent , rgbs , depths )
140-
155+
141156 dual_sys_output .output_action = traj_to_actions (trajectories )
142157
143158 return dual_sys_output
144-
145- def step_s2 (self , rgb , depth , pose , instruction , intrinsic , look_down = False ):
159+
160+ def step_s2 (self , rgb , depth , pose , instruction , intrinsic , look_down = False ):
146161 image = Image .fromarray (rgb ).convert ('RGB' )
147162 raw_image_size = image .size
148163 if not look_down :
@@ -152,9 +167,9 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
152167 else :
153168 image .save (f"{ self .save_dir } /debug_raw_{ self .episode_idx :04d} _look_down.jpg" )
154169 if not look_down :
155- self .conversation_history = []
170+ self .conversation_history = []
156171 self .past_key_values = None
157-
172+
158173 sources = copy .deepcopy (self .conversation )
159174 sources [0 ]["value" ] = sources [0 ]["value" ].replace ('<instruction>.' , instruction )
160175 cur_images = self .rgb_list [- 1 :]
@@ -164,7 +179,7 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
164179 history_id = np .unique (np .linspace (0 , self .episode_idx - 1 , self .num_history , dtype = np .int32 )).tolist ()
165180 placeholder = (DEFAULT_IMAGE_TOKEN + '\n ' ) * len (history_id )
166181 sources [0 ]["value" ] += f' These are your historical observations: { placeholder } .'
167-
182+
168183 history_id = sorted (history_id )
169184 self .input_images = [self .rgb_list [i ] for i in history_id ] + cur_images
170185 input_img_id = 0
@@ -174,64 +189,64 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
174189 input_img_id = - 1
175190 assert self .llm_output != "" , "Last llm_output should not be empty when look down"
176191 sources = [{"from" : "human" , "value" : "" }, {"from" : "gpt" , "value" : "" }]
177- self .conversation_history .append ({ 'role' : 'assistant' , 'content' : [{ 'type' : 'text' , 'text' : self .llm_output }]})
178-
192+ self .conversation_history .append (
193+ {'role' : 'assistant' , 'content' : [{'type' : 'text' , 'text' : self .llm_output }]}
194+ )
195+
179196 prompt = self .conjunctions [0 ] + DEFAULT_IMAGE_TOKEN
180197 sources [0 ]["value" ] += f" { prompt } ."
181198 prompt_instruction = copy .deepcopy (sources [0 ]["value" ])
182199 parts = split_and_clean (prompt_instruction )
183-
200+
184201 content = []
185- for i in range (len (parts )):
202+ for i in range (len (parts )):
186203 if parts [i ] == "<image>" :
187204 content .append ({"type" : "image" , "image" : self .input_images [input_img_id ]})
188- input_img_id += 1
205+ input_img_id += 1
189206 else :
190- content .append ({"type" : "text" , "text" : parts [i ]})
191-
207+ content .append ({"type" : "text" , "text" : parts [i ]})
208+
192209 self .conversation_history .append ({'role' : 'user' , 'content' : content })
193-
194- text = self .processor .apply_chat_template (
195- self .conversation_history , tokenize = False , add_generation_prompt = True
196- )
197-
210+
211+ text = self .processor .apply_chat_template (self .conversation_history , tokenize = False , add_generation_prompt = True )
212+
198213 inputs = self .processor (text = [text ], images = self .input_images , return_tensors = "pt" ).to (self .device )
199214 t0 = time .time ()
200215 with torch .no_grad ():
201216 outputs = self .model .generate (
202- ** inputs ,
203- max_new_tokens = 128 ,
217+ ** inputs ,
218+ max_new_tokens = 128 ,
204219 do_sample = False ,
205220 use_cache = True ,
206221 past_key_values = self .past_key_values ,
207222 return_dict_in_generate = True ,
208223 raw_input_ids = copy .deepcopy (inputs .input_ids ),
209224 )
210225 output_ids = outputs .sequences
211-
226+
212227 t1 = time .time ()
213- self .llm_output = self .processor .tokenizer .decode (output_ids [0 ][inputs .input_ids .shape [1 ]:], skip_special_tokens = True )
228+ self .llm_output = self .processor .tokenizer .decode (
229+ output_ids [0 ][inputs .input_ids .shape [1 ] :], skip_special_tokens = True
230+ )
214231 with open (f"{ self .save_dir } /llm_output_{ self .episode_idx :04d} .txt" , 'w' ) as f :
215232 f .write (self .llm_output )
216233 self .last_output_ids = copy .deepcopy (output_ids [0 ])
217234 self .past_key_values = copy .deepcopy (outputs .past_key_values )
218235 print (f"output { self .episode_idx } { self .llm_output } cost:{ t1 - t0 } s" )
219- if bool (re .search (r'\d' , self .llm_output )):
236+ if bool (re .search (r'\d' , self .llm_output )):
220237 coord = [int (c ) for c in re .findall (r'\d+' , self .llm_output )]
221238 pixel_goal = [int (coord [1 ]), int (coord [0 ])]
222- image_grid_thw = torch .cat (
223- [thw .unsqueeze (0 ) for thw in inputs .image_grid_thw ], dim = 0
224- )
239+ image_grid_thw = torch .cat ([thw .unsqueeze (0 ) for thw in inputs .image_grid_thw ], dim = 0 )
225240 pixel_values = inputs .pixel_values
226241 t0 = time .time ()
227242 with torch .no_grad ():
228243 traj_latents = self .model .generate_latents (output_ids , pixel_values , image_grid_thw )
229244 return None , traj_latents , pixel_goal
230-
245+
231246 else :
232247 action_seq = self .parse_actions (self .llm_output )
233- return action_seq , None , None
234-
235- def step_s1 (self , latent , rgb , depth ):
248+ return action_seq , None , None
249+
250+ def step_s1 (self , latent , rgb , depth ):
236251 all_trajs = self .model .generate_traj (latent , rgb , depth , use_async = True )
237252 return all_trajs
0 commit comments