|
14 | 14 |
|
15 | 15 |
|
16 | 16 | import os
|
17 |
| -import shutil |
18 | 17 | import subprocess
|
19 | 18 | from pathlib import Path
|
20 | 19 |
|
|
24 | 23 |
|
25 | 24 | class BaseVisionAgent(BaseAgent):
|
26 | 25 | WEIGHTS_URL: str = ""
|
| 26 | + DEFAULT_WEIGHTS_ROOT_PATH: Path = Path.home() / Path(".cache/rai/") |
| 27 | + WEIGHTS_DIR_PATH_PART: Path = Path("vision/weights") |
27 | 28 | WEIGHTS_FILENAME: str = ""
|
28 | 29 |
|
29 | 30 | def __init__(
|
30 | 31 | self,
|
31 |
| - weights_path: str | Path = Path.home() / Path(".cache/rai/"), |
| 32 | + weights_root_path: str | Path = DEFAULT_WEIGHTS_ROOT_PATH, |
32 | 33 | ros2_name: str = "",
|
33 | 34 | ):
|
| 35 | + if not self.WEIGHTS_FILENAME: |
| 36 | + raise ValueError("WEIGHTS_FILENAME is not set") |
34 | 37 | super().__init__()
|
35 |
| - self._weights_path = Path(weights_path) |
36 |
| - os.makedirs(self._weights_path, exist_ok=True) |
37 |
| - self._init_weight_path() |
38 |
| - self.weight_path = self._weights_path |
| 38 | + self.weights_root_path = Path(weights_root_path) |
| 39 | + self.weights_root_path.mkdir(parents=True, exist_ok=True) |
| 40 | + self.weights_path = ( |
| 41 | + self.weights_root_path / self.WEIGHTS_DIR_PATH_PART / self.WEIGHTS_FILENAME |
| 42 | + ) |
| 43 | + if not self.weights_path.exists(): |
| 44 | + self._download_weights() |
39 | 45 | self.ros2_connector = ROS2Connector(ros2_name, executor_type="single_threaded")
|
40 | 46 |
|
41 |
| - def _init_weight_path(self): |
42 |
| - try: |
43 |
| - if self.WEIGHTS_FILENAME == "": |
44 |
| - raise ValueError("WEIGHTS_FILENAME is not set") |
| 47 | + def _load_model_with_error_handling(self, model_class): |
| 48 | + """Load model with automatic error handling for corrupted weights. |
45 | 49 |
|
46 |
| - install_path = ( |
47 |
| - self._weights_path / "vision" / "weights" / self.WEIGHTS_FILENAME |
48 |
| - ) |
49 |
| - # make sure the file exists |
50 |
| - if install_path.exists() and install_path.is_file(): |
51 |
| - self._weights_path = install_path |
52 |
| - else: |
53 |
| - self._remove_weights(path=install_path) |
54 |
| - self._download_weights(install_path) |
55 |
| - self._weights_path = install_path |
| 50 | + Args: |
| 51 | + model_class: A class that can be instantiated with weights_path |
56 | 52 |
|
57 |
| - except Exception: |
58 |
| - self.logger.error("Could not find package path") |
59 |
| - raise Exception("Could not find package path") |
| 53 | + Returns: |
| 54 | + The loaded model instance |
| 55 | + """ |
| 56 | + try: |
| 57 | + return model_class(self.weights_path) |
| 58 | + except RuntimeError as e: |
| 59 | + self.logger.error(f"Could not load model: {e}") |
| 60 | + if "PytorchStreamReader" in str(e): |
| 61 | + self.logger.error("The weights might be corrupted. Redownloading...") |
| 62 | + self._remove_weights() |
| 63 | + self._download_weights() |
| 64 | + return model_class(self.weights_path) |
| 65 | + else: |
| 66 | + raise e |
60 | 67 |
|
61 |
| - def _download_weights(self, path: Path): |
| 68 | + def _download_weights(self): |
62 | 69 | try:
|
63 |
| - os.makedirs(path.parent, exist_ok=True) |
64 | 70 | subprocess.run(
|
65 | 71 | [
|
66 | 72 | "wget",
|
67 | 73 | self.WEIGHTS_URL,
|
68 | 74 | "-O",
|
69 |
| - path, |
| 75 | + self.weights_path, |
70 | 76 | "--progress=dot:giga",
|
71 | 77 | ]
|
72 | 78 | )
|
73 | 79 | except Exception:
|
74 | 80 | self.logger.error("Could not download weights")
|
75 | 81 | raise Exception("Could not download weights")
|
76 | 82 |
|
77 |
| - def _remove_weights(self, path: str): |
78 |
| - # Sometimes redownloding weights bugged and created a dir |
79 |
| - # so check also for dir and remove it in both cases |
80 |
| - if os.path.isdir(path): |
81 |
| - shutil.rmtree(path) |
82 |
| - elif os.path.isfile(path): |
83 |
| - os.remove(path) |
| 83 | + def _remove_weights(self): |
| 84 | + os.remove(self.weights_path) |
84 | 85 |
|
85 | 86 | def stop(self):
|
86 | 87 | self.ros2_connector.shutdown()
|
0 commit comments