995. Shows how VLM tools can be used during filtering
1010"""
1111
12- # python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-7B -Instruct --host 0.0.0.0 --port 30000
12+ # python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B -Instruct --host 0.0.0.0 --port 30000
1313
1414import os
1515import time
16+ import argparse
1617from pathlib import Path
17- from typing import Dict , List , Any
18+ from typing import Dict , List , Any , Optional
1819
1920import numpy as np
2021import cv2
3031class DROIDSuccessDetector :
3132 """Enhanced DROID success/failure detector using RoboDM Agent system."""
3233
33- def __init__ (self ):
34- """Initialize the detector with Agent capabilities."""
34+ def __init__ (self , max_trajectories : Optional [int ] = None ):
35+ """Initialize the detector with Agent capabilities.
36+
37+ Args:
38+ max_trajectories: Maximum number of trajectories to process. If None, processes all trajectories.
39+ """
3540 print ("Initializing RoboDM Agent with VLM tools..." )
3641
42+ self .max_trajectories = max_trajectories
43+ if max_trajectories is not None :
44+ print (f"Will limit processing to maximum { max_trajectories } trajectories" )
45+
3746 # Configure tools for the Agent
3847 self .tools_config = {
3948 "tools" : {
4049 "robo2vlm" : {
41- "model" : "Qwen/Qwen2.5-VL-7B -Instruct" ,
50+ "model" : "Qwen/Qwen2.5-VL-32B -Instruct" ,
4251 "temperature" : 0.1 ,
4352 "max_tokens" : 4096 ,
4453 "context_length" : 1024
@@ -85,7 +94,47 @@ def create_robodm_dataset(self, robodm_dir: str) -> VLADataset:
8594 config = config
8695 )
8796
88- print (f"Created VLADataset with { dataset .count ()} trajectory files" )
97+ total_trajectories = dataset .count ()
98+ print (f"Found { total_trajectories } trajectory files" )
99+
100+ # Apply max_trajectories limit if specified
101+ if self .max_trajectories is not None and total_trajectories > self .max_trajectories :
102+ print (f"Limiting to { self .max_trajectories } trajectories (out of { total_trajectories } total)" )
103+ # Use take() to limit the number of trajectories
104+ limited_items = dataset .take (self .max_trajectories )
105+
106+ # Create a new VLADataset from the limited items
107+ # We need to extract file paths from the limited items
108+ if limited_items :
109+ # Extract file paths from the limited items
110+ # The items are currently just string paths from the Ray dataset
111+ limited_file_paths = [item if isinstance (item , str ) else item .get ("item" , str (item ))
112+ for item in limited_items ]
113+
114+ # Create a new VLADataset with limited file paths
115+ import ray .data as rd
116+ limited_ray_dataset = rd .from_items (limited_file_paths )
117+ if config .shuffle :
118+ limited_ray_dataset = limited_ray_dataset .random_shuffle ()
119+
120+ # Create new VLADataset instance with limited data
121+ limited_dataset = VLADataset .__new__ (VLADataset )
122+ limited_dataset .path = dataset .path
123+ limited_dataset .return_type = dataset .return_type
124+ limited_dataset .config = dataset .config
125+ limited_dataset .file_paths = limited_file_paths
126+ limited_dataset .ray_dataset = limited_ray_dataset
127+ limited_dataset .metadata_manager = dataset .metadata_manager
128+ limited_dataset ._schema = None
129+ limited_dataset ._stats = None
130+ limited_dataset ._is_loaded = False
131+ limited_dataset ._has_file_paths = True
132+
133+ dataset = limited_dataset
134+ print (f"Limited dataset created with { dataset .count ()} trajectory files" )
135+ else :
136+ print (f"Processing all { total_trajectories } trajectory files" )
137+
89138 print (f"Dataset type: { type (dataset )} " )
90139 print (f"Has _is_loaded: { hasattr (dataset , '_is_loaded' )} " )
91140 print (f"Is loaded: { dataset ._is_loaded } " )
@@ -227,25 +276,35 @@ def calculate_f1_matrix(self, dataset: VLADataset):
227276 print ("F1 MATRIX CALCULATION" )
228277 print ("=" * 60 )
229278
279+ # Create output directory for F1 matrix results
280+ f1_output_dir = Path ("./f1_matrix_results" )
281+ f1_output_dir .mkdir (exist_ok = True )
282+
230283 # Transform to extract labels and predictions
231284 def extract_labels_and_predictions (trajectory : Dict [str , Any ]) -> Dict [str , Any ]:
232- """Extract ground truth and VLM predictions for F1 calculation."""
285+ """Extract ground truth and VLM predictions for F1 calculation with file saving ."""
233286 from pathlib import Path
234287 import numpy as np
288+ import cv2
235289
236290 file_path = trajectory .get ("__file_path__" , "" )
237291 ground_truth = "success" in file_path .lower ()
292+ traj_name = Path (file_path ).stem
238293
239- # Get VLM prediction (simplified version without saving files)
294+ # Get VLM prediction and save all results
240295 vlm_prediction = False
296+ vlm_response = "No VLM analysis performed"
297+
241298 try :
242299 # Find camera keys
243300 camera_keys = [k for k in trajectory .keys ()
244301 if "observation/images/" in k or "image" in k .lower ()]
302+ print (f"Camera keys: { camera_keys } " )
245303
246304 if camera_keys :
247305 primary_camera = camera_keys [3 ] if len (camera_keys ) > 1 else camera_keys [0 ]
248306 frames = trajectory .get (primary_camera , [])
307+ print (f"Frames: { len (frames )} , { frames [0 ].shape } " )
249308
250309 if len (frames ) >= 4 :
251310 # Select 4 frames: start, 1/3, 2/3, and end
@@ -257,32 +316,71 @@ def extract_labels_and_predictions(trajectory: Dict[str, Any]) -> Dict[str, Any]
257316 resized_frames = []
258317 for frame in selected_frames :
259318 if frame .shape [:2 ] != (h , w ):
260- import cv2
261319 frame = cv2 .resize (frame , (w , h ))
262320 resized_frames .append (frame )
263321
264322 top_row = np .hstack ([resized_frames [0 ], resized_frames [1 ]])
265323 bottom_row = np .hstack ([resized_frames [2 ], resized_frames [3 ]])
266324 stitched_frame = np .vstack ([top_row , bottom_row ])
267325
326+ # Save input image
327+ image_filename = f1_output_dir / f"{ traj_name } _input.jpg"
328+ cv2 .imwrite (str (image_filename ), cv2 .cvtColor (stitched_frame , cv2 .COLOR_RGB2BGR ))
329+
268330 # Use VLM to get prediction
269331 from robodm .agent .vlm_service import get_vlm_service
270332 vlm_service = get_vlm_service ()
271333 vlm_service .initialize ()
272334
273- vlm_prompt = "These are 4 frames from a robot trajectory. Does this trajectory look successful? Answer yes or no."
335+ vlm_prompt = "These are 4 frames from a robot trajectory. Does this trajectory look successful? First answer yes or no, then explain why ."
274336 vlm_response = vlm_service .analyze_image (stitched_frame , vlm_prompt )
275- print (vlm_response )
276337 vlm_prediction = "yes" in vlm_response .lower ()
277338
339+ print (f"🔍 F1 Analysis for { traj_name } : GT={ ground_truth } , VLM={ vlm_prediction } " )
340+
341+ elif len (frames ) > 0 :
342+ # If fewer than 4 frames, just use the last frame
343+ stitched_frame = frames [- 1 ]
344+
345+ # Save input image
346+ image_filename = f1_output_dir / f"{ traj_name } _input.jpg"
347+ cv2 .imwrite (str (image_filename ), cv2 .cvtColor (stitched_frame , cv2 .COLOR_RGB2BGR ))
348+
349+ # Use VLM to get prediction
350+ from robodm .agent .vlm_service import get_vlm_service
351+ vlm_service = get_vlm_service ()
352+ vlm_service .initialize ()
353+
354+ vlm_prompt = "This is the final frame from a robot trajectory. Does this trajectory look successful? Answer yes or no."
355+ vlm_response = vlm_service .analyze_image (stitched_frame , vlm_prompt )
356+ vlm_prediction = "yes" in vlm_response .lower ()
357+
358+ print (f"🔍 F1 Analysis for { traj_name } : GT={ ground_truth } , VLM={ vlm_prediction } " )
359+
278360 except Exception as e :
279- print (f"Error in VLM prediction: { e } " )
280- vlm_prediction = ground_truth # fallback to ground truth
361+ print (f"Error in VLM prediction for { traj_name } : { e } " )
362+ vlm_prediction = ground_truth
363+ vlm_response = f"Error occurred: { str (e )} "
364+
365+ # Save results to file
366+ results_filename = f1_output_dir / f"{ traj_name } _results.txt"
367+ with open (results_filename , 'w' ) as f :
368+ f .write (f"F1 Matrix Calculation Results\n " )
369+ f .write (f"=============================\n " )
370+ f .write (f"Trajectory: { traj_name } \n " )
371+ f .write (f"File path: { file_path } \n " )
372+ f .write (f"Ground truth (success): { ground_truth } \n " )
373+ f .write (f"VLM prediction (success): { vlm_prediction } \n " )
374+ f .write (f"Prediction correct: { ground_truth == vlm_prediction } \n " )
375+ f .write (f"\n VLM Prompt:\n { vlm_prompt if 'vlm_prompt' in locals () else 'No prompt used' } \n " )
376+ f .write (f"\n VLM Response:\n { vlm_response } \n " )
377+ f .write (f"\n Input image saved as: { traj_name } _input.jpg\n " )
281378
282379 return {
283- "trajectory_name" : Path ( file_path ). stem ,
380+ "trajectory_name" : traj_name ,
284381 "ground_truth" : ground_truth ,
285- "vlm_prediction" : vlm_prediction
382+ "vlm_prediction" : vlm_prediction ,
383+ "vlm_response" : vlm_response
286384 }
287385
288386 # Apply transformation to get all predictions using VLADataset's map
@@ -315,6 +413,12 @@ def extract_labels_and_predictions(trajectory: Dict[str, Any]) -> Dict[str, Any]
315413 f1_score = 2 * (precision * recall ) / (precision + recall ) if (precision + recall ) > 0 else 0
316414 accuracy = (true_positives + true_negatives ) / len (results )
317415
416+ print (f"\n Detailed Results:" )
417+ for result in results :
418+ status = "✅" if result ["ground_truth" ] == result ["vlm_prediction" ] else "❌"
419+ print (f"{ status } { result ['trajectory_name' ]} : GT={ result ['ground_truth' ]} , Pred={ result ['vlm_prediction' ]} " )
420+
421+
318422 # Print F1 Matrix
319423 print ("\n Confusion Matrix:" )
320424 print (" Predicted" )
@@ -328,10 +432,7 @@ def extract_labels_and_predictions(trajectory: Dict[str, Any]) -> Dict[str, Any]
328432 print (f"Recall: { recall :.3f} " )
329433 print (f"F1 Score: { f1_score :.3f} " )
330434
331- print (f"\n Detailed Results:" )
332- for result in results :
333- status = "✅" if result ["ground_truth" ] == result ["vlm_prediction" ] else "❌"
334- print (f"{ status } { result ['trajectory_name' ]} : GT={ result ['ground_truth' ]} , Pred={ result ['vlm_prediction' ]} " )
435+
335436
336437 return f1_score
337438
@@ -341,10 +442,22 @@ def main():
341442 print ("RoboDM VLADataset and Agent Demo" )
342443 print ("=" * 60 )
343444
344- robodm_dir = "./robodm_trajectories"
445+ # Configuration
446+ parser = argparse .ArgumentParser (description = "Run the DROID VLM demo" )
447+ parser .add_argument ("--data_dir" , type = str , default = "./robodm_trajectories" , help = "Directory containing RoboDM trajectory files" )
448+ parser .add_argument ("--max_trajectories" , type = int , default = 100 , help = "Maximum number of trajectories to process" )
449+ args = parser .parse_args ()
450+
451+ robodm_dir = args .data_dir
452+ max_trajectories = args .max_trajectories
453+
454+ print (f"Configuration:" )
455+ print (f" Data directory: { robodm_dir } " )
456+ print (f" Max trajectories: { max_trajectories if max_trajectories is not None else 'All' } " )
457+
345458 # Step 3: Create VLADataset (with file paths only)
346459 print ("\n 3. Creating VLADataset..." )
347- detector = DROIDSuccessDetector ()
460+ detector = DROIDSuccessDetector (max_trajectories = max_trajectories )
348461 dataset = detector .create_robodm_dataset (robodm_dir )
349462
350463 # Step 5: Calculate F1 Matrix
0 commit comments