Skip to content

Commit 0da92d7

Browse files
committed
refactor: replace custom missing provider implementation with utility function
1 parent 9028dc3 commit 0da92d7

File tree

9 files changed

+171
-108
lines changed

9 files changed

+171
-108
lines changed
Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
from __future__ import annotations
22

3+
from graphium_core.utils import missing_provider
4+
35
from .generic import GenericLiteLLMReranker
46
from .openai import OpenAIReranker
57

68
__all__ = ['GenericLiteLLMReranker', 'OpenAIReranker']
79

810

9-
def _missing_provider(name: str, dependency: str) -> type:
10-
class _MissingProvider: # pragma: no cover - simple error shim
11-
def __init__(self, *args, **kwargs):
12-
raise ImportError(
13-
f'{name} requires the optional dependency "{dependency}". '
14-
f'Install it via `uv sync --extra {dependency}` or the matching extras group.'
15-
)
16-
17-
_MissingProvider.__name__ = f'Missing{name}'
18-
return _MissingProvider
19-
20-
2111
try: # pragma: no cover - optional dependency
2212
from .gemini import GeminiReranker # type: ignore[unused-import]
2313
except ImportError: # pragma: no cover
24-
GeminiReranker = _missing_provider('GeminiReranker', 'google-genai') # type: ignore[assignment]
14+
GeminiReranker = missing_provider('GeminiReranker', 'google-genai') # type: ignore[assignment]
2515
else: # pragma: no cover
2616
__all__.append('GeminiReranker')
2717

2818
try: # pragma: no cover - optional dependency
2919
from .bge import BGEReranker # type: ignore[unused-import]
3020
except ImportError: # pragma: no cover
31-
BGEReranker = _missing_provider('BGEReranker', 'sentence-transformers') # type: ignore[assignment]
21+
BGEReranker = missing_provider('BGEReranker', 'sentence-transformers') # type: ignore[assignment]
3222
else: # pragma: no cover
3323
__all__.append('BGEReranker')
Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,32 @@
11
from __future__ import annotations
22

3+
from graphium_core.utils import missing_provider
4+
35
from .openai import OpenAIEmbedder, OpenAIEmbedderConfig
46

57
__all__ = ['OpenAIEmbedder', 'OpenAIEmbedderConfig']
68

79

8-
def _missing_provider(name: str, dependency: str) -> type:
9-
class _MissingProvider: # pragma: no cover - shim for optional extras
10-
def __init__(self, *args, **kwargs):
11-
raise ImportError(
12-
f'{name} requires the optional dependency "{dependency}". '
13-
f'Install it via `uv sync --extra {dependency}` or the corresponding extras group.'
14-
)
15-
16-
_MissingProvider.__name__ = f'Missing{name}'
17-
return _MissingProvider
18-
19-
2010
try: # pragma: no cover - optional dependency
2111
from .gemini import GeminiEmbedder, GeminiEmbedderConfig # type: ignore[unused-import]
2212
except ImportError: # pragma: no cover
23-
GeminiEmbedder = _missing_provider('GeminiEmbedder', 'google-genai') # type: ignore[assignment]
24-
GeminiEmbedderConfig = _missing_provider('GeminiEmbedderConfig', 'google-genai') # type: ignore[assignment]
13+
GeminiEmbedder = missing_provider('GeminiEmbedder', 'google-genai') # type: ignore[assignment]
14+
GeminiEmbedderConfig = missing_provider('GeminiEmbedderConfig', 'google-genai') # type: ignore[assignment]
2515
else: # pragma: no cover
2616
__all__.extend(['GeminiEmbedder', 'GeminiEmbedderConfig'])
2717

2818
try: # pragma: no cover - optional dependency
2919
from .voyage import VoyageAIEmbedder, VoyageAIEmbedderConfig # type: ignore[unused-import]
3020
except ImportError: # pragma: no cover
31-
VoyageAIEmbedder = _missing_provider('VoyageAIEmbedder', 'voyageai') # type: ignore[assignment]
32-
VoyageAIEmbedderConfig = _missing_provider('VoyageAIEmbedderConfig', 'voyageai') # type: ignore[assignment]
21+
VoyageAIEmbedder = missing_provider('VoyageAIEmbedder', 'voyageai') # type: ignore[assignment]
22+
VoyageAIEmbedderConfig = missing_provider('VoyageAIEmbedderConfig', 'voyageai') # type: ignore[assignment]
3323
else: # pragma: no cover
3424
__all__.extend(['VoyageAIEmbedder', 'VoyageAIEmbedderConfig'])
3525

