@@ -190,7 +190,14 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
190190 s3 = S3Client ()
191191
192192 while True :
193- local_filepath : Optional [str ] = upload_queue .get ()
193+ data : Optional [Union [str , Tuple [str , str ]]] = upload_queue .get ()
194+
195+ tmpdir = None
196+
197+ if isinstance (data , str ) or data is None :
198+ local_filepath = data
199+ else :
200+ tmpdir , local_filepath = data
194201
195202 # Terminate the process if we received a termination signal
196203 if local_filepath is None :
@@ -202,15 +209,25 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
202209
203210 if obj .scheme == "s3" :
204211 try :
212+ if tmpdir is None :
213+ output_filepath = os .path .join (str (obj .path ).lstrip ("/" ), os .path .basename (local_filepath ))
214+ else :
215+ output_filepath = os .path .join (str (obj .path ).lstrip ("/" ), local_filepath .replace (tmpdir , "" )[1 :])
216+
205217 s3 .client .upload_file (
206218 local_filepath ,
207219 obj .netloc ,
208- os . path . join ( str ( obj . path ). lstrip ( "/" ), os . path . basename ( local_filepath )) ,
220+ output_filepath ,
209221 )
210222 except Exception as e :
211223 print (e )
212224 elif output_dir .path and os .path .isdir (output_dir .path ):
213- shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
225+ if tmpdir is None :
226+ shutil .copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
227+ else :
228+ output_filepath = os .path .join (output_dir .path , local_filepath .replace (tmpdir , "" )[1 :])
229+ os .makedirs (os .path .dirname (output_filepath ), exist_ok = True )
230+ shutil .copyfile (local_filepath , output_filepath )
214231 else :
215232 raise ValueError (f"The provided { output_dir .path } isn't supported." )
216233
@@ -435,12 +452,15 @@ def _create_cache(self) -> None:
435452 )
436453 self .cache ._reader ._rank = _get_node_rank () * self .num_workers + self .worker_index
437454
438- def _try_upload (self , filepath : Optional [str ]) -> None :
439- if not filepath or (self .output_dir .url if self .output_dir .url else self .output_dir .path ) is None :
455+ def _try_upload (self , data : Optional [Union [ str , Tuple [ str , str ]] ]) -> None :
456+ if not data or (self .output_dir .url if self .output_dir .url else self .output_dir .path ) is None :
440457 return
441458
442- assert os .path .exists (filepath ), filepath
443- self .to_upload_queues [self ._counter % self .num_uploaders ].put (filepath )
459+ if isinstance (data , str ):
460+ assert os .path .exists (data ), data
461+ else :
462+ assert os .path .exists (data [- 1 ]), data
463+ self .to_upload_queues [self ._counter % self .num_uploaders ].put (data )
444464
445465 def _collect_paths (self ) -> None :
446466 if self .input_dir .path is None :
@@ -582,7 +602,7 @@ def _handle_data_transform_recipe(self, index: int) -> None:
582602 filepaths .append (os .path .join (directory , filename ))
583603
584604 for filepath in filepaths :
585- self ._try_upload (filepath )
605+ self ._try_upload (( output_dir , filepath ) )
586606
587607
588608class DataWorkerProcess (BaseWorker , Process ):
0 commit comments