Skip to content

Commit 66b1444

Browse files
committed
[fix] precommit fix
1 parent cb3a53a commit 66b1444

File tree

6 files changed

+304
-275
lines changed

6 files changed

+304
-275
lines changed

internnav/agent/internvla_n1_agent_realworld.py

Lines changed: 93 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,82 +2,86 @@
22
import itertools
33
import os
44
import re
5-
import time
6-
import torch
75
import sys
8-
import numpy as np
6+
import time
97
from datetime import datetime
108
from pathlib import Path
119

10+
import numpy as np
11+
import torch
12+
1213
sys.path.append(str(Path(__file__).parent.parent.parent))
1314

14-
from PIL import Image, ImageFile, ImageDraw, ImageFont
15-
from internnav.model.utils.vln_utils import split_and_clean, S2Output, traj_to_actions
1615
from collections import OrderedDict
1716

18-
from transformers import (
19-
AutoTokenizer,
20-
AutoProcessor,
21-
)
22-
from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM
17+
from PIL import Image
18+
from transformers import AutoProcessor
2319

20+
from internnav.model.basemodel.internvla_n1.internvla_n1 import InternVLAN1ForCausalLM
21+
from internnav.model.utils.vln_utils import S2Output, split_and_clean, traj_to_actions
2422

2523
DEFAULT_IMAGE_TOKEN = "<image>"
24+
25+
2626
class InternVLAN1AsyncAgent:
2727
def __init__(self, args):
2828
self.device = torch.device(args.device)
2929
self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S")
3030
self.model = InternVLAN1ForCausalLM.from_pretrained(
31-
args.model_path, torch_dtype=torch.bfloat16,
32-
attn_implementation="flash_attention_2", device_map={"": self.device}
31+
args.model_path,
32+
torch_dtype=torch.bfloat16,
33+
attn_implementation="flash_attention_2",
34+
device_map={"": self.device},
3335
)
3436
self.model.eval()
3537
self.model.to(self.device)
36-
38+
3739
self.processor = AutoProcessor.from_pretrained(args.model_path)
3840
self.processor.tokenizer.padding_side = 'left'
39-
41+
4042
self.resize_w = args.resize_w
4143
self.resize_h = args.resize_h
4244
self.num_history = args.num_history
43-
45+
4446
prompt = f"You are an autonomous navigation assistant. Your task is to <instruction>. Where should you go next to stay on track? Please output the next waypoint's coordinates in the image. Please output STOP when you have successfully completed the task."
4547
answer = ""
4648
self.conversation = [{"from": "human", "value": prompt}, {"from": "gpt", "value": answer}]
4749
self.conjunctions = [
48-
'you can see ',
49-
'in front of you is ',
50-
'there is ',
51-
'you can spot ',
52-
'you are toward the ',
53-
'ahead of you is ',
54-
'in your sight is '
55-
]
56-
57-
self.actions2idx = OrderedDict({
58-
'STOP': [0],
59-
"↑": [1],
60-
"←": [2],
61-
"→": [3],
62-
"↓": [5],
63-
})
64-
50+
'you can see ',
51+
'in front of you is ',
52+
'there is ',
53+
'you can spot ',
54+
'you are toward the ',
55+
'ahead of you is ',
56+
'in your sight is ',
57+
]
58+
59+
self.actions2idx = OrderedDict(
60+
{
61+
'STOP': [0],
62+
"↑": [1],
63+
"←": [2],
64+
"→": [3],
65+
"↓": [5],
66+
}
67+
)
68+
6569
self.rgb_list = []
6670
self.depth_list = []
6771
self.pose_list = []
68-
self.episode_idx = 0
72+
self.episode_idx = 0
6973
self.conversation_history = []
7074
self.llm_output = ""
7175
self.past_key_values = None
7276
self.last_s2_idx = -100
73-
77+
7478
# output
7579
self.output_action = None
7680
self.output_latent = None
7781
self.output_pixel = None
7882
self.pixel_goal_rgb = None
7983
self.pixel_goal_depth = None
80-
84+
8185
def reset(self):
8286
self.rgb_list = []
8387
self.depth_list = []
@@ -86,63 +90,74 @@ def reset(self):
8690
self.conversation_history = []
8791
self.llm_output = ""
8892
self.past_key_values = None
89-
93+
9094
self.save_dir = "test_data/" + datetime.now().strftime("%Y%m%d_%H%M%S")
9195
os.makedirs(self.save_dir, exist_ok=True)
96+
9297
def parse_actions(self, output):
9398
action_patterns = '|'.join(re.escape(action) for action in self.actions2idx)
9499
regex = re.compile(action_patterns)
95100
matches = regex.findall(output)
96101
actions = [self.actions2idx[match] for match in matches]
97102
actions = itertools.chain.from_iterable(actions)
98103
return list(actions)
99-
104+
100105
def step_no_infer(self, rgb, depth, pose):
101106
image = Image.fromarray(rgb).convert('RGB')
102107
raw_image_size = image.size
103108
image = image.resize((self.resize_w, self.resize_h))
104109
self.rgb_list.append(image)
105110
image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}.jpg")
106111
self.episode_idx += 1
107-
108-
def trajectory_tovw(self, trajectory, kp = 1.0):
112+
113+
def trajectory_tovw(self, trajectory, kp=1.0):
109114
subgoal = trajectory[-1]
110115
linear_vel, angular_vel = kp * np.linalg.norm(subgoal[:2]), kp * subgoal[2]
111116
linear_vel = np.clip(linear_vel, 0, 0.5)
112117
angular_vel = np.clip(angular_vel, -0.5, 0.5)
113118
return linear_vel, angular_vel
114119

