Skip to content

Commit a33e7e4

Browse files
committed
fix(nlq2sparql): make tools package lazy; fix wikidata_tool imports and syntax for module execution
1 parent d272c9f commit a33e7e4

File tree

2 files changed

+85
-37
lines changed

2 files changed

+85
-37
lines changed

code/nlq2sparql/tools/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
"""Tool subpackage for NLQ2SPARQL (Wikidata resolution, etc.)."""
1+
"""Tool subpackage for NLQ2SPARQL (Wikidata resolution, etc.).
22
3-
from .wikidata_tool import find_entity_id, find_property_id
3+
Avoid importing heavy or side-effectful modules at package import time.
4+
Consumers should import needed symbols directly from their modules, e.g.:
5+
from code.nlq2sparql.tools.wikidata_tool import find_entity_id
6+
"""
47

5-
__all__ = ["find_entity_id", "find_property_id"]
8+
__all__: list[str] = []
Lines changed: 79 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,83 +1,128 @@
1-
"""Wikidata Tool Functions for NLQ2SPARQL LLM Integrations
1+
"""Wikidata Tool Functions for NLQ2SPARQL LLM Integrations.
22
3-
Adapter-backed client creation to avoid touching shared libs.
3+
Async helpers that delegate to a loop-safe Wikidata client via the
4+
local integrations adapter. Provides two primary functions:
5+
- find_entity_id(label) -> QID
6+
- find_property_id(label) -> PID
47
"""
58
from __future__ import annotations
6-
import asyncio, logging
9+
10+
import asyncio
11+
import logging
712
from functools import lru_cache
8-
from pathlib import Path
9-
from typing import Optional, Sequence
10-
from integrations.wikidata_adapter import (
11-
get_wikidata_client,
12-
close_wikidata_client,
13-
)
13+
from typing import Dict, Optional, Sequence
14+
15+
try:
16+
from ..integrations.wikidata_adapter import (
17+
get_wikidata_client,
18+
close_wikidata_client,
19+
)
20+
except Exception:
21+
# Fallback absolute import if package context differs (e.g., scripts)
22+
from code.nlq2sparql.integrations.wikidata_adapter import ( # type: ignore
23+
get_wikidata_client,
24+
close_wikidata_client,
25+
)
26+
1427
logger = logging.getLogger(__name__)
1528

29+
1630
async def _get_client():
1731
"""Get a loop-safe Wikidata client via the local adapter."""
1832
return await get_wikidata_client()
33+
34+
1935
async def _search_entities_precise(term: str, entity_type: str, limit: int = 1):
2036
client = await _get_client()
2137
try:
2238
return await client.wbsearchentities(term, entity_type=entity_type, limit=limit)
2339
except Exception as e:
24-
logger.error("wbsearchentities failed for '%s': %s", term, e); return []
40+
logger.error("wbsearchentities failed for '%s': %s", term, e)
41+
return []
42+
43+
2544
async def _search_entities_fuzzy(term: str, entity_type: str, limit: int = 5):
26-
if entity_type != 'item': return []
45+
if entity_type != "item":
46+
return []
2747
client = await _get_client()
2848
try:
29-
return await client.search(term, limit=limit, entity_type='items')
49+
return await client.search(term, limit=limit, entity_type="items")
3050
except Exception as e:
31-
logger.error("Elastic search failed for '%s': %s", term, e); return []
51+
logger.error("Elastic search failed for '%s': %s", term, e)
52+
return []
53+
54+
3255
@lru_cache(maxsize=512)
33-
def _normalized_input(s: str) -> str: return ' '.join(s.strip().split())
34-
from typing import Dict
56+
def _normalized_input(s: str) -> str:
57+
return " ".join(s.strip().split())
58+
3559

3660
def _pick_best_candidate(term: str, candidates: Sequence[Dict]) -> Optional[str]:
37-
if not candidates: return None
61+
if not candidates:
62+
return None
3863
term_lower = term.lower().strip()
3964
norm = []
4065
for c in candidates:
41-
cid = c.get('id') or ''
42-
label = c.get('label') or c.get('snippet') or ''
43-
if cid and label: norm.append((cid,label))
44-
if not norm: return None
45-
for cid,label in norm:
46-
if label.lower()==term_lower: return cid
47-
for cid,label in norm:
48-
if label.lower().startswith(term_lower): return cid
66+
cid = c.get("id") or ""
67+
label = c.get("label") or c.get("snippet") or ""
68+
if cid and label:
69+
norm.append((cid, label))
70+
if not norm:
71+
return None
72+
for cid, label in norm:
73+
if label.lower() == term_lower:
74+
return cid
75+
for cid, label in norm:
76+
if label.lower().startswith(term_lower):
77+
return cid
4978
return norm[0][0]
79+
80+
5081
async def find_entity_id(entity_label: str) -> Optional[str]:
51-
if not entity_label or not entity_label.strip(): return None
82+
if not entity_label or not entity_label.strip():
83+
return None
5284
term = _normalized_input(entity_label)
5385
# test suite expects limit=1 call on wbsearchentities
54-
precise = await _search_entities_precise(term,'item',1)
86+
precise = await _search_entities_precise(term, "item", 1)
5587
qid = _pick_best_candidate(term, precise)
56-
if qid: return qid
57-
fuzzy = await _search_entities_fuzzy(term,'item',5)
88+
if qid:
89+
return qid
90+
fuzzy = await _search_entities_fuzzy(term, "item", 5)
5891
return _pick_best_candidate(term, fuzzy)
92+
93+
5994
async def find_property_id(property_label: str) -> Optional[str]:
60-
if not property_label or not property_label.strip(): return None
95+
if not property_label or not property_label.strip():
96+
return None
6197
term = _normalized_input(property_label)
62-
precise = await _search_entities_precise(term,'property',1)
98+
precise = await _search_entities_precise(term, "property", 1)
6399
return _pick_best_candidate(term, precise)
100+
101+
64102
async def _close_session():
65103
await close_wikidata_client()
104+
105+
66106
class WikidataTool:
67107
"""Lightweight OO wrapper kept for backward compatibility with tests.
68108
69109
Delegates to module-level async functions.
70110
"""
111+
71112
async def find_entity_id(self, label: str): # pragma: no cover simple delegation
72113
return await find_entity_id(label)
73114

74115
async def find_property_id(self, label: str): # pragma: no cover
75116
return await find_property_id(label)
76117

77-
__all__ = ['find_entity_id','find_property_id','WikidataTool']
78-
if __name__ == '__main__':
118+
119+
__all__ = ["find_entity_id", "find_property_id", "WikidataTool"]
120+
121+
122+
if __name__ == "__main__":
79123
async def _demo():
80-
print('Entity Bach ->', await find_entity_id('Bach'))
81-
print('Property composer ->', await find_property_id('composer'))
124+
print("Entity Bach ->", await find_entity_id("Bach"))
125+
print("Property composer ->", await find_property_id("composer"))
82126
await _close_session()
127+
83128
asyncio.run(_demo())

0 commit comments

Comments
 (0)