|
10 | 10 | from abc import abstractmethod
|
11 | 11 | from dataclasses import dataclass
|
12 | 12 | from multiprocessing import Process, Queue
|
| 13 | +from pathlib import Path |
13 | 14 | from queue import Empty
|
14 | 15 | from time import sleep, time
|
15 | 16 | from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
|
|
25 | 26 | _BOTO3_AVAILABLE,
|
26 | 27 | _DEFAULT_FAST_DEV_RUN_ITEMS,
|
27 | 28 | _INDEX_FILENAME,
|
| 29 | + _IS_IN_STUDIO, |
28 | 30 | _LIGHTNING_CLOUD_LATEST,
|
29 | 31 | _TORCH_GREATER_EQUAL_2_1_0,
|
30 | 32 | )
|
@@ -66,17 +68,21 @@ def _get_home_folder() -> str:
|
66 | 68 | return os.getenv("DATA_OPTIMIZER_HOME_FOLDER", os.path.expanduser("~"))
|
67 | 69 |
|
68 | 70 |
|
| 71 | +def _get_default_cache() -> str: |
| 72 | + return "/cache" if _IS_IN_STUDIO else tempfile.gettempdir() |
| 73 | + |
| 74 | + |
69 | 75 | def _get_cache_dir(name: Optional[str] = None) -> str:
|
70 | 76 | """Returns the cache directory used by the Cache to store the chunks."""
|
71 |
| - cache_dir = os.getenv("DATA_OPTIMIZER_CACHE_FOLDER", "/cache/chunks") |
| 77 | + cache_dir = os.getenv("DATA_OPTIMIZER_CACHE_FOLDER", f"{_get_default_cache()}/chunks") |
72 | 78 | if name is None:
|
73 | 79 | return cache_dir
|
74 | 80 | return os.path.join(cache_dir, name.lstrip("/"))
|
75 | 81 |
|
76 | 82 |
|
77 | 83 | def _get_cache_data_dir(name: Optional[str] = None) -> str:
|
78 | 84 | """Returns the cache data directory used by the DataProcessor workers to download the files."""
|
79 |
| - cache_dir = os.getenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", "/cache/data") |
| 85 | + cache_dir = os.getenv("DATA_OPTIMIZER_DATA_CACHE_FOLDER", f"{_get_default_cache()}/data") |
80 | 86 | if name is None:
|
81 | 87 | return os.path.join(cache_dir)
|
82 | 88 | return os.path.join(cache_dir, name.lstrip("/"))
|
@@ -222,18 +228,20 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
|
222 | 228 | )
|
223 | 229 | except Exception as e:
|
224 | 230 | print(e)
|
225 |
| - elif output_dir.path and os.path.isdir(output_dir.path): |
| 231 | + |
| 232 | + elif output_dir.path: |
226 | 233 | if tmpdir is None:
|
227 |
| - shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath))) |
| 234 | + output_filepath = os.path.join(output_dir.path, os.path.basename(local_filepath)) |
228 | 235 | else:
|
229 | 236 | output_filepath = os.path.join(output_dir.path, local_filepath.replace(tmpdir, "")[1:])
|
230 |
| - os.makedirs(os.path.dirname(output_filepath), exist_ok=True) |
231 |
| - shutil.copyfile(local_filepath, output_filepath) |
| 237 | + |
| 238 | + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) |
| 239 | + shutil.move(local_filepath, output_filepath) |
232 | 240 | else:
|
233 | 241 | raise ValueError(f"The provided {output_dir.path} isn't supported.")
|
234 | 242 |
|
235 | 243 | # Inform the remover to delete the file
|
236 |
| - if remove_queue: |
| 244 | + if remove_queue and os.path.exists(local_filepath): |
237 | 245 | remove_queue.put([local_filepath])
|
238 | 246 |
|
239 | 247 |
|
@@ -290,7 +298,10 @@ def _get_num_bytes(item: Any, base_path: str) -> int:
|
290 | 298 |
|
291 | 299 | num_bytes = 0
|
292 | 300 | for element in flattened_item:
|
293 |
| - if isinstance(element, str) and element.startswith(base_path) and os.path.exists(element): |
| 301 | + if isinstance(element, str): |
| 302 | + element = Path(element).resolve() |
| 303 | + if not element.exists(): |
| 304 | + continue |
294 | 305 | file_bytes = os.path.getsize(element)
|
295 | 306 | if file_bytes == 0:
|
296 | 307 | raise RuntimeError(f"The file {element} has 0 bytes!")
|
@@ -475,16 +486,22 @@ def _collect_paths(self) -> None:
|
475 | 486 | for item in self.items:
|
476 | 487 | flattened_item, spec = tree_flatten(item)
|
477 | 488 |
|
| 489 | + def is_path(element: Any) -> bool: |
| 490 | + if not isinstance(element, str): |
| 491 | + return False |
| 492 | + |
| 493 | + element: str = str(Path(element).resolve()) |
| 494 | + return ( |
| 495 | + element.startswith(self.input_dir.path) |
| 496 | + if self.input_dir.path is not None |
| 497 | + else os.path.exists(element) |
| 498 | + ) |
| 499 | + |
478 | 500 | # For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
|
479 | 501 | # Other alternative would be too slow.
|
480 | 502 | # TODO: Try using dictionary for higher accurary.
|
481 | 503 | indexed_paths = {
|
482 |
| - index: element |
483 |
| - for index, element in enumerate(flattened_item) |
484 |
| - if isinstance(element, str) |
485 |
| - and ( |
486 |
| - element.startswith(self.input_dir.path) if self.input_dir is not None else os.path.exists(element) |
487 |
| - ) # For speed reasons |
| 504 | + index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element) |
488 | 505 | }
|
489 | 506 |
|
490 | 507 | if len(indexed_paths) == 0:
|
@@ -947,7 +964,7 @@ def run(self, data_recipe: DataRecipe) -> None:
|
947 | 964 | print("Workers are finished.")
|
948 | 965 | result = data_recipe._done(len(user_items), self.delete_cached_files, self.output_dir)
|
949 | 966 |
|
950 |
| - if num_nodes == node_rank + 1: |
| 967 | + if num_nodes == node_rank + 1 and self.output_dir.url: |
951 | 968 | _create_dataset(
|
952 | 969 | input_dir=self.input_dir.path,
|
953 | 970 | storage_dir=self.output_dir.path,
|
|
0 commit comments