3626
try: # pragma: no cover - optional dependency
3727
from .embeddinggemma import EmbeddingGemmaConfig, EmbeddingGemmaEmbedder # type: ignore[unused-import]
3828
except ImportError: # pragma: no cover
39-
EmbeddingGemmaEmbedder = _missing_provider('EmbeddingGemmaEmbedder', 'sentence-transformers') # type: ignore[assignment]
40-
EmbeddingGemmaConfig = _missing_provider('EmbeddingGemmaConfig', 'sentence-transformers') # type: ignore[assignment]
29+
EmbeddingGemmaEmbedder = missing_provider('EmbeddingGemmaEmbedder', 'sentence-transformers') # type: ignore[assignment]
30+
EmbeddingGemmaConfig = missing_provider('EmbeddingGemmaConfig', 'sentence-transformers') # type: ignore[assignment]
4131
else: # pragma: no cover
4232
__all__.extend(['EmbeddingGemmaEmbedder', 'EmbeddingGemmaConfig'])

graphium_core/search/edges.py

Lines changed: 12 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
43
from collections.abc import Coroutine
54
from typing import Any
65

@@ -24,7 +23,7 @@
2423
node_distance_reranker,
2524
rrf,
2625
)
27-
from graphium_core.utils.async_utils import semaphore_gather
26+
from graphium_core.search.shared import gather_search_results
2827

2928

