Skip to content

Commit c4c0275

Browse files
Merge pull request #90 from open-sciencelab/refactor/refactor-search
Refactor/refactor search
2 parents b1e7ef8 + ccaa726 commit c4c0275

File tree

4 files changed

+118
-36
lines changed

4 files changed

+118
-36
lines changed
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
pipeline:
22
- name: read
33
params:
4-
input_file: resources/input_examples/search_demo.jsonl # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
4+
input_file: resources/input_examples/search_demo.json # input file path, support json, jsonl, txt, pdf. See resources/input_examples for examples
55

66
- name: search
77
params:
88
data_sources: [uniprot] # data source for searcher, support: wikipedia, google, uniprot
9+
uniprot_params:
10+
use_local_blast: true # whether to use local blast for uniprot search
11+
local_blast_db: /your_path/uniprot_sprot

graphgen/graphgen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
self.working_dir, namespace="graph"
6969
)
7070
self.search_storage: JsonKVStorage = JsonKVStorage(
71-
self.working_dir, namespace="searcher"
71+
self.working_dir, namespace="search"
7272
)
7373
self.rephrase_storage: JsonKVStorage = JsonKVStorage(
7474
self.working_dir, namespace="rephrase"
@@ -190,7 +190,7 @@ async def search(self, search_config: Dict):
190190
return
191191
search_results = await search_all(
192192
seed_data=seeds,
193-
**search_config,
193+
search_config=search_config,
194194
)
195195

196196
_add_search_keys = await self.search_storage.filter_keys(

graphgen/models/searcher/db/uniprot_searcher.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1+
import asyncio
2+
import os
13
import re
4+
import subprocess
5+
import tempfile
6+
from concurrent.futures import ThreadPoolExecutor
7+
from functools import lru_cache
28
from io import StringIO
39
from typing import Dict, Optional
410

@@ -16,6 +22,11 @@
1622
from graphgen.utils import logger
1723

1824

25+
@lru_cache(maxsize=None)
26+
def _get_pool():
27+
return ThreadPoolExecutor(max_workers=10)
28+
29+
1930
class UniProtSearch(BaseSearcher):
2031
"""
2132
UniProt Search client to searcher with UniProt.
@@ -24,6 +35,14 @@ class UniProtSearch(BaseSearcher):
2435
3) Search with FASTA sequence (BLAST searcher).
2536
"""
2637

38+
def __init__(self, use_local_blast: bool = False, local_blast_db: str = "sp_db"):
39+
super().__init__()
40+
self.use_local_blast = use_local_blast
41+
self.local_blast_db = local_blast_db
42+
if self.use_local_blast and not os.path.isfile(f"{self.local_blast_db}.phr"):
43+
logger.error("Local BLAST database files not found. Please check the path.")
44+
self.use_local_blast = False
45+
2746
def get_by_accession(self, accession: str) -> Optional[dict]:
2847
try:
2948
handle = ExPASy.get_sprot_raw(accession)
@@ -101,38 +120,86 @@ def get_by_fasta(self, fasta_sequence: str, threshold: float) -> Optional[Dict]:
101120
logger.error("Empty FASTA sequence provided.")
102121
return None
103122

104-
# UniProtKB/Swiss-Prot BLAST API
105-
try:
106-
logger.debug("Performing BLAST searcher for the given sequence: %s", seq)
107-
result_handle = NCBIWWW.qblast(
108-
program="blastp",
109-
database="swissprot",
110-
sequence=seq,
111-
hitlist_size=1,
112-
expect=threshold,
113-
)
114-
blast_record = NCBIXML.read(result_handle)
115-
except RequestException:
116-
raise
117-
except Exception as e: # pylint: disable=broad-except
118-
logger.error("BLAST searcher failed: %s", e)
119-
return None
123+
accession = None
124+
if self.use_local_blast:
125+
accession = self._local_blast(seq, threshold)
126+
if accession:
127+
logger.debug("Local BLAST found accession: %s", accession)
128+
129+
if not accession:
130+
logger.debug("Falling back to NCBIWWW.qblast.")
131+
132+
# UniProtKB/Swiss-Prot BLAST API
133+
try:
134+
logger.debug(
135+
"Performing BLAST searcher for the given sequence: %s", seq
136+
)
137+
result_handle = NCBIWWW.qblast(
138+
program="blastp",
139+
database="swissprot",
140+
sequence=seq,
141+
hitlist_size=1,
142+
expect=threshold,
143+
)
144+
blast_record = NCBIXML.read(result_handle)
145+
except RequestException:
146+
raise
147+
except Exception as e: # pylint: disable=broad-except
148+
logger.error("BLAST searcher failed: %s", e)
149+
return None
120150

121-
if not blast_record.alignments:
122-
logger.info("No BLAST hits found for the given sequence.")
123-
return None
151+
if not blast_record.alignments:
152+
logger.info("No BLAST hits found for the given sequence.")
153+
return None
124154

125-
best_alignment = blast_record.alignments[0]
126-
best_hsp = best_alignment.hsps[0]
127-
if best_hsp.expect > threshold:
128-
logger.info("No BLAST hits below the threshold E-value.")
129-
return None
130-
hit_id = best_alignment.hit_id
155+
best_alignment = blast_record.alignments[0]
156+
best_hsp = best_alignment.hsps[0]
157+
if best_hsp.expect > threshold:
158+
logger.info("No BLAST hits below the threshold E-value.")
159+
return None
160+
hit_id = best_alignment.hit_id
131161

132-
# like sp|P01308.1|INS_HUMAN
133-
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
162+
# like sp|P01308.1|INS_HUMAN
163+
accession = hit_id.split("|")[1].split(".")[0] if "|" in hit_id else hit_id
134164
return self.get_by_accession(accession)
135165

166+
def _local_blast(self, seq: str, threshold: float) -> Optional[str]:
167+
"""
168+
Perform local BLAST search using local BLAST database.
169+
:param seq: The protein sequence.
170+
:param threshold: E-value threshold for BLAST searcher.
171+
:return: The accession number of the best hit or None if not found.
172+
"""
173+
try:
174+
with tempfile.NamedTemporaryFile(
175+
mode="w+", suffix=".fa", delete=False
176+
) as tmp:
177+
tmp.write(f">query\n{seq}\n")
178+
tmp_name = tmp.name
179+
180+
cmd = [
181+
"blastp",
182+
"-db",
183+
self.local_blast_db,
184+
"-query",
185+
tmp_name,
186+
"-evalue",
187+
str(threshold),
188+
"-max_target_seqs",
189+
"1",
190+
"-outfmt",
191+
"6 sacc", # only return accession
192+
]
193+
logger.debug("Running local blastp: %s", " ".join(cmd))
194+
out = subprocess.check_output(cmd, text=True).strip()
195+
os.remove(tmp_name)
196+
if out:
197+
return out.split("\n", maxsplit=1)[0]
198+
return None
199+
except Exception as exc: # pylint: disable=broad-except
200+
logger.error("Local blastp failed: %s", exc)
201+
return None
202+
136203
@retry(
137204
stop=stop_after_attempt(5),
138205
wait=wait_exponential(multiplier=1, min=4, max=10),
@@ -156,20 +223,29 @@ async def search(
156223
query = query.strip()
157224

158225
logger.debug("UniProt searcher query: %s", query)
226+
227+
loop = asyncio.get_running_loop()
228+
159229
# check if fasta sequence
160230
if query.startswith(">") or re.fullmatch(
161231
r"[ACDEFGHIKLMNPQRSTVWY\s]+", query, re.I
162232
):
163-
result = self.get_by_fasta(query, threshold)
233+
coro = loop.run_in_executor(
234+
_get_pool(), self.get_by_fasta, query, threshold
235+
)
164236

165237
# check if accession number
166238
elif re.fullmatch(r"[A-NR-Z0-9]{6,10}", query, re.I):
167-
result = self.get_by_accession(query)
239+
coro = loop.run_in_executor(_get_pool(), self.get_by_accession, query)
168240

169241
else:
170242
# otherwise treat as keyword
171-
result = self.get_best_hit(query)
243+
coro = loop.run_in_executor(_get_pool(), self.get_best_hit, query)
172244

245+
result = await coro
173246
if result:
174247
result["_search_query"] = query
175248
return result
249+
250+
251+
# TODO: use local UniProt database for large-scale searchs

graphgen/operators/search/search_all.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,25 @@
1414

1515
async def search_all(
1616
seed_data: dict,
17-
data_sources: list[str],
17+
search_config: dict,
1818
) -> dict:
1919
"""
2020
Perform searches across multiple search types and aggregate the results.
2121
:param seed_data: A dictionary containing seed data with entity names.
22-
:param data_sources: A list of search types to perform (e.g., "wikipedia", "google", "bing", "uniprot").
22+
:param search_config: A dictionary specifying which data sources to use for searching.
2323
:return: A dictionary with
2424
"""
2525

2626
results = {}
27+
data_sources = search_config.get("data_sources", [])
2728

2829
for data_source in data_sources:
2930
if data_source == "uniprot":
3031
from graphgen.models import UniProtSearch
3132

32-
uniprot_search_client = UniProtSearch()
33+
uniprot_search_client = UniProtSearch(
34+
**search_config.get("uniprot_params", {})
35+
)
3336

3437
data = list(seed_data.values())
3538
data = [d["content"] for d in data if "content" in d]

0 commit comments

Comments
 (0)