|
9 | 9 | 5. Shows how VLM tools can be used during filtering |
10 | 10 | """ |
11 | 11 |
|
12 | | -# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B-Instruct --host 0.0.0.0 --port 30000 |
| 12 | +# python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B-Instruct --host 0.0.0.0 --port 30000 --tp 8 |
13 | 13 |
|
14 | 14 | import os |
15 | 15 | import time |
@@ -141,129 +141,259 @@ def create_robodm_dataset(self, robodm_dir: str) -> VLADataset: |
141 | 141 |
|
142 | 142 | return dataset |
143 | 143 |
|
144 | | - def create_success_filter_function(self) -> callable: |
| 144 | + def calculate_trajectory_captioning_f1(self, dataset: VLADataset): |
145 | 145 | """ |
146 | | - Create a simple filter function for successful trajectories. |
147 | | - |
148 | | - For now, we bypass the planner and write the function directly. |
149 | | - This function can use VLM tools during execution. |
| 146 | + Calculate F1 score for trajectory captioning by comparing VLM-generated captions |
| 147 | + with ground truth language descriptions from metadata using LLM for semantic matching. |
150 | 148 | |
| 149 | + Args: |
| 150 | + dataset: VLADataset with loaded trajectories |
| 151 | + |
151 | 152 | Returns: |
152 | | - Filter function that identifies successful trajectories |
| 153 | + float: F1 score for caption similarity |
153 | 154 | """ |
154 | | - def filter_successful_trajectories(trajectory: Dict[str, Any]) -> bool: |
155 | | - """ |
156 | | - Filter function to identify successful trajectories. |
| 155 | + print("\n" + "=" * 60) |
| 156 | + print("TRAJECTORY CAPTIONING F1 CALCULATION") |
| 157 | + print("=" * 60) |
| 158 | + |
| 159 | + # Create output directory for captioning results |
| 160 | + caption_output_dir = Path("./trajectory_captioning_results") |
| 161 | + caption_output_dir.mkdir(exist_ok=True) |
| 162 | + |
| 163 | + def extract_caption_and_description(trajectory: Dict[str, Any]) -> Dict[str, Any]: |
| 164 | + """Extract VLM caption and ground truth description from trajectory.""" |
| 165 | + import json |
| 166 | + from pathlib import Path |
| 167 | + import numpy as np |
| 168 | + import cv2 |
157 | 169 |
|
158 | | - This demonstrates: |
159 | | - 1. Working with trajectory data structure |
160 | | - 2. Using VLM tools during filtering |
161 | | - 3. Checking both labels and visual analysis |
162 | | - """ |
163 | | - # First check if we have a success label in the file path |
164 | 170 | file_path = trajectory.get("__file_path__", "") |
165 | | - has_success_label = "success" in file_path.lower() |
166 | | - trajectory["metadata"] = None # TODO: for now, it has serialization error |
| 171 | + traj_name = Path(file_path).stem |
167 | 172 |
|
168 | | - # For demonstration, we'll use VLM to analyze four frames stitched together |
169 | | - # This gives better context of the trajectory progression |
| 173 | + # Parse metadata to get language description |
| 174 | + ground_truth_description = "" |
| 175 | + try: |
| 176 | + metadata_data = trajectory.get("metadata", None) |
| 177 | + if metadata_data is not None: |
| 178 | + # Handle case where metadata is stored as a numpy array/list from trajectory loading |
| 179 | + if isinstance(metadata_data, (list, np.ndarray)) and len(metadata_data) > 0: |
| 180 | + metadata_str = metadata_data[0] |
| 181 | + else: |
| 182 | + metadata_str = metadata_data |
| 183 | + |
| 184 | + # Parse the JSON string |
| 185 | + if metadata_str: |
| 186 | + metadata = json.loads(metadata_str) |
| 187 | + # Get language instruction from metadata |
| 188 | + # Use current_task as it contains the task description in DROID dataset |
| 189 | + ground_truth_description = metadata.get("current_task", "") |
| 190 | + |
| 191 | + # If current_task is not available, try language_instruction fields |
| 192 | + if not ground_truth_description: |
| 193 | + ground_truth_description = ( |
| 194 | + metadata.get("language_instruction", "") or |
| 195 | + metadata.get("language_instruction_2", "") or |
| 196 | + metadata.get("language_instruction_3", "") |
| 197 | + ) |
| 198 | + except Exception as e: |
| 199 | + print(f"Error parsing metadata for {traj_name}: {e}") |
| 200 | + import traceback |
| 201 | + traceback.print_exc() |
| 202 | + |
| 203 | + |
| 204 | + # Get VLM caption |
| 205 | + vlm_caption = "" |
170 | 206 | try: |
171 | 207 | # Find camera keys |
172 | 208 | camera_keys = [k for k in trajectory.keys() |
173 | 209 | if "observation/images/" in k or "image" in k.lower()] |
174 | 210 |
|
175 | 211 | if camera_keys: |
176 | | - # Get the primary camera (usually the second one in DROID) |
177 | 212 | primary_camera = camera_keys[3] if len(camera_keys) > 1 else camera_keys[0] |
178 | | - |
179 | | - # Get four frames evenly spaced throughout the trajectory |
180 | 213 | frames = trajectory.get(primary_camera, []) |
181 | | - if len(frames) >= 4: |
182 | | - # Select 4 frames: start, 1/3, 2/3, and end |
183 | | - indices = [0, len(frames)//3, 2*len(frames)//3, len(frames)-1] |
| 214 | + |
| 215 | + if len(frames) >= 8: |
| 216 | + # Extract frames evenly distributed throughout the trajectory |
| 217 | + num_frames = 6 # Extract 6 frames for captioning |
| 218 | + indices = np.linspace(0, len(frames)-1, num_frames, dtype=int) |
184 | 219 | selected_frames = [frames[i] for i in indices] |
185 | 220 |
|
186 | | - # Use OpenCV to stitch frames together in a 2x2 grid |
187 | | - import cv2 |
| 221 | + # Create 2x3 grid for better trajectory understanding |
| 222 | + # Use original frame sizes without resizing |
188 | 223 |
|
189 | | - # Ensure all frames are the same size |
190 | | - h, w = selected_frames[0].shape[:2] |
191 | | - resized_frames = [] |
192 | | - for frame in selected_frames: |
193 | | - if frame.shape[:2] != (h, w): |
194 | | - frame = cv2.resize(frame, (w, h)) |
195 | | - resized_frames.append(frame) |
196 | | - |
197 | | - # Create 2x2 grid |
198 | | - top_row = np.hstack([resized_frames[0], resized_frames[1]]) |
199 | | - bottom_row = np.hstack([resized_frames[2], resized_frames[3]]) |
| 224 | + # Create 2x3 grid |
| 225 | + top_row = np.hstack(selected_frames[:3]) |
| 226 | + bottom_row = np.hstack(selected_frames[3:]) |
200 | 227 | stitched_frame = np.vstack([top_row, bottom_row]) |
201 | 228 |
|
202 | | - elif len(frames) > 0: |
203 | | - # If fewer than 4 frames, just use the last frame |
204 | | - stitched_frame = frames[-1] |
205 | | - |
206 | | - # IMPORTANT: Create VLM service locally to avoid serialization issues |
207 | | - # Don't capture external tools in the closure as they contain non-serializable objects |
| 229 | + # Save input image |
| 230 | + image_filename = caption_output_dir / f"{traj_name}_caption_input.jpg" |
| 231 | + cv2.imwrite(str(image_filename), cv2.cvtColor(stitched_frame, cv2.COLOR_RGB2BGR)) |
| 232 | + |
| 233 | + # Use VLM to generate caption |
| 234 | + from robodm.agent.vlm_service import get_vlm_service |
| 235 | + vlm_service = get_vlm_service() |
| 236 | + vlm_service.initialize() |
| 237 | + |
| 238 | + vlm_prompt = ( |
| 239 | + "These are 6 frames from a robot trajectory shown in temporal order " |
| 240 | + "(left to right, top to bottom). Please describe with one sentence what task the robot " |
| 241 | + "is performing in this trajectory. Be concise and specific about the " |
| 242 | + "actions and objects involved." |
| 243 | + ) |
| 244 | + vlm_caption = vlm_service.analyze_image(stitched_frame, vlm_prompt) |
| 245 | + |
| 246 | + print(f"📝 Captioning {traj_name}") |
| 247 | + print(f" GT: '{ground_truth_description}...'") |
| 248 | + print(f" VLM: '{vlm_caption}...'") |
| 249 | + |
| 250 | + else: |
| 251 | + print(f"⚠️ Trajectory {traj_name} has only {len(frames)} frames, skipping captioning") |
| 252 | + |
| 253 | + except Exception as e: |
| 254 | + print(f"Error generating VLM caption for {traj_name}: {e}") |
| 255 | + import traceback |
| 256 | + traceback.print_exc() |
| 257 | + |
| 258 | + # Use LLM to compare descriptions semantically |
| 259 | + is_match = False |
| 260 | + comparison_explanation = "" |
| 261 | + |
| 262 | + if ground_truth_description and vlm_caption: |
| 263 | + try: |
208 | 264 | from robodm.agent.vlm_service import get_vlm_service |
209 | 265 | vlm_service = get_vlm_service() |
210 | | - # vlm_service.initialize() |
211 | | - |
212 | | - # Import Path for local use |
213 | | - from pathlib import Path |
214 | | - import cv2 |
215 | | - |
216 | | - # Create output directory for VLM inputs/outputs |
217 | | - vlm_output_dir = Path("./vlm_analysis_results") |
218 | | - vlm_output_dir.mkdir(exist_ok=True) |
219 | 266 |
|
220 | | - # Create unique filename based on trajectory name |
221 | | - traj_name = Path(file_path).stem |
222 | | - image_filename = vlm_output_dir / f"{traj_name}_input.jpg" |
223 | | - text_filename = vlm_output_dir / f"{traj_name}_output.txt" |
224 | | - |
225 | | - # Save the stitched frame (VLM input) |
226 | | - cv2.imwrite(str(image_filename), cv2.cvtColor(stitched_frame, cv2.COLOR_RGB2BGR)) |
227 | | - |
228 | | - # Use VLM to check for success indicators on the stitched frames |
229 | | - vlm_prompt = "These are 4 frames from the trajectory (start, 1/3, 2/3, end). Anwser the question: Does this trajectory look successful in completing the task? Answer yes or no." |
230 | | - vlm_response = vlm_service.analyze_image(stitched_frame, vlm_prompt) |
231 | | - |
232 | | - # Save the VLM response (VLM output) with additional metadata |
233 | | - with open(text_filename, 'w') as f: |
234 | | - f.write(f"Trajectory: {traj_name}\n") |
235 | | - f.write(f"File path: {file_path}\n") |
236 | | - f.write(f"Has success label: {has_success_label}\n") |
237 | | - f.write(f"Input image saved as: {image_filename.name}\n") |
238 | | - f.write(f"\nVLM Prompt:\n{vlm_prompt}\n") |
239 | | - f.write(f"\nVLM Response:\n{vlm_response}\n") |
240 | | - |
241 | | - print(f"💾 Saved VLM analysis for {traj_name}:") |
242 | | - print(f" Input image: {image_filename}") |
243 | | - print(f" Output text: {text_filename}") |
244 | | - print(vlm_response) |
245 | | - |
246 | | - # Check if VLM thinks it's successful |
247 | | - vlm_success = "yes" in vlm_response.lower() |
| 267 | + comparison_prompt = f"""Compare these two robot task descriptions and determine if they describe the same task: |
| 268 | +
|
| 269 | +Description 1 (Ground Truth): {ground_truth_description} |
| 270 | +
|
| 271 | +Description 2 (VLM Caption): {vlm_caption} |
| 272 | +
|
| 273 | +Respond with only YES or NO followed by a brief explanation. |
| 274 | +
|
| 275 | +Format: |
| 276 | +YES/NO: Your explanation here""" |
| 277 | + |
| 278 | + comparison_response = vlm_service.generate_code(comparison_prompt) |
248 | 279 |
|
249 | | - # Combine label and VLM analysis |
250 | | - # For demo, we'll trust the label but log VLM disagreements |
251 | | - if has_success_label != vlm_success: |
252 | | - print(f"❌ Label and VLM disagree for {Path(file_path).name}: " |
253 | | - f"label={has_success_label}, vlm={vlm_success}") |
| 280 | + # Parse the response |
| 281 | + response_lower = comparison_response.strip().lower() |
| 282 | + if response_lower.startswith("yes"): |
| 283 | + is_match = True |
| 284 | + comparison_explanation = comparison_response[3:].strip(": ") |
| 285 | + elif response_lower.startswith("no"): |
| 286 | + is_match = False |
| 287 | + comparison_explanation = comparison_response[2:].strip(": ") |
254 | 288 | else: |
255 | | - print(f"✅ Label and VLM agree for {Path(file_path).name}: " |
256 | | - f"label={has_success_label}, vlm={vlm_success}") |
| 289 | + # Try to find YES or NO in the response |
| 290 | + is_match = "yes" in response_lower.split()[0:3] |
| 291 | + comparison_explanation = comparison_response |
| 292 | + |
| 293 | + print(f" Match: {'YES' if is_match else 'NO'}") |
257 | 294 |
|
258 | | - return has_success_label |
| 295 | + except Exception as e: |
| 296 | + print(f"Error comparing descriptions: {e}") |
| 297 | + comparison_explanation = f"Error: {str(e)}" |
| 298 | + |
| 299 | + # Save results |
| 300 | + results_filename = caption_output_dir / f"{traj_name}_caption_results.txt" |
| 301 | + with open(results_filename, 'w') as f: |
| 302 | + f.write(f"Trajectory Captioning Results\n") |
| 303 | + f.write(f"============================\n") |
| 304 | + f.write(f"Trajectory: {traj_name}\n") |
| 305 | + f.write(f"File path: {file_path}\n") |
| 306 | + f.write(f"\nGround Truth Description:\n{ground_truth_description}\n") |
| 307 | + f.write(f"\nVLM Generated Caption:\n{vlm_caption}\n") |
| 308 | + f.write(f"\nSemantic Comparison:\n") |
| 309 | + f.write(f"Match: {'YES' if is_match else 'NO'}\n") |
| 310 | + f.write(f"Explanation: {comparison_explanation}\n") |
| 311 | + f.write(f"\nInput image saved as: {traj_name}_caption_input.jpg\n") |
| 312 | + |
| 313 | + return { |
| 314 | + "trajectory_name": traj_name, |
| 315 | + "ground_truth_description": ground_truth_description, |
| 316 | + "vlm_caption": vlm_caption, |
| 317 | + "has_ground_truth": bool(ground_truth_description), |
| 318 | + "has_caption": bool(vlm_caption), |
| 319 | + "is_match": is_match, |
| 320 | + "comparison_explanation": comparison_explanation |
| 321 | + } |
| 322 | + |
| 323 | + # Apply transformation to get all captions |
| 324 | + results_dataset = dataset.map(extract_caption_and_description).materialize() |
| 325 | + results = list(results_dataset.iter_rows()) |
| 326 | + |
| 327 | + # Calculate F1 score based on LLM matching |
| 328 | + true_positives = 0 # VLM correctly identifies matching tasks |
| 329 | + false_positives = 0 # VLM incorrectly claims match |
| 330 | + false_negatives = 0 # VLM misses a match |
| 331 | + true_negatives = 0 # VLM correctly identifies non-match (not applicable here) |
| 332 | + |
| 333 | + valid_comparisons = 0 |
| 334 | + |
| 335 | + print("\nDetailed Caption Comparison Results:") |
| 336 | + print("-" * 80) |
| 337 | + |
| 338 | + for result in results: |
| 339 | + if result["has_ground_truth"] and result["has_caption"]: |
| 340 | + valid_comparisons += 1 |
259 | 341 |
|
260 | | - except Exception as e: |
261 | | - print(f"Error in VLM analysis: {e}") |
262 | | - # Fall back to label-based detection |
| 342 | + # Get the match result |
| 343 | + predicted_match = result["is_match"] |
| 344 | + |
| 345 | + # In this context, we assume ground truth is that captions SHOULD match |
| 346 | + # (since VLM is describing the same trajectory) |
| 347 | + actual_match = True |
| 348 | + |
| 349 | + if predicted_match and actual_match: |
| 350 | + true_positives += 1 |
| 351 | + elif not predicted_match and actual_match: |
| 352 | + false_negatives += 1 |
| 353 | + |
| 354 | + status = "✅" if predicted_match else "❌" |
| 355 | + print(f"{status} {result['trajectory_name']}: {'MATCH' if predicted_match else 'NO MATCH'}") |
| 356 | + print(f" Explanation: {result['comparison_explanation']}") |
| 357 | + print() |
| 358 | + |
| 359 | + # Calculate metrics |
| 360 | + if valid_comparisons > 0: |
| 361 | + # Precision: Of all predicted matches, how many were correct? |
| 362 | + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 |
| 363 | + |
| 364 | + # Recall: Of all actual matches, how many did we find? |
| 365 | + recall = true_positives / (true_positives + false_negatives) if (true_positives + false_negatives) > 0 else 0 |
263 | 366 |
|
264 | | - return has_success_label |
| 367 | + # F1 Score |
| 368 | + f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
| 369 | + else: |
| 370 | + precision = recall = f1_score = 0 |
| 371 | + print("⚠️ No valid comparisons found (missing ground truth or captions)") |
| 372 | + |
| 373 | + print(f"\nOverall Captioning Metrics:") |
| 374 | + print(f"Valid comparisons: {valid_comparisons}/{len(results)}") |
| 375 | + print(f"Matches (True Positives): {true_positives}") |
| 376 | + print(f"No Matches (False Negatives): {false_negatives}") |
| 377 | + print(f"Precision: {precision:.3f}") |
| 378 | + print(f"Recall: {recall:.3f}") |
| 379 | + print(f"F1 Score: {f1_score:.3f}") |
| 380 | + |
| 381 | + # Summary of results |
| 382 | + summary_filename = caption_output_dir / "captioning_f1_summary.txt" |
| 383 | + with open(summary_filename, 'w') as f: |
| 384 | + f.write(f"Trajectory Captioning F1 Summary\n") |
| 385 | + f.write(f"================================\n") |
| 386 | + f.write(f"Total trajectories: {len(results)}\n") |
| 387 | + f.write(f"Valid comparisons: {valid_comparisons}\n") |
| 388 | + f.write(f"Matches (True Positives): {true_positives}\n") |
| 389 | + f.write(f"No Matches (False Negatives): {false_negatives}\n") |
| 390 | + f.write(f"Precision: {precision:.3f}\n") |
| 391 | + f.write(f"Recall: {recall:.3f}\n") |
| 392 | + f.write(f"F1 Score: {f1_score:.3f}\n") |
265 | 393 |
|
266 | | - return filter_successful_trajectories |
| 394 | + print(f"\n✅ Results saved to {caption_output_dir}/") |
| 395 | + |
| 396 | + return f1_score |
267 | 397 |
|
268 | 398 | def calculate_f1_matrix(self, dataset: VLADataset): |
269 | 399 | """ |
@@ -460,9 +590,14 @@ def main(): |
460 | 590 | detector = DROIDSuccessDetector(max_trajectories=max_trajectories) |
461 | 591 | dataset = detector.create_robodm_dataset(robodm_dir) |
462 | 592 |
|
463 | | - # Step 5: Calculate F1 Matrix |
464 | | - print("\n5. Calculating F1 Matrix...") |
465 | | - detector.calculate_f1_matrix(dataset) |
| 593 | + # # Step 5: Calculate F1 Matrix |
| 594 | + # print("\n5. Calculating F1 Matrix...") |
| 595 | + # detector.calculate_f1_matrix(dataset) |
| 596 | + |
| 597 | + # Step 6: Calculate Trajectory Captioning F1 |
| 598 | + print("\n6. Calculating Trajectory Captioning F1...") |
| 599 | + captioning_f1 = detector.calculate_trajectory_captioning_f1(dataset) |
| 600 | + print(f"\nFinal Trajectory Captioning F1 Score: {captioning_f1:.3f}") |
466 | 601 |
|
467 | 602 | # Cleanup Ray |
468 | 603 | if ray.is_initialized(): |
|
0 commit comments