Skip to content

Commit 2b811b6

Browse files
committed
add automatic batch size reduction to sinequa_api
1 parent 1b71c2d commit 2b811b6

File tree

2 files changed

+110
-34
lines changed

2 files changed

+110
-34
lines changed

sde_collections/sinequa_api.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,26 @@ def _process_rows_to_records(self, rows: list) -> list[dict]:
188188
processed_records.append({"url": row[0], "full_text": row[1], "title": row[2]})
189189
return processed_records
190190

191-
def get_full_texts(self, collection_config_folder: str, source: str = None) -> Iterator[dict]:
191+
def get_full_texts(
192+
self,
193+
collection_config_folder: str,
194+
source: str = None,
195+
start_at: int = 0,
196+
batch_size: int = 500,
197+
min_batch_size: int = 1,
198+
) -> Iterator[dict]:
192199
"""
193200
Retrieves and yields batches of text records from the SQL database for a given collection.
194-
Uses pagination to handle large datasets efficiently.
201+
Uses pagination to handle large datasets efficiently. If a query fails, it automatically
202+
reduces the batch size and retries, with the ability to recover batch size after successful queries.
195203
196204
Args:
197-
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "SMD")
205+
collection_config_folder (str): The collection folder to query (e.g., "EARTHDATA", "CASEI")
198206
source (str, optional): The source to query. If None, defaults to "scrapers" for dev servers
199207
or "SDE" for other servers.
208+
start_at (int, optional): Starting offset for records. Defaults to 0.
209+
page_size (int, optional): Initial number of records per batch. Defaults to 500.
210+
min_batch_size (int, optional): Minimum batch size before giving up. Defaults to 1.
200211
201212
Yields:
202213
list[dict]: Batches of records, where each record is a dictionary containing:
@@ -208,29 +219,16 @@ def get_full_texts(self, collection_config_folder: str, source: str = None) -> I
208219
209220
Raises:
210221
ValueError: If the server's index is not defined in its configuration
211-
212-
Example batch:
213-
[
214-
{
215-
"url": "https://example.nasa.gov/doc1",
216-
"full_text": "This is the content of doc1...",
217-
"title": "Document 1 Title"
218-
},
219-
{
220-
"url": "https://example.nasa.gov/doc2",
221-
"full_text": "This is the content of doc2...",
222-
"title": "Document 2 Title"
223-
}
224-
]
222+
ValueError: If batch size reaches minimum without success
225223
226224
Note:
227-
- Results are paginated in batches of 5000 records
225+
- Results are paginated with adaptive batch sizing
228226
- Each batch is processed into clean dictionaries before being yielded
229227
- The iterator will stop when either:
230228
1. No more rows are returned from the query
231229
2. The total count of records has been reached
230+
- Batch size will decrease on failure and can recover after successful queries
232231
"""
233-
234232
if not source:
235233
source = self._get_source_name()
236234

@@ -240,29 +238,42 @@ def get_full_texts(self, collection_config_folder: str, source: str = None) -> I
240238
"Please update server configuration with the required index."
241239
)
242240

243-
sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"
241+
base_sql = f"SELECT url1, text, title FROM {index} WHERE collection = '/{source}/{collection_config_folder}/'"
244242

245-
page = 0
246-
page_size = 5000
247-
total_processed = 0
243+
current_offset = start_at
244+
current_batch_size = batch_size
245+
total_count = None
248246

249247
while True:
250-
paginated_sql = f"{sql} SKIP {total_processed} COUNT {page_size}"
251-
response = self._execute_sql_query(paginated_sql)
248+
sql = f"{base_sql} SKIP {current_offset} COUNT {current_batch_size}"
249+
250+
try:
251+
response = self._execute_sql_query(sql)
252+
rows = response.get("Rows", [])
253+
254+
if not rows: # Stop if we get an empty batch
255+
break
256+
257+
if total_count is None:
258+
total_count = response.get("TotalRowCount", 0)
252259

253-
rows = response.get("Rows", [])
254-
if not rows: # Stop if we get an empty batch
255-
break
260+
yield self._process_rows_to_records(rows)
256261

257-
yield self._process_rows_to_records(rows)
262+
current_offset += len(rows)
258263

259-
total_processed += len(rows)
260-
total_count = response.get("TotalRowCount", 0)
264+
if total_count and current_offset >= total_count: # Stop if we've processed all records
265+
break
261266

