1+ #!/usr/bin/env python3
2+ """
3+ DROID Trajectory Download Example
4+
5+ Concise example for downloading DROID trajectories to local storage with parallel processing.
6+ Downloads trajectories from GCS and converts them to robodm format for efficient processing.
7+
8+ Usage:
9+ python droid_download_example.py --gcs-pattern "gs://gresearch/robotics/droid_raw/1.0.1/*/success/*" --local-dir ./droid_data --num-trajectories 50
10+ """
11+
12+ import argparse
13+ import logging
14+ import os
15+ import subprocess
16+ import tempfile
17+ from pathlib import Path
18+ from typing import List , Optional , Tuple
19+
20+ import ray
21+
22+ logger = logging .getLogger (__name__ )
23+
24+
25+ @ray .remote
26+ def download_single_trajectory (gcs_path : str , local_dir : str , temp_dir : str ) -> Tuple [bool , Optional [str ], str , str ]:
27+ """Download a single DROID trajectory from GCS to local directory."""
28+ try :
29+ # Extract meaningful name from nested structure: date/trajectory_name
30+ path_parts = gcs_path .rstrip ("/" ).split ("/" )
31+ date_part = path_parts [- 2 ] # e.g., "2023-07-07"
32+ traj_part = path_parts [- 1 ] # e.g., "Fri_Jul__7_09:42:23_2023"
33+ trajectory_name = f"{ date_part } _{ traj_part } "
34+ local_trajectory_dir = Path (local_dir ) / trajectory_name
35+ local_trajectory_dir .mkdir (parents = True , exist_ok = True )
36+
37+ # Use gsutil for efficient GCS download
38+ # Remove trailing slash and add /* for contents
39+ clean_gcs_path = gcs_path .rstrip ("/" )
40+ cmd = ["gsutil" , "-m" , "cp" , "-r" , f"{ clean_gcs_path } /*" , str (local_trajectory_dir )]
41+ result = subprocess .run (cmd , capture_output = True , text = True , timeout = 300 )
42+
43+ if result .returncode == 0 :
44+ logger .info (f"Downloaded { trajectory_name } " )
45+ return True , str (local_trajectory_dir ), "" , trajectory_name
46+ else :
47+ error_msg = f"gsutil failed: { result .stderr } "
48+ logger .error (f"Failed to download { trajectory_name } : { error_msg } " )
49+ return False , None , error_msg , trajectory_name
50+
51+ except Exception as e :
52+ error_msg = f"Exception during download: { str (e )} "
53+ logger .error (f"Error downloading { gcs_path } : { error_msg } " )
54+ return False , None , error_msg , trajectory_name
55+
56+
57+ def scan_droid_trajectories (gcs_pattern : str , max_trajectories : Optional [int ] = None ) -> List [str ]:
58+ """Scan for DROID trajectories matching the GCS pattern."""
59+ try :
60+ # First get date directories
61+ cmd = ["gsutil" , "ls" , "-d" , gcs_pattern ]
62+ result = subprocess .run (cmd , capture_output = True , text = True , timeout = 60 )
63+
64+ if result .returncode != 0 :
65+ raise RuntimeError (f"gsutil ls failed: { result .stderr } " )
66+
67+ date_dirs = [line .strip () for line in result .stdout .strip ().split ('\n ' ) if line .strip ()]
68+
69+ # Now get actual trajectory directories from each date
70+ all_trajectories = []
71+ for date_dir in date_dirs :
72+ if max_trajectories and len (all_trajectories ) >= max_trajectories :
73+ break
74+
75+ cmd = ["gsutil" , "ls" , "-d" , f"{ date_dir .rstrip ('/' )} /*" ]
76+ result = subprocess .run (cmd , capture_output = True , text = True , timeout = 30 )
77+
78+ if result .returncode == 0 :
79+ traj_dirs = [line .strip () for line in result .stdout .strip ().split ('\n ' ) if line .strip ()]
80+ all_trajectories .extend (traj_dirs )
81+
82+ if max_trajectories and len (all_trajectories ) > max_trajectories :
83+ all_trajectories = all_trajectories [:max_trajectories ]
84+
85+ return all_trajectories
86+
87+ except Exception as e :
88+ logger .error (f"Failed to scan trajectories: { e } " )
89+ return []
90+
91+
92+ def download_droid_trajectories (
93+ gcs_pattern : str ,
94+ local_dir : str ,
95+ num_trajectories : Optional [int ] = None ,
96+ parallel_downloads : int = 8
97+ ) -> Tuple [List [str ], List [str ]]:
98+ """
99+ Download DROID trajectories from GCS to local directory.
100+
101+ Args:
102+ gcs_pattern: GCS pattern for trajectory paths (e.g., "gs://path/*/success/*")
103+ local_dir: Local directory to store downloaded trajectories
104+ num_trajectories: Maximum number of trajectories to download
105+ parallel_downloads: Number of parallel downloads
106+
107+ Returns:
108+ Tuple of (successful_paths, failed_paths)
109+ """
110+ if not ray .is_initialized ():
111+ ray .init ()
112+
113+ # Create local directory
114+ Path (local_dir ).mkdir (parents = True , exist_ok = True )
115+
116+ # Scan for trajectories
117+ print (f"Scanning for trajectories matching: { gcs_pattern } " )
118+ trajectory_paths = scan_droid_trajectories (gcs_pattern , num_trajectories )
119+
120+ if not trajectory_paths :
121+ print ("No trajectories found matching the pattern" )
122+ return [], []
123+
124+ print (f"Found { len (trajectory_paths )} trajectories to download" )
125+
126+ # Create temporary directory for downloads
127+ with tempfile .TemporaryDirectory () as temp_dir :
128+ # Start parallel downloads
129+ print (f"Starting { parallel_downloads } parallel downloads..." )
130+
131+ download_futures = []
132+ for gcs_path in trajectory_paths :
133+ future = download_single_trajectory .remote (gcs_path , local_dir , temp_dir )
134+ download_futures .append ((future , gcs_path ))
135+
136+ # Process results as they complete
137+ successful_paths = []
138+ failed_paths = []
139+
140+ for future , gcs_path in download_futures :
141+ try :
142+ success , local_path , error_msg , traj_name = ray .get (future )
143+ if success :
144+ successful_paths .append (local_path )
145+ print (f"✅ { traj_name } " )
146+ else :
147+ failed_paths .append (gcs_path )
148+ print (f"❌ { traj_name } : { error_msg } " )
149+ except Exception as e :
150+ failed_paths .append (gcs_path )
151+ print (f"❌ { gcs_path } : { e } " )
152+
153+ print (f"\n Download complete: { len (successful_paths )} successful, { len (failed_paths )} failed" )
154+ return successful_paths , failed_paths
155+
156+
157+ def main ():
158+ """Download DROID trajectories from GCS."""
159+ parser = argparse .ArgumentParser (description = "Download DROID trajectories from GCS" )
160+ parser .add_argument ("--gcs-pattern" , default = "gs://gresearch/robotics/droid_raw/1.0.1/*/success/*" ,
161+ help = "GCS pattern for trajectory paths" )
162+ parser .add_argument ("--local-dir" , required = True ,
163+ help = "Local directory to store trajectories" )
164+ parser .add_argument ("--num-trajectories" , type = int , default = None ,
165+ help = "Maximum number of trajectories to download" )
166+ parser .add_argument ("--parallel-downloads" , type = int , default = 8 ,
167+ help = "Number of parallel downloads" )
168+
169+ args = parser .parse_args ()
170+
171+ # Setup logging
172+ logging .basicConfig (level = logging .INFO , format = '%(levelname)s: %(message)s' )
173+
174+ print ("DROID Trajectory Downloader" )
175+ print ("=" * 50 )
176+ print (f"GCS pattern: { args .gcs_pattern } " )
177+ print (f"Local directory: { args .local_dir } " )
178+ print (f"Max trajectories: { args .num_trajectories or 'All' } " )
179+ print (f"Parallel downloads: { args .parallel_downloads } " )
180+ print ()
181+
182+ # Download trajectories
183+ successful_paths , failed_paths = download_droid_trajectories (
184+ gcs_pattern = args .gcs_pattern ,
185+ local_dir = args .local_dir ,
186+ num_trajectories = args .num_trajectories ,
187+ parallel_downloads = args .parallel_downloads
188+ )
189+
190+ # Summary
191+ print (f"\n 📊 Download Summary:" )
192+ print (f"Successful: { len (successful_paths )} " )
193+ print (f"Failed: { len (failed_paths )} " )
194+
195+ if successful_paths :
196+ print (f"\n Trajectories saved to: { args .local_dir } " )
197+ print ("Ready for processing with robodm!" )
198+
199+
200+ if __name__ == "__main__" :
201+ main ()
0 commit comments