Skip to content

Commit 2dfb44c

Browse files
authored
Use headless search endpoint for Paper Finder (#14)
1 parent c643767 commit 2dfb44c

File tree

12 files changed

+260
-269
lines changed

12 files changed

+260
-269
lines changed

.claude-plugin/marketplace.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
"name": "asta",
1212
"source": "./",
1313
"description": "Paper search, citations, literature reports, and Semantic Scholar API tools",
14-
"version": "0.4.0",
14+
"version": "0.5.0",
1515
"author": {
1616
"name": "AI2 Asta Team"
1717
},

.claude-plugin/plugin.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "asta",
3-
"version": "0.4.0",
3+
"version": "0.5.0",
44
"description": "Asta science tools for Claude Code - paper search, citations, and more",
55
"author": {
66
"name": "AI2 Asta Team"

DEVELOPER.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -220,14 +220,17 @@ from asta.core import AstaPaperFinder
220220

221221
client = AstaPaperFinder()
222222

223-
# Simple blocking search
223+
# Simple synchronous search using headless endpoint
224224
result = client.find_papers("query", timeout=300)
225-
# Returns: {widget_id, file_path, paper_count}
226-
227-
# Non-blocking start
228-
thread_id = client.start_search("query")
229-
widget_id = client.get_widget_id(thread_id)
230-
results = client.poll_for_results(widget_id, timeout=300)
225+
# Returns: {query, widget, status, timestamp, paper_count}
226+
227+
# With operation mode control
228+
result = client.find_papers(
229+
"query",
230+
timeout=300,
231+
operation_mode="fast", # "infer", "fast", or "diligent"
232+
include_full_metadata=True
233+
)
231234
```
232235

233236
### SemanticScholarClient

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "asta"
7-
version = "0.4.0"
7+
version = "0.5.0"
88
description = "Asta CLI for scientific literature review"
99
readme = "README.md"
1010
requires-python = ">=3.11"

src/asta/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
"""Asta - Science literature research tools"""
22

3-
__version__ = "0.4.0"
3+
__version__ = "0.5.0"

src/asta/core/client.py

Lines changed: 44 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -7,149 +7,86 @@
77
import time
88
import urllib.error
99
import urllib.request
10-
import uuid
1110
from pathlib import Path
1211
from typing import Any
1312

1413

1514
class AstaPaperFinder:
16-
"""Client for Asta Paper Finder API"""
15+
"""Client for Asta Paper Finder API using headless endpoint"""
1716

18-
def __init__(self, base_url: str = "REDACTED_ASTA_PROD_URL"):
17+
def __init__(self, base_url: str = "REDACTED_MABOOL_WORKERS_URL"):
1918
self.base_url = base_url
20-
self.mabool_url = "REDACTED_MABOOL_DEMO_URL"
21-
self.user_id = str(uuid.uuid4())
2219
self.headers = {
23-
"X-Anonymous-User-ID": self.user_id,
2420
"Content-Type": "application/json",
2521
}
2622

2723
def _request(
2824
self, url: str, method: str = "GET", data: dict | None = None
29-
) -> dict[str, Any] | list:
25+
) -> dict[str, Any]:
3026
"""Make an HTTP request and return JSON response"""
3127
body = json.dumps(data).encode() if data else None
3228
req = urllib.request.Request(
3329
url, data=body, headers=self.headers, method=method
3430
)
35-
response = urllib.request.urlopen(req)
36-
return json.loads(response.read())
37-
38-
def create_thread(self) -> str:
39-
"""Create a new thread"""
40-
result = self._request(f"{self.base_url}/api/chat/thread", method="PUT")
41-
return result["thread"]["key"]
42-
43-
def send_message(
44-
self, text: str, thread_id: str, profile: str = "paper-finder-only"
45-
) -> dict[str, Any]:
46-
"""Send a message to the thread"""
47-
return self._request(
48-
f"{self.base_url}/api/chat/message",
49-
method="POST",
50-
data={"text": text, "thread_id": thread_id, "profile": profile},
51-
)
52-
53-
def get_widget_id(self, thread_id: str, max_retries: int = 20) -> str | None:
54-
"""Get the widget ID from thread events"""
55-
url = f"{self.base_url}/api/rest/thread/{thread_id}/event/widget_paper_finder"
56-
for _ in range(max_retries):
57-
try:
58-
req = urllib.request.Request(url, headers=self.headers)
59-
response = urllib.request.urlopen(req)
60-
data = json.loads(response.read())
61-
last_event = data.get("last_event")
62-
if last_event and isinstance(last_event, dict):
63-
event_data = last_event.get("data")
64-
if event_data and isinstance(event_data, dict):
65-
widget_id = event_data.get("id")
66-
if widget_id:
67-
return widget_id
68-
except urllib.error.HTTPError:
69-
pass
70-
time.sleep(2)
71-
return None
72-
73-
def get_widget_results(self, widget_id: str) -> dict[str, Any] | list:
74-
"""Get widget results from mabool service"""
75-
url = f"{self.mabool_url}/api/2/rounds/{widget_id}/result/widget"
76-
req = urllib.request.Request(url, headers=self.headers)
77-
response = urllib.request.urlopen(req)
78-
return json.loads(response.read())
79-
80-
def poll_for_results(self, widget_id: str, timeout: int = 300):
81-
"""Poll for results until completion or timeout"""
82-
start_time = time.time()
83-
poll_interval = 2
84-
85-
while time.time() - start_time < timeout:
31+
try:
32+
response = urllib.request.urlopen(req)
33+
return json.loads(response.read())
34+
except urllib.error.HTTPError as e:
35+
error_body = e.read().decode("utf-8")
8636
try:
87-
result = self.get_widget_results(widget_id)
88-
89-
# Handle if result is a list - got the papers directly
90-
if isinstance(result, list):
91-
return {
92-
"roundStatus": {"kind": "completed"},
93-
"results": result,
94-
"thread_id": None,
95-
"widget_id": widget_id,
96-
}
97-
98-
# Handle dict response with roundStatus
99-
status = result.get("roundStatus", {}).get("kind", "unknown")
100-
101-
if status == "completed":
102-
return result
103-
elif status == "failed":
104-
error = result.get("roundStatus", {}).get("error", "Unknown error")
105-
raise Exception(f"Paper finder failed: {error}")
106-
107-
except urllib.error.HTTPError as e:
108-
if e.code != 404:
109-
raise
110-
111-
time.sleep(poll_interval)
112-
113-
raise TimeoutError(f"Timeout after {timeout} seconds")
114-
115-
def start_search(self, query: str) -> str:
116-
"""Start a paper search and return thread_id immediately (non-blocking)"""
117-
thread_id = self.create_thread()
118-
self.send_message(query, thread_id)
119-
return thread_id
37+
error_data = json.loads(error_body)
38+
error_msg = error_data.get("detail", str(e))
39+
except json.JSONDecodeError:
40+
error_msg = error_body or str(e)
41+
raise Exception(f"API request failed: {error_msg}") from e
12042

12143
def find_papers(
122-
self, query: str, timeout: int = 300, save_to_file: Path | None = None
44+
self,
45+
query: str,
46+
timeout: int = 300,
47+
save_to_file: Path | None = None,
48+
operation_mode: str = "infer",
49+
include_full_metadata: bool = True,
12350
) -> dict[str, Any]:
124-
"""Complete workflow to find papers (blocking).
51+
"""Execute a one-shot paper search using the headless endpoint.
12552
12653
Args:
12754
query: Search query
128-
timeout: Maximum time to wait for results
55+
timeout: Maximum time to wait for results (seconds)
12956
save_to_file: Optional path to save results. If None, no file is saved.
57+
operation_mode: Search strategy - 'infer', 'fast', or 'diligent' (default: 'infer')
58+
include_full_metadata: Whether to return full paper details (default: True)
13059
13160
Returns:
132-
Complete search results including widget data
61+
Complete search results with papers
13362
"""
134-
thread_id = self.start_search(query)
63+
url = f"{self.base_url}/api/3/headless/paper-search"
64+
65+
request_body = {
66+
"query": query,
67+
"operation_mode": operation_mode,
68+
"include_full_metadata": include_full_metadata,
69+
"timeout_seconds": timeout,
70+
}
13571

136-
# Get widget ID
137-
widget_id = self.get_widget_id(thread_id)
138-
if not widget_id:
139-
raise Exception("Failed to get widget ID after retries")
72+
# Make the synchronous request
73+
result = self._request(url, method="POST", data=request_body)
14074

141-
# Poll for results
142-
widget_result = self.poll_for_results(widget_id, timeout)
75+
# Check for errors
76+
if "error" in result and result["error"]:
77+
error = result["error"]
78+
raise Exception(f"Paper search failed: {error}")
14379

144-
papers = widget_result.get("results", [])
80+
papers = result.get("papers", [])
14581

146-
# Build complete search data
82+
# Build search data in format compatible with existing models
14783
search_data = {
14884
"query": query,
149-
"thread_id": thread_id,
150-
"widget_id": widget_id,
85+
"widget": {
86+
"results": papers,
87+
"response_text": result.get("response_text", ""),
88+
},
15189
"status": "completed",
152-
"widget": widget_result,
15390
"timestamp": time.time(),
15491
"paper_count": len(papers),
15592
}

src/asta/literature/find.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@
1919
default=300,
2020
help="Maximum time to wait for results (seconds)",
2121
)
22-
def find(query: str, timeout: int):
22+
@click.option(
23+
"--mode",
24+
type=click.Choice(["infer", "fast", "diligent"]),
25+
default="infer",
26+
help="Search strategy: infer (auto-detect), fast (quick results), or diligent (comprehensive)",
27+
)
28+
def find(query: str, timeout: int, mode: str):
2329
"""Find papers matching QUERY using Asta Paper Finder.
2430
2531
Saves results to .asta/literature/find/ with an auto-generated filename.
@@ -31,10 +37,18 @@ def find(query: str, timeout: int):
3137
3238
# With custom timeout
3339
asta literature find "transformers" --timeout 60
40+
41+
# Use fast mode for quick results
42+
asta literature find "deep learning" --mode fast
43+
44+
# Use diligent mode for comprehensive search
45+
asta literature find "neural networks" --mode diligent
3446
"""
3547
try:
3648
client = AstaPaperFinder()
37-
raw_result = client.find_papers(query, timeout=timeout, save_to_file=None)
49+
raw_result = client.find_papers(
50+
query, timeout=timeout, save_to_file=None, operation_mode=mode
51+
)
3852

3953
# Transform to literature search result format
4054
literature_result = LiteratureSearchResult(

src/asta/literature/models.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import Any
44

5-
from pydantic import BaseModel, ConfigDict, Field
5+
from pydantic import BaseModel, ConfigDict, Field, field_validator
66

77

88
class Author(BaseModel):
@@ -62,27 +62,53 @@ class CitationContext(BaseModel):
6262
class Paper(BaseModel):
6363
"""Paper search result with relevance judgements"""
6464

65-
corpusId: int
65+
model_config = ConfigDict(populate_by_name=True)
66+
67+
# Use validation_alias to accept snake_case from API
68+
corpusId: int = Field(validation_alias="corpus_id")
6669
title: str
6770
abstract: str | None = None
6871
year: int | None = None
6972
authors: list[Author] = Field(default_factory=list)
7073
venue: str | None = None
7174
journal: dict[str, Any] | None = None
7275
url: str | None = None
73-
publicationDate: str | None = None
74-
citationCount: int | None = None
76+
publicationDate: str | None = Field(
77+
default=None, validation_alias="publication_date"
78+
)
79+
citationCount: int | None = Field(default=None, validation_alias="citation_count")
7580
categories: list[str] = Field(default_factory=list)
7681

7782
# Asta Paper Finder specific fields
78-
relevanceScore: float
79-
relevanceJudgement: RelevanceJudgement | None = None
83+
relevanceScore: float = Field(validation_alias="relevance_score")
84+
relevanceJudgement: RelevanceJudgement | None = Field(
85+
default=None, validation_alias="relevance_judgement"
86+
)
8087
snippets: list[Snippet] = Field(default_factory=list)
81-
citationContexts: list[CitationContext] = Field(default_factory=list)
88+
citationContexts: list[CitationContext] = Field(
89+
default_factory=list, validation_alias="citation_contexts"
90+
)
8291

8392
# Legal/filtering fields
84-
legalToShow: bool = True
85-
numOfOmittedCitationContextsDueLegal: int = 0
93+
legalToShow: bool = Field(default=True, validation_alias="legal_to_show")
94+
numOfOmittedCitationContextsDueLegal: int = Field(
95+
default=0, validation_alias="num_of_omitted_citation_contexts_due_legal"
96+
)
97+
98+
@field_validator("authors", mode="before")
99+
@classmethod
100+
def convert_author_strings(cls, v):
101+
"""Convert author strings to Author objects if needed."""
102+
if not isinstance(v, list):
103+
return v
104+
result = []
105+
for author in v:
106+
if isinstance(author, str):
107+
# Convert string to Author dict
108+
result.append({"name": author, "id": ""})
109+
else:
110+
result.append(author)
111+
return result
86112

87113

88114
class LiteratureSearchResult(BaseModel):

tests/test_cli.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,40 @@ def test_find_custom_timeout(self, runner):
140140

141141
assert result.exit_code == 0
142142
mock_instance.find_papers.assert_called_once_with(
143-
"test query", timeout=60, save_to_file=None
143+
"test query", timeout=60, save_to_file=None, operation_mode="infer"
144+
)
145+
146+
def test_find_with_mode_option(self, runner):
147+
"""Test find command with different operation modes."""
148+
mock_result = {
149+
"query": "test query",
150+
"status": "completed",
151+
"paper_count": 1,
152+
"widget": {
153+
"results": [
154+
{
155+
"corpusId": 123,
156+
"title": "Test Paper",
157+
"relevanceScore": 0.9,
158+
"authors": [],
159+
}
160+
]
161+
},
162+
}
163+
164+
with patch("asta.literature.find.AstaPaperFinder") as MockFinder:
165+
mock_instance = MagicMock()
166+
mock_instance.find_papers.return_value = mock_result
167+
MockFinder.return_value = mock_instance
168+
169+
# Test fast mode
170+
result = runner.invoke(
171+
cli, ["literature", "find", "test query", "--mode", "fast"]
172+
)
173+
174+
assert result.exit_code == 0
175+
mock_instance.find_papers.assert_called_with(
176+
"test query", timeout=300, save_to_file=None, operation_mode="fast"
144177
)
145178

146179

0 commit comments

Comments
 (0)