Skip to content

Commit e2539ef

Browse files
author
Your Name
committed
Refactor VLM processing to focus on MP4 input and remove unused parameters
- Removed image_key and language_key parameters from VLM processing functions as they are not applicable for DROID directories with MP4 files. - Updated the pipeline to handle multiple trials for VLM evaluations, including saving per-trial metrics and aggregate results. - Simplified the example usage in `simple_vlm_processing.py` to reflect the new input format and removed state visualization functionality. - Enhanced documentation to clarify the new input requirements and processing methods.
1 parent 83b235e commit e2539ef

File tree

3 files changed

+150
-297
lines changed

3 files changed

+150
-297
lines changed

examples/droid_h5/droid_pipeline.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,6 @@ def run_complete_pipeline(
548548
# Try to use the actual VLM processing with trajectory directories
549549
vlm_results = process_trajectories_parallel(
550550
trajectory_paths_for_vlm,
551-
image_key="", # Not used for DROID directories with video_path_key
552-
language_key=language_key,
553551
question=question,
554552
max_workers=max_workers,
555553
output_dir=f"{output_dir}/vlm_detailed_results",

examples/droid_h5/evaluate_vlm_configs.py

Lines changed: 114 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def main():
119119
parser.add_argument("--seed", type=int, help="Random seed")
120120
parser.add_argument("--max-workers", type=int, default=4, help="Parallel workers for VLM")
121121
parser.add_argument("--eval-root", default="./eval_runs", help="Root folder for evaluation outputs")
122+
parser.add_argument("--num-trials", type=int, default=1, help="Number of trials per configuration")
122123

123124
parser.add_argument("--frame-counts", type=int, nargs='+', default=[4, 8, 16, 32],
124125
help="Frame counts to evaluate")
@@ -187,50 +188,133 @@ def main():
187188
run_out_dir = runs_root / run_name
188189
os.makedirs(run_out_dir, exist_ok=True)
189190

190-
print(f"\n🚀 Run: {run_name}")
191-
results = process_trajectories_parallel(
192-
trajectory_paths=successful_local_paths,
193-
image_key="", # not used for DROID directories when MP4s present
194-
language_key=args.language_key,
195-
question=args.question,
196-
max_workers=args.max_workers,
197-
output_dir=str(run_out_dir),
198-
video_path_key=cam_key,
199-
num_frames=n,
200-
passing_method=method,
201-
concat_grid_cols=None
202-
)
203-
204-
# Persist raw results
205-
with open(run_out_dir / "vlm_results.json", 'w') as f:
206-
json.dump(results, f, indent=2)
207-
208-
total, predicted, correct, acc = compute_accuracy(results, gt_by_name)
209-
print(f"📈 Accuracy: {acc:.3f} ({correct}/{predicted}) | total {total}")
210-
211-
# Save metrics per run
191+
per_trial_metrics = []
192+
193+
for trial_idx in range(max(1, int(args.num_trials))):
194+
trial_num = trial_idx + 1
195+
trial_dir = run_out_dir / f"trial_{trial_num:02d}"
196+
os.makedirs(trial_dir, exist_ok=True)
197+
198+
print(f"\n🚀 Run: {run_name} [trial {trial_num}/{args.num_trials}]")
199+
results = process_trajectories_parallel(
200+
trajectory_paths=successful_local_paths,
201+
question=args.question,
202+
max_workers=args.max_workers,
203+
output_dir=str(trial_dir),
204+
video_path_key=cam_key,
205+
num_frames=n,
206+
passing_method=method,
207+
concat_grid_cols=None
208+
)
209+
210+
# Persist raw results per trial
211+
with open(trial_dir / "vlm_results.json", 'w') as f:
212+
json.dump(results, f, indent=2)
213+
214+
total, predicted, correct, acc = compute_accuracy(results, gt_by_name)
215+
print(f"📈 Trial {trial_num} accuracy: {acc:.3f} ({correct}/{predicted}) | total {total}")
216+
217+
# Save per-trial metrics
218+
with open(trial_dir / "metrics.csv", 'w', newline='') as f:
219+
writer = csv.writer(f)
220+
writer.writerow(["method", "frames", "camera_key", "trial", "total", "predicted", "correct", "accuracy"])
221+
writer.writerow([method, n, cam_key or "auto", trial_num, total, predicted, correct, f"{acc:.6f}"])
222+
223+
per_trial_metrics.append({
224+
"trial": trial_num,
225+
"total": total,
226+
"predicted": predicted,
227+
"correct": correct,
228+
"accuracy": acc,
229+
"run_dir": str(trial_dir)
230+
})
231+
232+
# Also add to overall summary (per-trial row)
233+
summary_rows.append({
234+
"method": method,
235+
"frames": n,
236+
"camera_key": cam_key or "auto",
237+
"trial": trial_num,
238+
"total": total,
239+
"predicted": predicted,
240+
"correct": correct,
241+
"accuracy": acc,
242+
"is_aggregate": False,
243+
"num_trials": int(args.num_trials),
244+
"accuracy_mean": None,
245+
"accuracy_variance": None,
246+
"run_dir": str(trial_dir)
247+
})
248+
249+
# Aggregate across trials
250+
accuracies = [m["accuracy"] for m in per_trial_metrics]
251+
if len(accuracies) > 1:
252+
mean_acc = float(np.mean(accuracies))
253+
var_acc = float(np.var(accuracies, ddof=1))
254+
else:
255+
mean_acc = float(accuracies[0]) if accuracies else 0.0
256+
var_acc = 0.0
257+
258+
print(f"📊 Aggregate over {len(accuracies)} trial(s): mean={mean_acc:.3f}, var={var_acc:.6f}")
259+
260+
# Persist aggregate metrics JSON at config root
261+
aggregate_payload = {
262+
"method": method,
263+
"frames": n,
264+
"camera_key": cam_key or "auto",
265+
"num_trials": int(args.num_trials),
266+
"per_trial": per_trial_metrics,
267+
"accuracy_mean": mean_acc,
268+
"accuracy_variance": var_acc,
269+
}
270+
with open(run_out_dir / "aggregate_metrics.json", 'w') as f:
271+
json.dump(aggregate_payload, f, indent=2)
272+
273+
# Write combined metrics (per-trial rows + aggregate row) at config root
212274
with open(run_out_dir / "metrics.csv", 'w', newline='') as f:
213275
writer = csv.writer(f)
214-
writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy"])
215-
writer.writerow([method, n, cam_key or "auto", total, predicted, correct, f"{acc:.6f}"])
276+
writer.writerow(["method", "frames", "camera_key", "trial", "total", "predicted", "correct", "accuracy", "is_aggregate", "num_trials", "accuracy_mean", "accuracy_variance"])
277+
for m in per_trial_metrics:
278+
writer.writerow([method, n, cam_key or "auto", m["trial"], m["total"], m["predicted"], m["correct"], f"{m['accuracy']:.6f}", 0, int(args.num_trials), "", ""])
279+
writer.writerow([method, n, cam_key or "auto", "all", "", "", "", f"{mean_acc:.6f}", 1, int(args.num_trials), f"{mean_acc:.6f}", f"{var_acc:.6f}"])
216280

281+
# Add aggregate row to overall summary
217282
summary_rows.append({
218283
"method": method,
219284
"frames": n,
220285
"camera_key": cam_key or "auto",
221-
"total": total,
222-
"predicted": predicted,
223-
"correct": correct,
224-
"accuracy": acc,
286+
"trial": "all",
287+
"total": None,
288+
"predicted": None,
289+
"correct": None,
290+
"accuracy": mean_acc,
291+
"is_aggregate": True,
292+
"num_trials": int(args.num_trials),
293+
"accuracy_mean": mean_acc,
294+
"accuracy_variance": var_acc,
225295
"run_dir": str(run_out_dir)
226296
})
227297

228298
# Write overall summary
229299
with open(eval_root / "summary.csv", 'w', newline='') as f:
230300
writer = csv.writer(f)
231-
writer.writerow(["method", "frames", "camera_key", "total", "predicted", "correct", "accuracy", "run_dir"])
301+
writer.writerow(["method", "frames", "camera_key", "trial", "total", "predicted", "correct", "accuracy", "is_aggregate", "num_trials", "accuracy_mean", "accuracy_variance", "run_dir"])
232302
for r in summary_rows:
233-
writer.writerow([r["method"], r["frames"], r["camera_key"], r["total"], r["predicted"], r["correct"], f"{r['accuracy']:.6f}", r["run_dir"]])
303+
writer.writerow([
304+
r["method"],
305+
r["frames"],
306+
r["camera_key"],
307+
r.get("trial", ""),
308+
r.get("total", ""),
309+
r.get("predicted", ""),
310+
r.get("correct", ""),
311+
f"{r['accuracy']:.6f}",
312+
int(bool(r.get("is_aggregate", False))),
313+
r.get("num_trials", ""),
314+
f"{r['accuracy_mean']:.6f}" if r.get("accuracy_mean") is not None else "",
315+
f"{r['accuracy_variance']:.6f}" if r.get("accuracy_variance") is not None else "",
316+
r["run_dir"],
317+
])
234318

235319
elapsed = time.time() - start_all
236320
print(f"\n🎉 Evaluation complete in {elapsed/60:.1f} minutes")

0 commit comments

Comments
 (0)