Skip to content

Commit 28deddd

Browse files
author
Your Name
committed
vlm captioning
1 parent 658d87b commit 28deddd

File tree

3 files changed

+234
-99
lines changed

3 files changed

+234
-99
lines changed

examples/droid/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ robodm_trajectories/
33
vlm_analysis_results/
44
full_robodm_trajectories/
55
f1_matrix_results/
6+
trajectory_captioning_results/

examples/droid/droid_vlm_demo.py

Lines changed: 233 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
5. Shows how VLM tools can be used during filtering
1010
"""
1111

12-
# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B-Instruct --host 0.0.0.0 --port 30000
12+
# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B-Instruct --host 0.0.0.0 --port 30000 --tp 8
1313

1414
import os
1515
import time
@@ -141,129 +141,259 @@ def create_robodm_dataset(self, robodm_dir: str) -> VLADataset:
141141

142142
return dataset
143143

144-
def create_success_filter_function(self) -> callable:
144+
def calculate_trajectory_captioning_f1(self, dataset: VLADataset):
145145
"""
146-
Create a simple filter function for successful trajectories.
147-
148-
For now, we bypass the planner and write the function directly.
149-
This function can use VLM tools during execution.
146+
Calculate F1 score for trajectory captioning by comparing VLM-generated captions
147+
with ground truth language descriptions from metadata using LLM for semantic matching.
150148
149+
Args:
150+
dataset: VLADataset with loaded trajectories
151+
151152
Returns:
152-
Filter function that identifies successful trajectories
153+
float: F1 score for caption similarity
153154
"""
154-
def filter_successful_trajectories(trajectory: Dict[str, Any]) -> bool:
155-
"""
156-
Filter function to identify successful trajectories.
155+
print("\n" + "=" * 60)
156+
print("TRAJECTORY CAPTIONING F1 CALCULATION")
157+
print("=" * 60)
158+
159+
# Create output directory for captioning results
160+
caption_output_dir = Path("./trajectory_captioning_results")
161+
caption_output_dir.mkdir(exist_ok=True)
162+
163+
def extract_caption_and_description(trajectory: Dict[str, Any]) -> Dict[str, Any]:
164+
"""Extract VLM caption and ground truth description from trajectory."""
165+
import json
166+
from pathlib import Path
167+
import numpy as np
168+
import cv2
157169