115-
def step(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
120+
def step(self, rgb, depth, pose, instruction, intrinsic, look_down=False):
116121
dual_sys_output = S2Output()
117122
PLAN_STEP_GAP = 8
118-
no_output_flag = (self.output_action is None and self.output_latent is None)
123+
no_output_flag = self.output_action is None and self.output_latent is None
119124
if (self.episode_idx - self.last_s2_idx > PLAN_STEP_GAP) or look_down or no_output_flag:
120-
self.output_action, self.output_latent, self.output_pixel = self.step_s2(rgb, depth, pose, instruction, intrinsic, look_down)
125+
self.output_action, self.output_latent, self.output_pixel = self.step_s2(
126+
rgb, depth, pose, instruction, intrinsic, look_down
127+
)
121128
self.last_s2_idx = self.episode_idx
122129
dual_sys_output.output_pixel = self.output_pixel
123130
self.pixel_goal_rgb = copy.deepcopy(rgb)
124131
self.pixel_goal_depth = copy.deepcopy(depth)
125132
else:
126133
self.step_no_infer(rgb, depth, pose)
127-
128-
134+
129135
if self.output_action is not None:
130136
dual_sys_output.output_action = copy.deepcopy(self.output_action)
131-
self.output_action = None
132-
elif self.output_latent is not None:
137+
self.output_action = None
138+
elif self.output_latent is not None:
133139
processed_pixel_rgb = np.array(Image.fromarray(self.pixel_goal_rgb).resize((224, 224))) / 255
134140
processed_pixel_depth = np.array(Image.fromarray(self.pixel_goal_depth).resize((224, 224)))
135141
processed_rgb = np.array(Image.fromarray(rgb).resize((224, 224))) / 255
136142
processed_depth = np.array(Image.fromarray(depth).resize((224, 224)))
137-
rgbs = torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)]).unsqueeze(0).to(self.device)
138-
depths = torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)]).unsqueeze(0).unsqueeze(-1).to(self.device)
143+
rgbs = (
144+
torch.stack([torch.from_numpy(processed_pixel_rgb), torch.from_numpy(processed_rgb)])
145+
.unsqueeze(0)
146+
.to(self.device)
147+
)
148+
depths = (
149+
torch.stack([torch.from_numpy(processed_pixel_depth), torch.from_numpy(processed_depth)])
150+
.unsqueeze(0)
151+
.unsqueeze(-1)
152+
.to(self.device)
153+
)
139154
trajectories = self.step_s1(self.output_latent, rgbs, depths)
140-
155+
141156
dual_sys_output.output_action = traj_to_actions(trajectories)
142157