3029
def _collect_seed_bfs_nodes(search_results: list[list[EntityEdge]]) -> list[str]:
@@ -98,36 +97,6 @@ def _build_edge_search_tasks(
9897
return tasks, requires_seeded_bfs
9998

10099

101-
async def _gather_edge_search_results(
102-
driver: GraphDriver,
103-
tasks: list[Coroutine[Any, Any, list[EntityEdge]]],
104-
requires_seeded_bfs: bool,
105-
config: EdgeSearchConfig,
106-
search_filter: SearchFilters,
107-
group_ids: list[str] | None,
108-
limit: int,
109-
) -> list[list[EntityEdge]]:
110-
search_results = list(await semaphore_gather(*tasks)) if tasks else []
111-
112-
if not requires_seeded_bfs:
113-
return search_results
114-
115-
seed_node_uuids = _collect_seed_bfs_nodes(search_results)
116-
if not seed_node_uuids:
117-
return search_results
118-
119-
bfs_results = await edge_bfs_search(
120-
driver,
121-
seed_node_uuids,
122-
config.bfs_max_depth,
123-
search_filter,
124-
group_ids,
125-
2 * limit,
126-
)
127-
search_results.append(bfs_results)
128-
return search_results
129-
130-
131100
def _rerank_edges_with_rrf(
132101
search_result_uuids: list[list[str]],
133102
edge_uuid_map: dict[str, EntityEdge],
@@ -311,14 +280,18 @@ async def edge_search(
311280
limit,
312281
)
313282

314-
search_results = await _gather_edge_search_results(
315-
driver,
283+
search_results = await gather_search_results(
316284
search_tasks,
317285
requires_seeded_bfs,
318-
config,
319-
search_filter,
320-
group_ids,
321-
limit,
286+
_collect_seed_bfs_nodes,
287+
lambda seed_node_uuids: edge_bfs_search(
288+
driver,
289+
seed_node_uuids,
290+
config.bfs_max_depth,
291+
search_filter,
292+
group_ids,
293+
2 * limit,
294+
),
322295
)
323296

324297
reranked_edges, edge_scores = await _rerank_edges(
@@ -336,4 +309,4 @@ async def edge_search(
336309
return reranked_edges[:limit], edge_scores[:limit]
337310

338311

339-
__all__ = ['edge_search', '_build_edge_search_tasks', '_gather_edge_search_results']
312+
__all__ = ['edge_search', '_build_edge_search_tasks']

graphium_core/search/nodes.py

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
node_similarity_search,
3333
rrf,
3434
)
35-
from graphium_core.utils.async_utils import semaphore_gather
35+
from graphium_core.search.shared import gather_search_results
3636

3737

3838
async def node_search(
@@ -61,14 +61,18 @@ async def node_search(
6161
limit,
6262
)
6363

64-
search_results = await _gather_node_search_results(
65-
driver,
64+
search_results = await gather_search_results(
6665
search_tasks,
6766
requires_seeded_bfs,
68-
config,
69-
search_filter,
70-
group_ids,
71-
limit,
67+
_collect_seed_nodes,
68+
lambda seed_nodes: node_bfs_search(
69+
driver,
70+
seed_nodes,
71+
search_filter,
72+
config.bfs_max_depth,
73+
group_ids,
74+
2 * limit,
75+
),
7276
)
7377

7478
reranked_nodes, node_scores = await _rerank_nodes(
@@ -195,36 +199,6 @@ def _build_node_search_tasks(
195199
return tasks, requires_seeded_bfs
196200

197201

198-
async def _gather_node_search_results(
199-
driver: GraphDriver,
200-
tasks: list[Coroutine[Any, Any, list[EntityNode]]],
201-
requires_seeded_bfs: bool,
202-
config: NodeSearchConfig,
203-
search_filter: SearchFilters,
204-
group_ids: list[str] | None,
205-
limit: int,
206-
) -> list[list[EntityNode]]:
207-
search_results = list(await semaphore_gather(*tasks)) if tasks else []
208-
209-
if not requires_seeded_bfs:
210-
return search_results
211-
212-
seed_nodes = _collect_seed_nodes(search_results)
213-
if not seed_nodes:
214-
return search_results
215-
216-
bfs_results = await node_bfs_search(
217-
driver,
218-
seed_nodes,
219-
search_filter,
220-
config.bfs_max_depth,
221-
group_ids,
222-
2 * limit,
223-
)
224-
search_results.append(bfs_results)
225-
return search_results
226-
227-
228202
def _collect_seed_nodes(search_results: list[list[EntityNode]]) -> list[str]:
229203
seen: set[str] = set()
230204
seeds: list[str] = []

graphium_core/search/shared.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from __future__ import annotations
2+
3+
from collections.abc import Awaitable, Callable, Coroutine, Sequence
4+
from typing import Any, TypeVar
5+
6+
from graphium_core.utils.async_utils import semaphore_gather
7+
8+
T = TypeVar('T')
9+
10+
11+
async def gather_search_results(
12+
tasks: Sequence[Coroutine[Any, Any, list[T]]],
13+
requires_seeded_bfs: bool,
14+
collect_seed_ids: Callable[[list[list[T]]], list[str]],
15+
bfs_fetch: Callable[[list[str]], Awaitable[list[T]]],
16+
) -> list[list[T]]:
17+
"""Execute search tasks and optionally augment results with a BFS fallback."""
18+
search_results = list(await semaphore_gather(*tasks)) if tasks else []
19+
20+
if not requires_seeded_bfs:
21+
return search_results
22+
23+
seed_ids = collect_seed_ids(search_results)
24+
if not seed_ids:
25+
return search_results
26+
27+
bfs_results = await bfs_fetch(seed_ids)
28+
search_results.append(bfs_results)
29+
return search_results
30+
31+
32+
__all__ = ['gather_search_results']

graphium_core/utils/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .optional_dependencies import missing_provider
2+
3+
__all__ = ['missing_provider']
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import Final
4+
5+
_DEFAULT_GUIDANCE: Final[str] = 'or the matching extras group.'
6+
7+
8+
def missing_provider(
9+
name: str,
10+
dependency: str,
11+
*,
12+
extras_group: str | None = None,
13+
guidance: str | None = None,
14+
) -> type:
15+
"""Return a shim type that raises a helpful ImportError when instantiated.
16+
17+
Both embedder and reranker packages use the same pattern for optional providers.
18+
This helper keeps the messaging consistent while avoiding copy/paste implementations.
19+
"""
20+
21+
extras_label = extras_group or dependency
22+
message = (
23+
f'{name} requires the optional dependency "{dependency}". '
24+
f'Install it via `uv sync --extra {extras_label}` {guidance or _DEFAULT_GUIDANCE}'
25+
)
26+
27+
class _MissingProvider: # pragma: no cover - shim for optional extras
28+
def __init__(self, *args, **kwargs):
29+
raise ImportError(message)
30+
31+
_MissingProvider.__name__ = f'Missing{name}'
32+
return _MissingProvider
33+
34+
35+
__all__ = ['missing_provider']

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ dev = [
7070
"pytest-xdist>=3.6.1",
7171
"ruff>=0.7.1",
7272
"opentelemetry-sdk>=1.20.0",
73+
"pylint>=4.0.1",
7374
]
7475

7576
[build-system]

0 commit comments

Comments
 (0)