158-
This demonstrates:
159-
1. Working with trajectory data structure
160-
2. Using VLM tools during filtering
161-
3. Checking both labels and visual analysis
162-
"""
163-
# First check if we have a success label in the file path
164170
file_path = trajectory.get("__file_path__", "")
165-
has_success_label = "success" in file_path.lower()
166-
trajectory["metadata"] = None # TODO: for now, it has serialization error
171+
traj_name = Path(file_path).stem
167172

168-
# For demonstration, we'll use VLM to analyze four frames stitched together
169-
# This gives better context of the trajectory progression
173+
# Parse metadata to get language description
174+
ground_truth_description = ""
175+
try:
176+
metadata_data = trajectory.get("metadata", None)
177+
if metadata_data is not None:
178+
# Handle case where metadata is stored as a numpy array/list from trajectory loading
179+
if isinstance(metadata_data, (list, np.ndarray)) and len(metadata_data) > 0:
180+
metadata_str = metadata_data[0]
181+
else:
182+
metadata_str = metadata_data
183+
184+
# Parse the JSON string
185+
if metadata_str:
186+
metadata = json.loads(metadata_str)
187+
# Get language instruction from metadata
188+
# Use current_task as it contains the task description in DROID dataset
189+
ground_truth_description = metadata.get("current_task", "")
190+
191+
# If current_task is not available, try language_instruction fields
192+
if not ground_truth_description:
193+
ground_truth_description = (
194+
metadata.get("language_instruction", "") or
195+
metadata.get("language_instruction_2", "") or
196+
metadata.get("language_instruction_3", "")
197+
)
198+
except Exception as e:
199+
print(f"Error parsing metadata for {traj_name}: {e}")
200+
import traceback
201+
traceback.print_exc()
202+
203+
204+
# Get VLM caption
205+
vlm_caption = ""
170206
try:
171207
# Find camera keys
172208
camera_keys = [k for k in trajectory.keys()
173209
if "observation/images/" in k or "image" in k.lower()]
174210

175211
if camera_keys:
176-
# Get the primary camera (usually the second one in DROID)
177212
primary_camera = camera_keys[3] if len(camera_keys) > 1 else camera_keys[0]
178-
179-
# Get four frames evenly spaced throughout the trajectory
180213
frames = trajectory.get(primary_camera, [])
181-
if len(frames) >= 4:
182-
# Select 4 frames: start, 1/3, 2/3, and end
183-
indices = [0, len(frames)//3, 2*len(frames)//3, len(frames)-1]
214+
215+
if len(frames) >= 8:
216+
# Extract frames evenly distributed throughout the trajectory
217+
num_frames = 6 # Extract 6 frames for captioning
218+
indices = np.linspace(0, len(frames)-1, num_frames, dtype=int)
184219
selected_frames = [frames[i] for i in indices]
185220

186-
# Use OpenCV to stitch frames together in a 2x2 grid
187-
import cv2
221+
# Create 2x3 grid for better trajectory understanding
222+
# Use original frame sizes without resizing
188223

189-
# Ensure all frames are the same size
190-
h, w = selected_frames[0].shape[:2]
191-
resized_frames = []
192-
for frame in selected_frames:
193-
if frame.shape[:2] != (h, w):
194-
frame = cv2.resize(frame, (w, h))
195-
resized_frames.append(frame)
196-
197-
# Create 2x2 grid
198-
top_row = np.hstack([resized_frames[0], resized_frames[1]])
199-
bottom_row = np.hstack([resized_frames[2], resized_frames[3]])
224+
# Create 2x3 grid
225+
top_row = np.hstack(selected_frames[:3])
226+
bottom_row = np.hstack(selected_frames[3:])
200227
stitched_frame = np.vstack([top_row, bottom_row])
201228

202-
elif len(frames) > 0:
203-
# If fewer than 4 frames, just use the last frame
204-
stitched_frame = frames[-1]
205-
206-
# IMPORTANT: Create VLM service locally to avoid serialization issues
207-
# Don't capture external tools in the closure as they contain non-serializable objects
229+
# Save input image
230+
image_filename = caption_output_dir / f"{traj_name}_caption_input.jpg"
231+
cv2.imwrite(str(image_filename), cv2.cvtColor(stitched_frame, cv2.COLOR_RGB2BGR))
232+
233+
# Use VLM to generate caption
234+
from robodm.agent.vlm_service import get_vlm_service
235+
vlm_service = get_vlm_service()
236+
vlm_service.initialize()
237+
238+
vlm_prompt = (
239+
"These are 6 frames from a robot trajectory shown in temporal order "
240+
"(left to right, top to bottom). Please describe with one sentence what task the robot "
241+
"is performing in this trajectory. Be concise and specific about the "
242+
"actions and objects involved."
243+
)
244+
vlm_caption = vlm_service.analyze_image(stitched_frame, vlm_prompt)
245+
246+
print(f"📝 Captioning {traj_name}")
247+
print(f" GT: '{ground_truth_description}...'")
248+
print(f" VLM: '{vlm_caption}...'")
249+
250+
else:
251+
print(f"⚠️ Trajectory {traj_name} has only {len(frames)} frames, skipping captioning")
252+
253+
except Exception as e:
254+
print(f"Error generating VLM caption for {traj_name}: {e}")
255+
import traceback
256+
traceback.print_exc()
257+
258+
# Use LLM to compare descriptions semantically
259+
is_match = False
260+
comparison_explanation = ""
261+
262+
if ground_truth_description and vlm_caption:
263+
try:
208264
from robodm.agent.vlm_service import get_vlm_service
209265
vlm_service = get_vlm_service()
210-
# vlm_service.initialize()
211-
212-
# Import Path for local use
213-
from pathlib import Path
214-
import cv2
215-
216-
# Create output directory for VLM inputs/outputs
217-
vlm_output_dir = Path("./vlm_analysis_results")
218-
vlm_output_dir.mkdir(exist_ok=True)
219266

220-
# Create unique filename based on trajectory name
221-
traj_name = Path(file_path).stem
222-
image_filename = vlm_output_dir / f"{traj_name}_input.jpg"
223-
text_filename = vlm_output_dir / f"{traj_name}_output.txt"
224-
225-
# Save the stitched frame (VLM input)
226-
cv2.imwrite(str(image_filename), cv2.cvtColor(stitched_frame, cv2.COLOR_RGB2BGR))
227-
228-
# Use VLM to check for success indicators on the stitched frames
229-
vlm_prompt = "These are 4 frames from the trajectory (start, 1/3, 2/3, end). Anwser the question: Does this trajectory look successful in completing the task? Answer yes or no."
230-
vlm_response = vlm_service.analyze_image(stitched_frame, vlm_prompt)
231-
232-
# Save the VLM response (VLM output) with additional metadata
233-
with open(text_filename, 'w') as f:
234-
f.write(f"Trajectory: {traj_name}\n")
235-
f.write(f"File path: {file_path}\n")
236-
f.write(f"Has success label: {has_success_label}\n")
237-
f.write(f"Input image saved as: {image_filename.name}\n")
238-
f.write(f"\nVLM Prompt:\n{vlm_prompt}\n")
239-
f.write(f"\nVLM Response:\n{vlm_response}\n")
240-
241-
print(f"💾 Saved VLM analysis for {traj_name}:")
242-
print(f" Input image: {image_filename}")
243-
print(f" Output text: {text_filename}")
244-
print(vlm_response)
245-
246-
# Check if VLM thinks it's successful
247-
vlm_success = "yes" in vlm_response.lower()
267+
comparison_prompt = f"""Compare these two robot task descriptions and determine if they describe the same task:
268+
269+
Description 1 (Ground Truth): {ground_truth_description}
270+
271+
Description 2 (VLM Caption): {vlm_caption}
272+
273+
Respond with only YES or NO followed by a brief explanation.
274+
275+
Format:
276+
YES/NO: Your explanation here"""
277+
278+
comparison_response = vlm_service.generate_code(comparison_prompt)
248279

249-
# Combine label and VLM analysis
250-
# For demo, we'll trust the label but log VLM disagreements
251-
if has_success_label != vlm_success:
252-
print(f"❌ Label and VLM disagree for {Path(file_path).name}: "
253-
f"label={has_success_label}, vlm={vlm_success}")
280+
# Parse the response
281+
response_lower = comparison_response.strip().lower()
282+
if response_lower.startswith("yes"):
283+
is_match = True
284+
comparison_explanation = comparison_response[3:].strip(": ")
285+
elif response_lower.startswith("no"):
286+
is_match = False
287+
comparison_explanation = comparison_response[2:].strip(": ")
254288
else:
255-
print(f"✅ Label and VLM agree for {Path(file_path).name}: "
256-
f"label={has_success_label}, vlm={vlm_success}")
289+
# Try to find YES or NO in the response
290+
is_match = "yes" in response_lower.split()[0:3]
291+
comparison_explanation = comparison_response
292+
293+
print(f" Match: {'YES' if is_match else 'NO'}")
257294

258-
return has_success_label
295+
except Exception as e:
296+
print(f"Error comparing descriptions: {e}")
297+
comparison_explanation = f"Error: {str(e)}"
298+
299+
# Save results
300+
results_filename = caption_output_dir / f"{traj_name}_caption_results.txt"
301+
with open(results_filename, 'w') as f:
302+
f.write(f"Trajectory Captioning Results\n")
303+
f.write(f"============================\n")
304+
f.write(f"Trajectory: {traj_name}\n")
305+
f.write(f"File path: {file_path}\n")
306+
f.write(f"\nGround Truth Description:\n{ground_truth_description}\n")
307+
f.write(f"\nVLM Generated Caption:\n{vlm_caption}\n")
308+
f.write(f"\nSemantic Comparison:\n")
309+
f.write(f"Match: {'YES' if is_match else 'NO'}\n")
310+
f.write(f"Explanation: {comparison_explanation}\n")
311+
f.write(f"\nInput image saved as: {traj_name}_caption_input.jpg\n")
312+
313+
return {
314+
"trajectory_name": traj_name,
315+
"ground_truth_description": ground_truth_description,
316+
"vlm_caption": vlm_caption,
317+
"has_ground_truth": bool(ground_truth_description),
318+
"has_caption": bool(vlm_caption),
319+
"is_match": is_match,
320+
"comparison_explanation": comparison_explanation
321+
}
322+
323+
# Apply transformation to get all captions
324+
results_dataset = dataset.map(extract_caption_and_description).materialize()
325+
results = list(results_dataset.iter_rows())
326+
327+
# Calculate F1 score based on LLM matching
328+
true_positives = 0 # VLM correctly identifies matching tasks
329+
false_positives = 0 # VLM incorrectly claims match
330+
false_negatives = 0 # VLM misses a match
331+
true_negatives = 0 # VLM correctly identifies non-match (not applicable here)
332+
333+
valid_comparisons = 0
334+
335+
print("\nDetailed Caption Comparison Results:")
336+
print("-" * 80)
337+
338+
for result in results:
339+
if result["has_ground_truth"] and result["has_caption"]:
340+
valid_comparisons += 1
259341

260-
except Exception as e:
261-
print(f"Error in VLM analysis: {e}")
262-
# Fall back to label-based detection
342+
# Get the match result
343+
predicted_match = result["is_match"]
344+
345+
# In this context, we assume ground truth is that captions SHOULD match
346+
# (since VLM is describing the same trajectory)
347+
actual_match = True
348+
349+
if predicted_match and actual_match:
350+
true_positives += 1
351+
elif not predicted_match and actual_match:
352+
false_negatives += 1
353+
354+
status = "✅" if predicted_match else "❌"
355+
print(f"{status} {result['trajectory_name']}: {'MATCH' if predicted_match else 'NO MATCH'}")
356+
print(f" Explanation: {result['comparison_explanation']}")
357+
print()
358+
359+
# Calculate metrics
360+
if valid_comparisons > 0:
361+
# Precision: Of all predicted matches, how many were correct?
362+
precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0
363+
364+
# Recall: Of all actual matches, how many did we find?
365+
recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0
263366

264-
return has_success_label
367+
# F1 Score
368+
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
369+
else:
370+
precision = recall = f1_score = 0
371+
print("⚠️ No valid comparisons found (missing ground truth or captions)")
372+
373+
print(f"\nOverall Captioning Metrics:")
374+
print(f"Valid comparisons: {valid_comparisons}/{len(results)}")
375+
print(f"Matches (True Positives): {true_positives}")
376+
print(f"No Matches (False Negatives): {false_negatives}")
377+
print(f"Precision: {precision:.3f}")
378+
print(f"Recall: {recall:.3f}")
379+
print(f"F1 Score: {f1_score:.3f}")
380+
381+
# Summary of results
382+
summary_filename = caption_output_dir / "captioning_f1_summary.txt"
383+
with open(summary_filename, 'w') as f:
384+
f.write(f"Trajectory Captioning F1 Summary\n")
385+
f.write(f"================================\n")
386+
f.write(f"Total trajectories: {len(results)}\n")
387+
f.write(f"Valid comparisons: {valid_comparisons}\n")
388+
f.write(f"Matches (True Positives): {true_positives}\n")
389+
f.write(f"No Matches (False Negatives): {false_negatives}\n")
390+
f.write(f"Precision: {precision:.3f}\n")
391+
f.write(f"Recall: {recall:.3f}\n")
392+
f.write(f"F1 Score: {f1_score:.3f}\n")
265393

266-
return filter_successful_trajectories
394+
print(f"\n✅ Results saved to {caption_output_dir}/")
395+
396+
return f1_score
267397

268398
def calculate_f1_matrix(self, dataset: VLADataset):
269399
"""
@@ -460,9 +590,14 @@ def main():
460590
detector = DROIDSuccessDetector(max_trajectories=max_trajectories)
461591
dataset = detector.create_robodm_dataset(robodm_dir)
462592

463-
# Step 5: Calculate F1 Matrix
464-
print("\n5. Calculating F1 Matrix...")
465-
detector.calculate_f1_matrix(dataset)
593+
# # Step 5: Calculate F1 Matrix
594+
# print("\n5. Calculating F1 Matrix...")
595+
# detector.calculate_f1_matrix(dataset)
596+
597+
# Step 6: Calculate Trajectory Captioning F1
598+
print("\n6. Calculating Trajectory Captioning F1...")
599+
captioning_f1 = detector.calculate_trajectory_captioning_f1(dataset)
600+
print(f"\nFinal Trajectory Captioning F1 Score: {captioning_f1:.3f}")
466601

467602
# Cleanup Ray
468603
if ray.is_initialized():

robodm/dataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,6 @@ def _load_trajectory(self, item) -> Dict[str, Any]:
127127
data = traj.load(return_type=self.return_type)
128128
# Add file path metadata for tracking
129129
data["__file_path__"] = str(file_path)
130-
data["metadata"] = None
131130

132131
return data
133132
except Exception as e:

0 commit comments

Comments
 (0)