143158
return dual_sys_output
144-
145-
def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
159+
160+
def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down=False):
146161
image = Image.fromarray(rgb).convert('RGB')
147162
raw_image_size = image.size
148163
if not look_down:
@@ -152,9 +167,9 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
152167
else:
153168
image.save(f"{self.save_dir}/debug_raw_{self.episode_idx:04d}_look_down.jpg")
154169
if not look_down:
155-
self.conversation_history = []
170+
self.conversation_history = []
156171
self.past_key_values = None
157-
172+
158173
sources = copy.deepcopy(self.conversation)
159174
sources[0]["value"] = sources[0]["value"].replace('<instruction>.', instruction)
160175
cur_images = self.rgb_list[-1:]
@@ -164,7 +179,7 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
164179
history_id = np.unique(np.linspace(0, self.episode_idx - 1, self.num_history, dtype=np.int32)).tolist()
165180
placeholder = (DEFAULT_IMAGE_TOKEN + '\n') * len(history_id)
166181
sources[0]["value"] += f' These are your historical observations: {placeholder}.'
167-
182+
168183
history_id = sorted(history_id)
169184
self.input_images = [self.rgb_list[i] for i in history_id] + cur_images
170185
input_img_id = 0
@@ -174,64 +189,64 @@ def step_s2(self, rgb, depth, pose, instruction, intrinsic, look_down = False):
174189
input_img_id = -1
175190
assert self.llm_output != "", "Last llm_output should not be empty when look down"
176191
sources = [{"from": "human", "value": ""}, {"from": "gpt", "value": ""}]
177-
self.conversation_history.append({ 'role': 'assistant', 'content': [{ 'type': 'text', 'text': self.llm_output}]})
178-
192+
self.conversation_history.append(
193+
{'role': 'assistant', 'content': [{'type': 'text', 'text': self.llm_output}]}
194+
)
195+
179196
prompt = self.conjunctions[0] + DEFAULT_IMAGE_TOKEN
180197
sources[0]["value"] += f" {prompt}."
181198
prompt_instruction = copy.deepcopy(sources[0]["value"])
182199
parts = split_and_clean(prompt_instruction)
183-
200+
184201
content = []
185-
for i in range (len(parts)):
202+
for i in range(len(parts)):
186203
if parts[i] == "<image>":
187204
content.append({"type": "image", "image": self.input_images[input_img_id]})
188-
input_img_id +=1
205+
input_img_id += 1
189206
else:
190-
content.append({"type": "text", "text": parts[i]})
191-
207+
content.append({"type": "text", "text": parts[i]})
208+
192209
self.conversation_history.append({'role': 'user', 'content': content})
193-
194-
text = self.processor.apply_chat_template(
195-
self.conversation_history, tokenize=False, add_generation_prompt=True
196-
)
197-
210+
211+
text = self.processor.apply_chat_template(self.conversation_history, tokenize=False, add_generation_prompt=True)
212+
198213
inputs = self.processor(text=[text], images=self.input_images, return_tensors="pt").to(self.device)
199214
t0 = time.time()
200215
with torch.no_grad():
201216
outputs = self.model.generate(
202-
**inputs,
203-
max_new_tokens=128,
217+
**inputs,
218+
max_new_tokens=128,
204219
do_sample=False,
205220
use_cache=True,
206221
past_key_values=self.past_key_values,
207222
return_dict_in_generate=True,
208223
raw_input_ids=copy.deepcopy(inputs.input_ids),
209224
)
210225
output_ids = outputs.sequences
211-
226+
212227
t1 = time.time()
213-
self.llm_output = self.processor.tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
228+
self.llm_output = self.processor.tokenizer.decode(
229+
output_ids[0][inputs.input_ids.shape[1] :], skip_special_tokens=True
230+
)
214231
with open(f"{self.save_dir}/llm_output_{self.episode_idx:04d}.txt", 'w') as f:
215232
f.write(self.llm_output)
216233
self.last_output_ids = copy.deepcopy(output_ids[0])
217234
self.past_key_values = copy.deepcopy(outputs.past_key_values)
218235
print(f"output {self.episode_idx} {self.llm_output} cost:{t1-t0}s")
219-
if bool(re.search(r'\d', self.llm_output)):
236+
if bool(re.search(r'\d', self.llm_output)):
220237
coord = [int(c) for c in re.findall(r'\d+', self.llm_output)]
221238
pixel_goal = [int(coord[1]), int(coord[0])]
222-
image_grid_thw = torch.cat(
223-
[thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0
224-
)
239+
image_grid_thw = torch.cat([thw.unsqueeze(0) for thw in inputs.image_grid_thw], dim=0)
225240
pixel_values = inputs.pixel_values
226241
t0 = time.time()
227242
with torch.no_grad():
228243
traj_latents = self.model.generate_latents(output_ids, pixel_values, image_grid_thw)
229244
return None, traj_latents, pixel_goal
230-
245+
231246
else:
232247
action_seq = self.parse_actions(self.llm_output)
233-
return action_seq, None, None
234-
235-
def step_s1(self, latent, rgb, depth):
248+
return action_seq, None, None
249+
250+
def step_s1(self, latent, rgb, depth):
236251
all_trajs = self.model.generate_traj(latent, rgb, depth, use_async=True)
237252
return all_trajs

0 commit comments

Comments
 (0)