Skip to content

Commit 0629874

Browse files
committed
fix: lazy inport error
1 parent 5ec528e commit 0629874

File tree

6 files changed

+224
-14
lines changed

6 files changed

+224
-14
lines changed

src/rs_embed/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
"""
99

1010
from ._version import __version__
11-
from .api import describe_model, export_batch, get_embedding, get_embeddings_batch, list_models
11+
from .api import (
12+
describe_model,
13+
export_batch,
14+
get_embedding,
15+
get_embeddings_batch,
16+
list_models,
17+
reset_runtime,
18+
)
1219
from .core.specs import (
1320
BBox,
1421
FetchSpec,
@@ -53,6 +60,7 @@
5360
"get_embeddings_batch",
5461
"list_models",
5562
"describe_model",
63+
"reset_runtime",
5664
# Export API
5765
"export_batch",
5866
"export_npz",

src/rs_embed/api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
provider_factory_for_backend,
9292
)
9393
from .tools.runtime import (
94+
reset_runtime as _reset_runtime_shared,
9495
run_embedding_request as _run_embedding_request_shared,
9596
)
9697

@@ -156,6 +157,20 @@ def describe_model(model: str) -> dict[str, Any]:
156157
return cls().describe()
157158

158159

160+
def reset_runtime() -> dict[str, int]:
161+
"""Clear lazy-import/runtime caches in the current Python process.
162+
163+
This is mainly useful in notebooks after a failed model import or when you
164+
want to force fresh embedder instances without restarting the kernel.
165+
166+
Returns
167+
-------
168+
dict[str, int]
169+
Summary counts describing how many runtime/import caches were cleared.
170+
"""
171+
return _reset_runtime_shared()
172+
173+
159174
def get_embedding(
160175
model: str,
161176
*,

src/rs_embed/core/registry.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from __future__ import annotations
33

44
import importlib
5+
import sys
56
from typing import Any
67

78
from rs_embed.embedders.catalog import MODEL_SPECS, canonical_model_id
@@ -45,15 +46,24 @@ def _try_lazy_load_model(name: str) -> None:
4546
return
4647
module_name, class_name = spec
4748
fqmn = f"rs_embed.embedders.{module_name}"
49+
sys_modules_before = frozenset(sys.modules.keys())
4850
try:
4951
mod = importlib.import_module(fqmn)
5052
except Exception as e:
53+
_cleanup_failed_embedder_import(
54+
fqmn=fqmn,
55+
sys_modules_before=sys_modules_before,
56+
)
5157
_REGISTRY_IMPORT_ERRORS[model_id] = e
5258
return
5359

5460
try:
5561
cls = getattr(mod, class_name)
5662
except Exception as e:
63+
_cleanup_failed_embedder_import(
64+
fqmn=fqmn,
65+
sys_modules_before=sys_modules_before,
66+
)
5767
_REGISTRY_IMPORT_ERRORS[model_id] = e
5868
return
5969

@@ -64,6 +74,30 @@ def _try_lazy_load_model(name: str) -> None:
6474
_REGISTRY_IMPORT_ERRORS.pop(model_id, None)
6575

6676

77+
def _cleanup_failed_embedder_import(
78+
*,
79+
fqmn: str,
80+
sys_modules_before: frozenset[str],
81+
) -> None:
82+
"""Remove modules imported during a failed lazy-load attempt.
83+
84+
A failed import can leave half-initialized embedder/vendor modules in
85+
``sys.modules`` for the lifetime of the notebook kernel. Clean only the
86+
modules created during this import attempt so the next retry gets a fresh
87+
import path.
88+
"""
89+
module_names = tuple(sys.modules.keys())
90+
for mod_name in module_names:
91+
if mod_name in sys_modules_before:
92+
continue
93+
if mod_name == fqmn or mod_name.startswith(f"{fqmn}."):
94+
sys.modules.pop(mod_name, None)
95+
continue
96+
if mod_name.startswith("rs_embed.embedders._vendor"):
97+
sys.modules.pop(mod_name, None)
98+
importlib.invalidate_caches()
99+
100+
67101
def get_embedder_cls(name: str) -> type[Any]:
68102
"""Resolve and return the embedder class for ``name``.
69103

src/rs_embed/tools/runtime.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import inspect
11+
import sys
1112
import time
1213
from collections.abc import Callable
1314
from dataclasses import dataclass
@@ -19,6 +20,7 @@
1920

