Skip to content

Commit 2b03aa0

Browse files
authored
improve performance (#34)
* Improve fetching performance * Add verbose flag * Add option for enhanced peformance * Update version
1 parent 6620030 commit 2b03aa0

File tree

4 files changed

+70
-46
lines changed

4 files changed

+70
-46
lines changed

src/bibx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
"read_wos",
2727
]
2828

29-
__version__ = "0.4.0"
29+
__version__ = "0.4.1"
3030

3131

3232
def query_openalex(

src/bibx/builders/openalex.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from collections import Counter
23
from enum import Enum
34
from typing import Optional
45
from urllib.parse import urlparse
@@ -11,11 +12,14 @@
1112

1213
logger = logging.getLogger(__name__)
1314

15+
MAX_REFERENCES = 400
16+
1417

1518
class HandleReferences(Enum):
1619
"""How to handle references when building an openalex collection."""
1720

1821
BASIC = "basic"
22+
COMMON = "common"
1923
FULL = "full"
2024

2125

@@ -39,14 +43,22 @@ def build(self) -> Collection:
3943
logger.info("building collection for query %s", self.query)
4044
works = self.client.list_recent_articles(self.query, self.limit)
4145
cache = {work.id: work for work in works}
46+
references: list[str] = []
47+
for work in works:
48+
references.extend(work.referenced_works)
49+
if self.references == HandleReferences.COMMON:
50+
counter = Counter(references)
51+
most_common = {key for key, _ in counter.most_common(MAX_REFERENCES)}
52+
missing = most_common - set(cache.keys())
53+
logger.info("fetching %d missing references", len(missing))
54+
missing_works = self.client.list_articles_by_openalex_id(list(missing))
55+
cache.update({work.id: work for work in missing_works})
4256
if self.references == HandleReferences.FULL:
43-
references: list[str] = []
44-
for work in works:
45-
references.extend(work.referenced_works)
4657
missing = set(references) - set(cache.keys())
4758
logger.info("fetching %d missing references", len(missing))
4859
missing_works = self.client.list_articles_by_openalex_id(list(missing))
4960
cache.update({work.id: work for work in missing_works})
61+
5062
article_cache = {
5163
openalexid: self._work_to_article(work)
5264
for openalexid, work in cache.items()

src/bibx/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from collections.abc import Callable
23
from enum import Enum
34
from typing import TextIO
@@ -83,8 +84,14 @@ def openalex(
8384
help="how to handle references",
8485
default=HandleReferences.BASIC,
8586
),
87+
verbose: bool = typer.Option(
88+
help="be more verbose",
89+
default=False,
90+
),
8691
) -> None:
8792
"""Run the sap algorithm on a seed file of any supported format."""
93+
if verbose:
94+
logging.basicConfig(level=logging.INFO)
8895
c = query_openalex(" ".join(query), references=references)
8996
s = Sap()
9097
graph = s.create_graph(c)

src/bibx/clients/openalex.py

Lines changed: 47 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from concurrent.futures import ThreadPoolExecutor, as_completed, wait
23
from enum import Enum
34
from typing import Optional, Union
45

@@ -122,6 +123,18 @@ def __init__(
122123
}
123124
)
124125

126+
def _fetch_works(self, params: dict[str, Union[str, int]]) -> WorkResponse:
127+
response = self.session.get(
128+
f"{self.base_url}/works",
129+
params=params,
130+
)
131+
try:
132+
response.raise_for_status()
133+
data = response.json()
134+
return WorkResponse.model_validate(data)
135+
except (requests.RequestException, ValidationError) as error:
136+
raise OpenAlexError(str(error)) from error
137+
125138
def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]:
126139
"""List recent articles from the openalex API."""
127140
select = ",".join(Work.model_fields.keys())
@@ -134,56 +147,48 @@ def list_recent_articles(self, query: str, limit: int = 600) -> list[Work]:
134147
)
135148
pages = (limit // MAX_WORKS_PER_PAGE) + 1
136149
results: list[Work] = []
137-
for page in range(1, pages + 1):
138-
logger.info("fetching page %d with filter %s", page, filter_)
139-
params: dict[str, Union[str, int]] = {
140-
"select": select,
141-
"filter": filter_,
142-
"sort": "publication_year:desc",
143-
"per_page": MAX_WORKS_PER_PAGE,
144-
"page": page,
145-
}
146-
response = self.session.get(
147-
f"{self.base_url}/works",
148-
params=params,
149-
)
150-
try:
151-
response.raise_for_status()
152-
data = response.json()
153-
work_response = WorkResponse.model_validate(data)
154-
logger.info(
155-
"fetched %d works in page %d", len(work_response.results), page
150+
with ThreadPoolExecutor(max_workers=min(pages, 25)) as executor:
151+
futures = [
152+
executor.submit(
153+
self._fetch_works,
154+
{
155+
"select": select,
156+
"filter": filter_,
157+
"sort": "publication_year:desc",
158+
"per_page": MAX_WORKS_PER_PAGE,
159+
"page": page,
160+
},
156161
)
162+
for page in range(1, pages + 1)
163+
]
164+
wait(futures)
165+
for future in futures:
166+
work_response = future.result()
157167
results.extend(work_response.results)
158-
if page * MAX_WORKS_PER_PAGE >= min(work_response.meta.count, limit):
168+
if len(results) >= limit:
159169
break
160-
except (requests.RequestException, ValidationError) as error:
161-
raise OpenAlexError(str(error)) from error
162170
return results[:limit]
163171

164172
def list_articles_by_openalex_id(self, ids: list[str]) -> list[Work]:
165173
"""List articles by openalex id."""
166174
select = ",".join(Work.model_fields.keys())
167-
filter_ = ",".join([f"ids.openalex:{id_}" for id_ in ids])
168175
results: list[Work] = []
169-
for ids_ in chunks(ids, MAX_IDS_PER_REQUEST):
170-
value = "|".join(ids_)
171-
filter_ = f"ids.openalex:{value},type:types/article"
172-
logger.info("fetching %d ids from openalex", len(ids_))
173-
params: dict[str, Union[str, int]] = {
174-
"select": select,
175-
"filter": filter_,
176-
"per_page": MAX_IDS_PER_REQUEST,
177-
}
178-
response = self.session.get(
179-
f"{self.base_url}/works",
180-
params=params,
181-
)
182-
try:
183-
response.raise_for_status()
184-
data = response.json()
185-
work_response = WorkResponse.model_validate(data)
176+
with ThreadPoolExecutor(max_workers=5) as executor:
177+
futures = [
178+
executor.submit(
179+
self._fetch_works,
180+
{
181+
"select": select,
182+
"filter": f"ids.openalex:{'|'.join(ids)},type:types/article",
183+
"per_page": MAX_IDS_PER_REQUEST,
184+
},
185+
)
186+
for ids in chunks(ids, MAX_IDS_PER_REQUEST)
187+
]
188+
for future in as_completed(futures):
189+
work_response = future.result()
190+
logger.info(
191+
"got %s works from the openalex api", len(work_response.results)
192+
)
186193
results.extend(work_response.results)
187-
except (requests.RequestException, ValidationError) as error:
188-
raise OpenAlexError(str(error)) from error
189194
return results

0 commit comments

Comments
 (0)