Skip to content

Commit b91405d

Browse files
committed
refactor sinequa_api wrapper, test suites, and full_text import
1 parent d7620ad commit b91405d

File tree

4 files changed

+269
-229
lines changed

4 files changed

+269
-229
lines changed

sde_collections/sinequa_api.py

Lines changed: 106 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
import json
2+
from collections.abc import Iterator
23
from typing import Any
34

45
import requests
56
import urllib3
67
from django.conf import settings
7-
from django.db import transaction
8-
9-
from .models.delta_url import DumpUrl
108

119
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
1210

@@ -138,85 +136,99 @@ def query(self, page: int, collection_config_folder: str | None = None, source:
138136

139137
return self.process_response(url, payload)
140138

141-
def sql_query(self, sql: str, collection) -> Any:
139+
def _execute_sql_query(self, sql: str) -> dict:
140+
"""
141+
Executes a SQL query against the Sinequa API.
142+
143+
Args:
144+
sql (str): The SQL query to execute
145+
146+
Returns:
147+
dict: The JSON response from the API containing 'Rows' and 'TotalRowCount'
148+
149+
Raises:
150+
ValueError: If no token is available for authentication
151+
"""
142152
token = self._get_token()
143153
if not token:
144154
raise ValueError("Authentication error: Token is required for SQL endpoint access")
145155

146-
page = 0
147-
page_size = 5000 # Number of records per page
148-
skip_records = 0
156+
url = f"{self.base_url}/api/v1/engine.sql"
157+
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
158+
raw_payload = json.dumps(
159+
{
160+
"method": "engine.sql",
161+
"sql": sql,
162+
"pretty": True,
163+
}
164+
)
149165

150-
while True:
151-
paginated_sql = f"{sql} SKIP {skip_records} COUNT {page_size}"
152-
url = f"{self.base_url}/api/v1/engine.sql"
153-
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
154-
raw_payload = json.dumps(
155-
{
156-
"method": "engine.sql",
157-
"sql": paginated_sql,
158-
"pretty": True,
159-
}
160-
)
166+
return self.process_response(url, headers=headers, raw_data=raw_payload)
161167

162-
response = self.process_response(url, headers=headers, raw_data=raw_payload)
163-
batch_data = response.get("Rows", [])
164-
total_row_count = response.get("TotalRowCount", 0)
165-
processed_response = self._process_full_text_response(response)
166-
self.process_and_update_data(processed_response, collection)
167-
print(f"Batch {page + 1} has been processed and updated")
168+
def _process_rows_to_records(self, rows: list) -> list[dict]:
169+
"""
170+
Converts raw SQL row data into structured record dictionaries.
168171
169-
# Check if all rows have been fetched
170-
if len(batch_data) == 0 or (skip_records + page_size) >= total_row_count:
171-
break
172+
Args:
173+
rows (list): List of rows, where each row is [url, full_text, title]
172174
173-
page += 1
174-
skip_records += page_size
175-
176-
return f"All {total_row_count} records have been processed and updated."
177-
178-
def process_and_update_data(self, batch_data, collection):
179-
for record in batch_data:
180-
try:
181-
with transaction.atomic():
182-
url = record["url"]
183-
scraped_text = record.get("full_text", "")
184-
scraped_title = record.get("title", "")
185-
DumpUrl.objects.update_or_create(
186-
url=url,
187-
defaults={
188-
"scraped_text": scraped_text,
189-
"scraped_title": scraped_title,
190-
"collection": collection,
191-
},
192-
)
193-
except KeyError as e:
194-
print(f"Missing key in data: {str(e)}")
195-
except Exception as e:
196-
print(f"Error processing record: {str(e)}")
197-
198-
def get_full_texts(self, collection_config_folder: str, source: str = None, collection=None) -> Any:
175+
Returns:
176+
list[dict]: List of processed records with url, full_text, and title keys
177+
178+
Raises:
179+
ValueError: If any row doesn't contain exactly 3 elements
180+
"""
181+
processed_records = []
182+
for idx, row in enumerate(rows):
183+
if len(row) != 3:
184+
raise ValueError(
185+
f"Invalid row format at index {idx}: Expected exactly three elements (url, full_text, title). "
186+
f"Received {len(row)} elements."
187+
)
188+
processed_records.append({"url": row[0], "full_text": row[1], "title": row[2]})
189+
return processed_records
190+
191+
def get_full_texts(self, collection_config_folder: str, source: str = None) -> Iterator[dict]:
199192
"""
200-
Retrieves the full texts, URLs, and titles for a specified collection.
193+
Retrieves and yields batches of text records from the SQL database for a given collection.
194+
Uses pagination to handle large datasets efficiently.
201195
202-
Returns:
203-
dict: A JSON response containing the results of the SQL query,
204-
where each item has 'url', 'text', and 'title'.
205-
206-
Example:
207-
Calling get_full_texts("example_collection") might return:
208-
[
209-
{
210-
'url': 'http://example.com/article1',
211-
'text': 'Here is the full text of the first article...',
212-
'title': 'Article One Title'
213-
},
214-
{
215-
'url': 'http://example.com/article2',
216-
'text': 'Here is the full text of the second article...',
217-
'title': 'Article Two Title'
218-
}
219-
]
196+
Args:
197+
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "SMD")
198+
source (str, optional): The source to query. If None, defaults to "scrapers" for dev servers
199+
or "SDE" for other servers.
200+
201+
Yields:
202+
list[dict]: Batches of records, where each record is a dictionary containing:
203+
{
204+
"url": str, # The URL of the document
205+
"full_text": str, # The full text content of the document
206+
"title": str # The title of the document
207+
}
208+
209+
Raises:
210+
ValueError: If the server's index is not defined in its configuration
211+
212+
Example batch:
213+
[
214+
{
215+
"url": "https://example.nasa.gov/doc1",
216+
"full_text": "This is the content of doc1...",
217+
"title": "Document 1 Title"
218+
},
219+
{
220+
"url": "https://example.nasa.gov/doc2",
221+
"full_text": "This is the content of doc2...",
222+
"title": "Document 2 Title"
223+
}
224+
]
225+
226+
Note:
227+
- Results are paginated in batches of 5000 records
228+
- Each batch is processed into clean dictionaries before being yielded
229+
- The iterator will stop when either:
230+
1. No more rows are returned from the query
231+
2. The total count of records has been reached
220232
"""
221233

222234
if not source:
@@ -229,7 +241,28 @@ def get_full_texts(self, collection_config_folder: str, source: str = None, coll
229241
)
230242

231243
sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"
232-
return self.sql_query(sql, collection)
244+
245+
page = 0
246+
page_size = 5000
247+
total_processed = 0
248+
249+
while True:
250+
paginated_sql = f"{sql} SKIP {total_processed} COUNT {page_size}"
251+
response = self._execute_sql_query(paginated_sql)
252+
253+
rows = response.get("Rows", [])
254+
if not rows: # Stop if we get an empty batch
255+
break
256+
257+
yield self._process_rows_to_records(rows)
258+
259+
total_processed += len(rows)
260+
total_count = response.get("TotalRowCount", 0)
261+
262+
if total_processed >= total_count: # Stop if we've processed all records
263+
break
264+
265+
page += 1
233266

234267
@staticmethod
235268
def _process_full_text_response(batch_data: dict):

sde_collections/tasks.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.conf import settings
88
from django.core import management
99
from django.core.management.commands import loaddata
10+
from django.db import transaction
1011

1112
from config import celery_app
1213

@@ -147,26 +148,43 @@ def resolve_title_pattern(title_pattern_id):
147148
@celery_app.task(soft_time_limit=600)
148149
def fetch_and_replace_full_text(collection_id, server_name):
149150
"""
150-
Task to initiate fetching and replacing full text and metadata for all URLs associated with a specified collection
151-
from a given server.
152-
Args:
153-
collection_id (int): The identifier for the collection in the database.
154-
server_name (str): The name of the server.
155-
156-
Returns:
157-
str: A message indicating the result of the operation, including the number of URLs processed.
151+
Task to fetch and replace full text and metadata for a collection.
152+
Handles data in batches to manage memory usage.
158153
"""
159154
collection = Collection.objects.get(id=collection_id)
160155
api = Api(server_name)
161156

162-
# Step 1: Delete all existing DumpUrl entries for the collection
157+
# Step 1: Delete existing DumpUrl entries
163158
deleted_count, _ = DumpUrl.objects.filter(collection=collection).delete()
164159
print(f"Deleted {deleted_count} old records.")
165160

166-
# Step 2: Fetch and process new data
167-
result_message = api.get_full_texts(collection.config_folder, collection=collection)
168-
169-
# Step 3: Migrate DumpUrl to DeltaUrl
170-
collection.migrate_dump_to_delta()
171-
172-
return result_message
161+
# Step 2: Process data in batches
162+
total_processed = 0
163+
164+
try:
165+
for batch in api.get_full_texts(collection.config_folder):
166+
# Use bulk_create for efficiency, with a transaction per batch
167+
with transaction.atomic():
168+
DumpUrl.objects.bulk_create(
169+
[
170+
DumpUrl(
171+
url=record["url"],
172+
collection=collection,
173+
scraped_text=record["full_text"],
174+
scraped_title=record["title"],
175+
)
176+
for record in batch
177+
]
178+
)
179+
180+
total_processed += len(batch)
181+
print(f"Processed batch of {len(batch)} records. Total: {total_processed}")
182+
183+
# Step 3: Migrate dump URLs to delta URLs
184+
collection.migrate_dump_to_delta()
185+
186+
return f"Successfully processed {total_processed} records and updated the database."
187+
188+
except Exception as e:
189+
print(f"Error processing records: {str(e)}")
190+
raise

0 commit comments

Comments
 (0)