2021
from ..core.embedding import Embedding
2122
from ..core.errors import ModelError
23+
from ..core import registry as _runtime_registry
2224
from ..core.registry import get_embedder_cls
2325
from ..core.specs import OutputSpec, SensorSpec, SpatialSpec, TemporalSpec
2426
from ..core.types import FetchResult
@@ -59,6 +61,50 @@ def get_embedder_bundle_cached(model: str, backend: str, device: str, sensor_k:
5961
return emb, RLock()
6062

6163

64+
def _clear_loaded_embedder_module_caches() -> int:
65+
"""Clear ``lru_cache`` wrappers found on already-imported embedder modules."""
66+
cleared = 0
67+
seen: set[int] = set()
68+
for module_name, module in tuple(sys.modules.items()):
69+
if module is None or not module_name.startswith("rs_embed.embedders."):
70+
continue
71+
for obj in vars(module).values():
72+
cache_clear = getattr(obj, "cache_clear", None)
73+
if not callable(cache_clear):
74+
continue
75+
obj_id = id(obj)
76+
if obj_id in seen:
77+
continue
78+
cache_clear()
79+
seen.add(obj_id)
80+
cleared += 1
81+
return cleared
82+
83+
84+
def reset_runtime() -> dict[str, int]:
85+
"""Clear embedder runtime caches without dropping registered classes.
86+
87+
This is a notebook-friendly escape hatch for recovering from stale runtime
88+
state after failed lazy imports or cached loader state. It preserves any
89+
custom classes already registered in ``core.registry`` while clearing
90+
instance caches and lazy-loader bookkeeping.
91+
"""
92+
import_errors_cleared = len(_runtime_registry._REGISTRY_IMPORT_ERRORS)
93+
_runtime_registry._REGISTRY_IMPORT_ERRORS.clear()
94+
95+
get_embedder_bundle_cached.cache_clear()
96+
_embedder_method_accepts_parameter.cache_clear()
97+
embedder_accepts_input_chw.cache_clear()
98+
embedder_accepts_model_config.cache_clear()
99+
embedder_module_caches_cleared = _clear_loaded_embedder_module_caches()
100+
101+
return {
102+
"import_errors_cleared": int(import_errors_cleared),
103+
"runtime_caches_cleared": 4,
104+
"embedder_module_caches_cleared": int(embedder_module_caches_cleared),
105+
}
106+
107+
62108
def sensor_key(sensor: SensorSpec | None) -> tuple:
63109
if sensor is None:
64110
return ("__none__",)

tests/test_api.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
GEE, torch, or any real model weights.
55
"""
66

7+
import functools
8+
import sys
9+
import types
10+
711
import numpy as np
812
import pytest
913

1014
import rs_embed.api as api
1115
import rs_embed.tools.runtime as rt
12-
from rs_embed import list_models
16+
from rs_embed import list_models, reset_runtime
1317
from rs_embed.api import (
1418
_assert_supported,
1519
_validate_specs,
@@ -238,6 +242,35 @@ def test_get_embedding_unknown_model():
238242
get_embedding("nonexistent", spatial=_SPATIAL)
239243

240244

245+
def test_reset_runtime_clears_runtime_and_embedder_module_caches(monkeypatch):
246+
rt.get_embedder_bundle_cached("mock_model", "auto", "auto", sensor_key(None))
247+
rt.embedder_accepts_input_chw(type(_MockEmbedder))
248+
rt.embedder_accepts_model_config(type(_MockVariantEmbedder))
249+
registry._REGISTRY_IMPORT_ERRORS["remoteclip"] = RuntimeError("boom")
250+
251+
fake_mod = types.ModuleType("rs_embed.embedders._reset_runtime_fake")
252+
253+
@functools.lru_cache(maxsize=4)
254+
def _cached_loader(x):
255+
return x
256+
257+
fake_mod._cached_loader = _cached_loader
258+
fake_mod._cached_loader(1)
259+
monkeypatch.setitem(sys.modules, fake_mod.__name__, fake_mod)
260+
261+
summary = reset_runtime()
262+
263+
assert summary["import_errors_cleared"] == 1
264+
assert summary["runtime_caches_cleared"] == 4
265+
assert summary["embedder_module_caches_cleared"] >= 1
266+
assert rt.get_embedder_bundle_cached.cache_info().currsize == 0
267+
assert rt.embedder_accepts_input_chw.cache_info().currsize == 0
268+
assert rt.embedder_accepts_model_config.cache_info().currsize == 0
269+
assert _cached_loader.cache_info().currsize == 0
270+
assert registry._REGISTRY_IMPORT_ERRORS == {}
271+
assert registry.get_embedder_cls("mock_model") is _MockEmbedder
272+
273+
241274
def test_get_embedding_modality_resolves_default_sensor():
242275

243276
registry.register("mock_multi")(_MockMultimodalEmbedder)

tests/test_registry.py

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
import sys
2+
import types
3+
14
import pytest
25

36
from rs_embed.core import registry
7+
from rs_embed.embedders import catalog
48

59
# ── fixture to isolate registry between tests ──────────────────────
610

@@ -17,6 +21,26 @@ def clean_registry():
1721
registry._REGISTRY_IMPORT_ERRORS.clear()
1822

1923

24+
def _install_fake_lazy_model(
25+
monkeypatch,
26+
*,
27+
model_id: str,
28+
module_name: str,
29+
class_name: str,
30+
alias: str | None = None,
31+
):
32+
monkeypatch.setitem(catalog.MODEL_SPECS, model_id, (module_name, class_name))
33+
if alias is not None:
34+
monkeypatch.setitem(catalog.MODEL_ALIASES, alias, model_id)
35+
36+
fqmn = f"rs_embed.embedders.{module_name}"
37+
mod = types.ModuleType(fqmn)
38+
cls = type(class_name, (), {})
39+
setattr(mod, class_name, cls)
40+
monkeypatch.setitem(sys.modules, fqmn, mod)
41+
return fqmn, cls
42+
43+
2044
# ══════════════════════════════════════════════════════════════════════
2145
# register + get_embedder_cls
2246
# ══════════════════════════════════════════════════════════════════════
@@ -112,6 +136,12 @@ def test_get_embedder_cls_includes_last_import_error(monkeypatch):
112136

113137

114138
def test_get_embedder_cls_lazy_imports_builtin_without_bulk_package_import(monkeypatch):
139+
fqmn, cls_expected = _install_fake_lazy_model(
140+
monkeypatch,
141+
model_id="fake_lazy",
142+
module_name="onthefly_test_lazy",
143+
class_name="FakeLazyEmbedder",
144+
)
115145
calls = []
116146
orig_import_module = registry.importlib.import_module
117147

@@ -121,17 +151,52 @@ def _spy(name, *args, **kwargs):
121151

122152
monkeypatch.setattr(registry.importlib, "import_module", _spy)
123153

124-
cls = registry.get_embedder_cls("remoteclip")
125-
assert cls.__name__ == "RemoteCLIPS2RGBEmbedder"
126-
assert "remoteclip" in registry.list_models()
127-
assert "remoteclip_s2rgb" not in registry.list_models()
154+
cls = registry.get_embedder_cls("fake_lazy")
155+
assert cls is cls_expected
156+
assert "fake_lazy" in registry.list_models()
128157
assert "rs_embed.embedders" not in calls
129-
assert "rs_embed.embedders.onthefly_remoteclip" in calls
158+
assert fqmn in calls
159+
160+
161+
def test_get_embedder_cls_cleans_failed_lazy_import_modules(monkeypatch):
162+
from rs_embed.core.errors import ModelError
163+
164+
model_id = "fake_lazy_fail"
165+
module_name = "onthefly_test_lazy_fail"
166+
class_name = "FakeLazyFailEmbedder"
167+
monkeypatch.setitem(catalog.MODEL_SPECS, model_id, (module_name, class_name))
168+
fqmn = f"rs_embed.embedders.{module_name}"
169+
vendor_name = "rs_embed.embedders._vendor.fake_test_lazy_fail"
170+
orig_import_module = registry.importlib.import_module
171+
172+
def _boom(name, *args, **kwargs):
173+
if name == fqmn:
174+
sys.modules[fqmn] = types.ModuleType(fqmn)
175+
sys.modules[vendor_name] = types.ModuleType(vendor_name)
176+
raise RuntimeError("boom")
177+
return orig_import_module(name, *args, **kwargs)
178+
179+
monkeypatch.setattr(registry.importlib, "import_module", _boom)
130180

181+
with pytest.raises(ModelError, match="RuntimeError: boom"):
182+
registry.get_embedder_cls(model_id)
131183

132-
def test_get_embedder_cls_accepts_legacy_alias():
133-
cls_new = registry.get_embedder_cls("remoteclip")
134-
cls_old = registry.get_embedder_cls("remoteclip_s2rgb")
184+
assert fqmn not in sys.modules
185+
assert vendor_name not in sys.modules
186+
187+
188+
def test_get_embedder_cls_accepts_legacy_alias(monkeypatch):
189+
_, cls_expected = _install_fake_lazy_model(
190+
monkeypatch,
191+
model_id="fake_alias_target",
192+
module_name="onthefly_test_lazy_alias",
193+
class_name="FakeAliasEmbedder",
194+
alias="fake_alias_old",
195+
)
196+
197+
cls_new = registry.get_embedder_cls("fake_alias_target")
198+
cls_old = registry.get_embedder_cls("fake_alias_old")
199+
assert cls_new is cls_expected
135200
assert cls_old is cls_new
136201

137202

@@ -151,9 +216,18 @@ def test_get_embedder_cls_accepts_satmaepp_s2_aliases():
151216
assert cls_alt is cls_new
152217

153218

154-
def test_get_embedder_cls_can_reregister_when_registry_was_cleared():
155-
cls1 = registry.get_embedder_cls("remoteclip")
219+
def test_get_embedder_cls_can_reregister_when_registry_was_cleared(monkeypatch):
220+
_, cls_expected = _install_fake_lazy_model(
221+
monkeypatch,
222+
model_id="fake_reregister",
223+
module_name="onthefly_test_lazy_reregister",
224+
class_name="FakeReregisterEmbedder",
225+
alias="fake_reregister_old",
226+
)
227+
228+
cls1 = registry.get_embedder_cls("fake_reregister")
156229
registry._REGISTRY.clear()
157-
cls2 = registry.get_embedder_cls("remoteclip_s2rgb")
230+
cls2 = registry.get_embedder_cls("fake_reregister_old")
231+
assert cls1 is cls_expected
158232
assert cls2 is cls1
159-
assert registry.get_embedder_cls("remoteclip") is cls1
233+
assert registry.get_embedder_cls("fake_reregister") is cls1

0 commit comments

Comments
 (0)