Skip to content

Commit 8ec2eb3

Browse files
fix: redownload weights only on PytorchStreamReader error (#686)
1 parent 4e42986 commit 8ec2eb3

File tree

3 files changed

+40
-55
lines changed

3 files changed

+40
-55
lines changed

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/base_vision_agent.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515

1616
import os
17-
import shutil
1817
import subprocess
1918
from pathlib import Path
2019

@@ -24,63 +23,65 @@
2423

2524
class BaseVisionAgent(BaseAgent):
2625
WEIGHTS_URL: str = ""
26+
DEFAULT_WEIGHTS_ROOT_PATH: Path = Path.home() / Path(".cache/rai/")
27+
WEIGHTS_DIR_PATH_PART: Path = Path("vision/weights")
2728
WEIGHTS_FILENAME: str = ""
2829

2930
def __init__(
3031
self,
31-
weights_path: str | Path = Path.home() / Path(".cache/rai/"),
32+
weights_root_path: str | Path = DEFAULT_WEIGHTS_ROOT_PATH,
3233
ros2_name: str = "",
3334
):
35+
if not self.WEIGHTS_FILENAME:
36+
raise ValueError("WEIGHTS_FILENAME is not set")
3437
super().__init__()
35-
self._weights_path = Path(weights_path)
36-
os.makedirs(self._weights_path, exist_ok=True)
37-
self._init_weight_path()
38-
self.weight_path = self._weights_path
38+
self.weights_root_path = Path(weights_root_path)
39+
self.weights_root_path.mkdir(parents=True, exist_ok=True)
40+
self.weights_path = (
41+
self.weights_root_path / self.WEIGHTS_DIR_PATH_PART / self.WEIGHTS_FILENAME
42+
)
43+
if not self.weights_path.exists():
44+
self._download_weights()
3945
self.ros2_connector = ROS2Connector(ros2_name, executor_type="single_threaded")
4046

41-
def _init_weight_path(self):
42-
try:
43-
if self.WEIGHTS_FILENAME == "":
44-
raise ValueError("WEIGHTS_FILENAME is not set")
47+
def _load_model_with_error_handling(self, model_class):
48+
"""Load model with automatic error handling for corrupted weights.
4549
46-
install_path = (
47-
self._weights_path / "vision" / "weights" / self.WEIGHTS_FILENAME
48-
)
49-
# make sure the file exists
50-
if install_path.exists() and install_path.is_file():
51-
self._weights_path = install_path
52-
else:
53-
self._remove_weights(path=install_path)
54-
self._download_weights(install_path)
55-
self._weights_path = install_path
50+
Args:
51+
model_class: A class that can be instantiated with weights_path
5652
57-
except Exception:
58-
self.logger.error("Could not find package path")
59-
raise Exception("Could not find package path")
53+
Returns:
54+
The loaded model instance
55+
"""
56+
try:
57+
return model_class(self.weights_path)
58+
except RuntimeError as e:
59+
self.logger.error(f"Could not load model: {e}")
60+
if "PytorchStreamReader" in str(e):
61+
self.logger.error("The weights might be corrupted. Redownloading...")
62+
self._remove_weights()
63+
self._download_weights()
64+
return model_class(self.weights_path)
65+
else:
66+
raise e
6067

61-
def _download_weights(self, path: Path):
68+
def _download_weights(self):
6269
try:
63-
os.makedirs(path.parent, exist_ok=True)
6470
subprocess.run(
6571
[
6672
"wget",
6773
self.WEIGHTS_URL,
6874
"-O",
69-
path,
75+
self.weights_path,
7076
"--progress=dot:giga",
7177
]
7278
)
7379
except Exception:
7480
self.logger.error("Could not download weights")
7581
raise Exception("Could not download weights")
7682

77-
def _remove_weights(self, path: str):
78-
# Sometimes redownloding weights bugged and created a dir
79-
# so check also for dir and remove it in both cases
80-
if os.path.isdir(path):
81-
shutil.rmtree(path)
82-
elif os.path.isfile(path):
83-
os.remove(path)
83+
def _remove_weights(self):
84+
os.remove(self.weights_path)
8485

8586
def stop(self):
8687
self.ros2_connector.shutdown()

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounded_sam.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,11 @@ class GroundedSamAgent(BaseVisionAgent):
3434

3535
def __init__(
3636
self,
37-
weights_path: str | Path = Path.home() / Path(".cache/rai"),
37+
weights_root_path: str | Path = Path.home() / Path(".cache/rai"),
3838
ros2_name: str = GSAM_NODE_NAME,
3939
):
40-
super().__init__(weights_path, ros2_name)
41-
try:
42-
self._segmenter = GDSegmenter(self._weights_path)
43-
except Exception as e:
44-
self.logger.error(
45-
f"Could not load model : {e}. The weights might be corrupted. Redownloading..."
46-
)
47-
self._remove_weights(self.weight_path)
48-
self._init_weight_path()
49-
self.segmenter = GDSegmenter(self.weight_path)
40+
super().__init__(weights_root_path, ros2_name)
41+
self._segmenter = self._load_model_with_error_handling(GDSegmenter)
5042
self.logger.info(f"{self.__class__.__name__} initialized")
5143

5244
def run(self):

src/rai_extensions/rai_open_set_vision/rai_open_set_vision/agents/grounding_dino.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,11 @@ class GroundingDinoAgent(BaseVisionAgent):
2929

3030
def __init__(
3131
self,
32-
weights_path: str | Path = Path.home() / Path(".cache/rai"),
32+
weights_root_path: str | Path = Path.home() / Path(".cache/rai"),
3333
ros2_name: str = GDINO_NODE_NAME,
3434
):
35-
super().__init__(weights_path, ros2_name)
36-
try:
37-
self._boxer = GDBoxer(self._weights_path)
38-
except Exception as e:
39-
self.logger.error(
40-
f"Could not load model: {e}, The weights might be corrupted. Redownloading..."
41-
)
42-
self._remove_weights(self.weight_path)
43-
self._init_weight_path()
44-
self.segmenter = GDBoxer(self.weight_path)
35+
super().__init__(weights_root_path, ros2_name)
36+
self._boxer = self._load_model_with_error_handling(GDBoxer)
4537
self.logger.info(f"{self.__class__.__name__} initialized")
4638

4739
def run(self):

0 commit comments

Comments
 (0)