262-
if total_processed >= total_count: # Stop if we've processed all records
263-
break
267+
except (requests.RequestException, ValueError) as e:
268+
if current_batch_size <= min_batch_size:
269+
raise ValueError(
270+
f"Failed to process batch even at minimum size {min_batch_size}. " f"Last error: {str(e)}"
271+
)
264272

265-
page += 1
273+
# Halve the batch size and retry
274+
current_batch_size = max(current_batch_size // 2, min_batch_size)
275+
print(f"Reducing batch size to {current_batch_size} and retrying...")
276+
continue
266277

267278
@staticmethod
268279
def _process_full_text_response(batch_data: dict):

sde_collections/tests/api_tests.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import MagicMock, patch
33

44
import pytest
5+
import requests
56
from django.utils import timezone
67

78
from sde_collections.models.collection import WorkflowStatusChoices
@@ -160,3 +161,67 @@ def test_query_dev_server_missing_credentials(self, mock_post, api_instance):
160161

161162
with pytest.raises(ValueError, match="Authentication error: Missing credentials for dev server"):
162163
api_instance.query(page=1)
164+
165+
@patch("sde_collections.sinequa_api.Api._execute_sql_query")
166+
def test_get_full_texts_batch_size_reduction(self, mock_execute_sql, api_instance):
167+
"""Test that batch size reduces appropriately on failure and continues processing."""
168+
# Mock first query to fail, then succeed with smaller batch
169+
mock_execute_sql.side_effect = [
170+
requests.RequestException("Query too large"), # First attempt fails
171+
{
172+
"Rows": [["http://example.com/1", "Text 1", "Title 1"]],
173+
"TotalRowCount": 1,
174+
}, # Succeeds with smaller batch
175+
]
176+
177+
batches = list(api_instance.get_full_texts("test_folder", batch_size=100, min_batch_size=1))
178+
179+
# Verify the batches were processed correctly after size reduction
180+
assert len(batches) == 1
181+
assert len(batches[0]) == 1
182+
assert batches[0][0]["url"] == "http://example.com/1"
183+
184+
# Verify the calls made - first with original size, then with reduced size
185+
assert mock_execute_sql.call_count == 2
186+
first_call = mock_execute_sql.call_args_list[0][0][0]
187+
second_call = mock_execute_sql.call_args_list[1][0][0]
188+
assert "COUNT 100" in first_call
189+
assert "COUNT 50" in second_call # Should be halved from 100
190+
191+
@patch("sde_collections.sinequa_api.Api._execute_sql_query")
192+
def test_get_full_texts_minimum_batch_size(self, mock_execute_sql, api_instance):
193+
"""Test behavior when reaching minimum batch size."""
194+
mock_execute_sql.side_effect = requests.RequestException("Query failed")
195+
196+
# Start with batch_size=4, min_batch_size=1
197+
# Should try: 4 -> 2 -> 1 -> raise error
198+
with pytest.raises(ValueError, match="Failed to process batch even at minimum size 1"):
199+
list(api_instance.get_full_texts("test_folder", batch_size=4, min_batch_size=1))
200+
201+
# Should have tried 3 times before giving up
202+
assert mock_execute_sql.call_count == 3
203+
calls = mock_execute_sql.call_args_list
204+
assert "COUNT 4" in calls[0][0][0] # First try with 4
205+
assert "COUNT 2" in calls[1][0][0] # Second try with 2
206+
assert "COUNT 1" in calls[2][0][0] # Final try with 1
207+
208+
@patch("sde_collections.sinequa_api.Api._execute_sql_query")
209+
def test_get_full_texts_batch_size_progression(self, mock_execute_sql, api_instance):
210+
"""Test multiple batch size reductions followed by successful query."""
211+
mock_execute_sql.side_effect = [
212+
requests.RequestException("First failure"),
213+
requests.RequestException("Second failure"),
214+
{"Rows": [["http://example.com/1", "Text 1", "Title 1"]], "TotalRowCount": 1},
215+
]
216+
217+
# Start with batch_size=100, should reduce to 25 before succeeding
218+
batches = list(api_instance.get_full_texts("test_folder", batch_size=100, min_batch_size=1))
219+
220+
assert len(batches) == 1 # Should get one successful batch
221+
assert mock_execute_sql.call_count == 3
222+
223+
calls = mock_execute_sql.call_args_list
224+
# Verify the progression of batch sizes
225+
assert "COUNT 100" in calls[0][0][0] # First attempt
226+
assert "COUNT 50" in calls[1][0][0] # After first failure
227+
assert "COUNT 25" in calls[2][0][0] # After second failure

0 commit comments

Comments
 (0)