Skip to content

Commit 978ffed

Browse files
author
Your Name
committed
minimal example
1 parent 5edfc1a commit 978ffed

File tree

3 files changed

+708
-0
lines changed

3 files changed

+708
-0
lines changed

examples/droid_robosq/README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# DROID VLM Batch Processing
2+
3+
This folder contains examples for batch processing DROID robot trajectories with Vision Language Models (VLM).
4+
5+
## Files
6+
7+
### `droid_download_example.py`
8+
Downloads DROID trajectories from Google Cloud Storage with parallel processing.
9+
10+
**Usage:**
11+
```bash
12+
python droid_download_example.py --local-dir ./droid_data --num-trajectories 50
13+
```
14+
15+
**Features:**
16+
- Parallel downloads from GCS using gsutil
17+
- Handles nested DROID directory structure
18+
- Configurable number of trajectories and parallel workers
19+
20+
### `simple_droid_vlm_example.py`
21+
Batch processes DROID trajectories with VLM using configurable prompts and answer extraction.
22+
23+
**Prerequisites:**
24+
Start qwen VLM server:
25+
```bash
26+
python3 -m sglang.launch_server --model-path Qwen/Qwen2.5-VL-32B-Instruct --host 0.0.0.0 --port 30000 --tp 4
27+
```
28+
29+
**Usage Examples:**
30+
31+
Binary classification:
32+
```bash
33+
python simple_droid_vlm_example.py --data-dir ./droid_data --prompt "Is this trajectory successful?" --answer-type binary --output results.csv
34+
```
35+
36+
Multiple choice:
37+
```bash
38+
python simple_droid_vlm_example.py --data-dir ./droid_data --prompt "What type of task is this?" --answer-type multiple_choice --choices pick place push other --output task_analysis.csv
39+
```
40+
41+
Numerical scoring:
42+
```bash
43+
python simple_droid_vlm_example.py --data-dir ./droid_data --prompt "Rate the trajectory quality from 1-10" --answer-type number --output quality_scores.csv
44+
```
45+
46+
With reasoning:
47+
```bash
48+
python simple_droid_vlm_example.py --data-dir ./droid_data --prompt "Is this successful?" --answer-type binary --reasoning --output detailed_results.csv
49+
```
50+
51+
## Answer Types
52+
53+
- **`binary`**: Extracts yes/no responses
54+
- **`number`**: Extracts numerical values
55+
- **`multiple_choice`**: Selects from provided choices
56+
- **`text`**: Extracts free text (first sentence)
57+
58+
## Output Format
59+
60+
CSV with columns:
61+
- `trajectory_path`: Path to trajectory directory
62+
- `trajectory_name`: Trajectory identifier
63+
- `extracted_answer`: Parsed answer based on type
64+
- `original_answer`: Full VLM response
65+
- `error`: Error message if processing failed
66+
67+
## Quick Start
68+
69+
1. Download trajectories:
70+
```bash
71+
python droid_download_example.py --local-dir ./droid_data --num-trajectories 10
72+
```
73+
74+
2. Start VLM server (see prerequisites above)
75+
76+
3. Process with VLM:
77+
```bash
78+
python simple_droid_vlm_example.py --data-dir ./droid_data --prompt "Is this trajectory successful?" --answer-type binary --output results.csv
79+
```
80+
81+
## Features
82+
83+
- ✅ Real VLM integration with qwen/sglang
84+
- ✅ User-configurable prompts and answer extraction
85+
- ✅ Structured CSV output
86+
- ✅ Multiple answer type support
87+
- ✅ Parallel processing capability
88+
- ✅ Error handling and logging
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#!/usr/bin/env python3
2+
"""
3+
DROID Trajectory Download Example
4+
5+
Concise example for downloading DROID trajectories to local storage with parallel processing.
6+
Downloads trajectories from GCS and converts them to robodm format for efficient processing.
7+
8+
Usage:
9+
python droid_download_example.py --gcs-pattern "gs://gresearch/robotics/droid_raw/1.0.1/*/success/*" --local-dir ./droid_data --num-trajectories 50
10+
"""
11+
12+
import argparse
13+
import logging
14+
import os
15+
import subprocess
16+
import tempfile
17+
from pathlib import Path
18+
from typing import List, Optional, Tuple
19+
20+
import ray
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
@ray.remote
26+
def download_single_trajectory(gcs_path: str, local_dir: str, temp_dir: str) -> Tuple[bool, Optional[str], str, str]:
27+
"""Download a single DROID trajectory from GCS to local directory."""
28+
try:
29+
# Extract meaningful name from nested structure: date/trajectory_name
30+
path_parts = gcs_path.rstrip("/").split("/")
31+
date_part = path_parts[-2] # e.g., "2023-07-07"
32+
traj_part = path_parts[-1] # e.g., "Fri_Jul__7_09:42:23_2023"
33+
trajectory_name = f"{date_part}_{traj_part}"
34+
local_trajectory_dir = Path(local_dir) / trajectory_name
35+
local_trajectory_dir.mkdir(parents=True, exist_ok=True)
36+
37+
# Use gsutil for efficient GCS download
38+
# Remove trailing slash and add /* for contents
39+
clean_gcs_path = gcs_path.rstrip("/")
40+
cmd = ["gsutil", "-m", "cp", "-r", f"{clean_gcs_path}/*", str(local_trajectory_dir)]
41+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=300)
42+
43+
if result.returncode == 0:
44+
logger.info(f"Downloaded {trajectory_name}")
45+
return True, str(local_trajectory_dir), "", trajectory_name
46+
else:
47+
error_msg = f"gsutil failed: {result.stderr}"
48+
logger.error(f"Failed to download {trajectory_name}: {error_msg}")
49+
return False, None, error_msg, trajectory_name
50+
51+
except Exception as e:
52+
error_msg = f"Exception during download: {str(e)}"
53+
logger.error(f"Error downloading {gcs_path}: {error_msg}")
54+
return False, None, error_msg, trajectory_name
55+
56+
57+
def scan_droid_trajectories(gcs_pattern: str, max_trajectories: Optional[int] = None) -> List[str]:
58+
"""Scan for DROID trajectories matching the GCS pattern."""
59+
try:
60+
# First get date directories
61+
cmd = ["gsutil", "ls", "-d", gcs_pattern]
62+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=60)
63+
64+
if result.returncode != 0:
65+
raise RuntimeError(f"gsutil ls failed: {result.stderr}")
66+
67+
date_dirs = [line.strip() for line in result.stdout.strip().split('\n') if line.strip()]
68+
69+
# Now get actual trajectory directories from each date
70+
all_trajectories = []
71+
for date_dir in date_dirs:
72+
if max_trajectories and len(all_trajectories) >= max_trajectories:
73+
break
74+
75+
cmd = ["gsutil", "ls", "-d", f"{date_dir.rstrip('/')}/*"]
76+
result = subprocess.run(cmd, capture_output=True, text=True, timeout=30)
77+
78+
if result.returncode == 0:
79+
traj_dirs = [line.strip() for line in result.stdout.strip().split('\n') if line.strip()]
80+
all_trajectories.extend(traj_dirs)
81+
82+
if max_trajectories and len(all_trajectories) > max_trajectories:
83+
all_trajectories = all_trajectories[:max_trajectories]
84+
85+
return all_trajectories
86+
87+
except Exception as e:
88+
logger.error(f"Failed to scan trajectories: {e}")
89+
return []
90+
91+
92+
def download_droid_trajectories(
93+
gcs_pattern: str,
94+
local_dir: str,
95+
num_trajectories: Optional[int] = None,
96+
parallel_downloads: int = 8
97+
) -> Tuple[List[str], List[str]]:
98+
"""
99+
Download DROID trajectories from GCS to local directory.
100+
101+
Args:
102+
gcs_pattern: GCS pattern for trajectory paths (e.g., "gs://path/*/success/*")
103+
local_dir: Local directory to store downloaded trajectories
104+
num_trajectories: Maximum number of trajectories to download
105+
parallel_downloads: Number of parallel downloads
106+
107+
Returns:
108+
Tuple of (successful_paths, failed_paths)
109+
"""
110+
if not ray.is_initialized():
111+
ray.init()
112+
113+
# Create local directory
114+
Path(local_dir).mkdir(parents=True, exist_ok=True)
115+
116+
# Scan for trajectories
117+
print(f"Scanning for trajectories matching: {gcs_pattern}")
118+
trajectory_paths = scan_droid_trajectories(gcs_pattern, num_trajectories)
119+
120+
if not trajectory_paths:
121+
print("No trajectories found matching the pattern")
122+
return [], []
123+
124+
print(f"Found {len(trajectory_paths)} trajectories to download")
125+
126+
# Create temporary directory for downloads
127+
with tempfile.TemporaryDirectory() as temp_dir:
128+
# Start parallel downloads
129+
print(f"Starting {parallel_downloads} parallel downloads...")
130+
131+
download_futures = []
132+
for gcs_path in trajectory_paths:
133+
future = download_single_trajectory.remote(gcs_path, local_dir, temp_dir)
134+
download_futures.append((future, gcs_path))
135+
136+
# Process results as they complete
137+
successful_paths = []
138+
failed_paths = []
139+
140+
for future, gcs_path in download_futures:
141+
try:
142+
success, local_path, error_msg, traj_name = ray.get(future)
143+
if success:
144+
successful_paths.append(local_path)
145+
print(f"✅ {traj_name}")
146+
else:
147+
failed_paths.append(gcs_path)
148+
print(f"❌ {traj_name}: {error_msg}")
149+
except Exception as e:
150+
failed_paths.append(gcs_path)
151+
print(f"❌ {gcs_path}: {e}")
152+
153+
print(f"\nDownload complete: {len(successful_paths)} successful, {len(failed_paths)} failed")
154+
return successful_paths, failed_paths
155+
156+
157+
def main():
158+
"""Download DROID trajectories from GCS."""
159+
parser = argparse.ArgumentParser(description="Download DROID trajectories from GCS")
160+
parser.add_argument("--gcs-pattern", default = "gs://gresearch/robotics/droid_raw/1.0.1/*/success/*",
161+
help="GCS pattern for trajectory paths")
162+
parser.add_argument("--local-dir", required=True,
163+
help="Local directory to store trajectories")
164+
parser.add_argument("--num-trajectories", type=int, default=None,
165+
help="Maximum number of trajectories to download")
166+
parser.add_argument("--parallel-downloads", type=int, default=8,
167+
help="Number of parallel downloads")
168+
169+
args = parser.parse_args()
170+
171+
# Setup logging
172+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
173+
174+
print("DROID Trajectory Downloader")
175+
print("=" * 50)
176+
print(f"GCS pattern: {args.gcs_pattern}")
177+
print(f"Local directory: {args.local_dir}")
178+
print(f"Max trajectories: {args.num_trajectories or 'All'}")
179+
print(f"Parallel downloads: {args.parallel_downloads}")
180+
print()
181+
182+
# Download trajectories
183+
successful_paths, failed_paths = download_droid_trajectories(
184+
gcs_pattern=args.gcs_pattern,
185+
local_dir=args.local_dir,
186+
num_trajectories=args.num_trajectories,
187+
parallel_downloads=args.parallel_downloads
188+
)
189+
190+
# Summary
191+
print(f"\n📊 Download Summary:")
192+
print(f"Successful: {len(successful_paths)}")
193+
print(f"Failed: {len(failed_paths)}")
194+
195+
if successful_paths:
196+
print(f"\nTrajectories saved to: {args.local_dir}")
197+
print("Ready for processing with robodm!")
198+
199+
200+
if __name__ == "__main__":
201+
main()

0 commit comments

Comments
 (0)