diff --git a/nodestream/pipeline/extractors/stores/__init__.py b/nodestream/pipeline/extractors/stores/__init__.py index e69de29b..e3e6edc8 100644 --- a/nodestream/pipeline/extractors/stores/__init__.py +++ b/nodestream/pipeline/extractors/stores/__init__.py @@ -0,0 +1,3 @@ +from .splunk_extractor import SplunkExtractor + +__all__ = ("SplunkExtractor",) diff --git a/nodestream/pipeline/extractors/stores/splunk_extractor.py b/nodestream/pipeline/extractors/stores/splunk_extractor.py new file mode 100644 index 00000000..42f1fdf3 --- /dev/null +++ b/nodestream/pipeline/extractors/stores/splunk_extractor.py @@ -0,0 +1,400 @@ +import asyncio +import json +from logging import getLogger +from typing import Any, AsyncGenerator, Dict, Optional + +from httpx import AsyncClient, BasicAuth, HTTPStatusError + +from ..extractor import Extractor + + +class SplunkRecord: + """ + Wrapper for Splunk search result records. + """ + + @classmethod + def from_raw_splunk_result(cls, raw_result: Dict[str, Any]): + """Create a SplunkRecord from raw Splunk API response.""" + return cls(raw_result) + + def __init__(self, record_data: Dict[str, Any]): + self.record_data = record_data + + +class SplunkExtractor(Extractor): + @classmethod + def from_file_data( + cls, + base_url: str, + query: str, + auth_token: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + earliest_time: str = "-24h", + latest_time: str = "now", + verify_ssl: bool = True, + request_timeout_seconds: int = 300, + max_count: int = 10000, + app: str = "search", + user: Optional[str] = None, + chunk_size: int = 1000, + ) -> "SplunkExtractor": + return cls( + base_url=base_url.rstrip("/"), + query=query, + auth_token=auth_token, + username=username, + password=password, + earliest_time=earliest_time, + latest_time=latest_time, + verify_ssl=verify_ssl, + request_timeout_seconds=request_timeout_seconds, + max_count=max_count, + app=app, + user=user or username or "admin", + chunk_size=chunk_size, + ) + + def __init__( + self, + base_url: str, + query: str, + auth_token: Optional[str], + username: Optional[str], + password: Optional[str], + earliest_time: str, + latest_time: str, + verify_ssl: bool, + request_timeout_seconds: int, + max_count: int, + app: str, + user: str, + chunk_size: int, + ) -> None: + self.base_url = base_url + self.query = query + self.auth_token = auth_token + self.username = username + self.password = password + self.earliest_time = earliest_time + self.latest_time = latest_time + self.verify_ssl = verify_ssl + self.request_timeout_seconds = request_timeout_seconds + self.max_count = max_count + self.app = app + self.user = user + self.chunk_size = chunk_size + + # State for pagination and checkpointing + self.search_id = None + self.offset = 0 + self.is_done = False + + self.logger = getLogger(self.__class__.__name__) + + @property + def _auth(self): + if self.auth_token: + return None # token goes in header, not in BasicAuth + if self.username is not None and self.password is not None: + return BasicAuth(self.username, self.password) + return None + + @property + def _headers(self) -> Dict[str, str]: + headers = {"Accept": "application/json"} + if self.auth_token: + headers["Authorization"] = f"Splunk {self.auth_token}" + return headers + + @property + def _normalized_query(self) -> str: + """Ensure query starts with 'search' keyword.""" + query = self.query.strip() + if not query.lower().startswith("search "): + return f"search {query}" + return query + + def get_jobs_endpoint(self) -> str: + """Get the Splunk jobs endpoint.""" + return f"{self.base_url}/servicesNS/{self.user}/{self.app}/search/jobs" + + def get_results_endpoint(self, search_id: str) -> str: + """Get the results endpoint for a specific search job.""" + return f"{self.base_url}/servicesNS/{self.user}/{self.app}/search/jobs/{search_id}/results" + + async def _create_search_job(self, client: AsyncClient) -> str: + """Create a search job using the official Splunk REST API pattern.""" + job_data = { + "search": self._normalized_query, + "earliest_time": self.earliest_time, + "latest_time": self.latest_time, + "max_count": str(self.max_count), + "output_mode": "json", # Request JSON format explicitly + } + + self.logger.info( + "Creating Splunk search job", + extra={ + "query": self._normalized_query, + "earliest_time": self.earliest_time, + "latest_time": self.latest_time, + "max_count": self.max_count, + }, + ) + + response = await client.post( + self.get_jobs_endpoint(), + data=job_data, + headers={ + **self._headers, + "Content-Type": "application/x-www-form-urlencoded", + }, + auth=self._auth, + ) + + if response.status_code != 201: + raise HTTPStatusError( + f"Failed to create search job: {response.status_code}", + request=response.request, + response=response, + ) + + # Extract SID from response - try JSON first since we requested it + search_id = None + + try: + result = response.json() + search_id = result.get("sid") + if search_id: + self.logger.info( + "Created search job successfully", extra={"search_id": search_id} + ) + return search_id + except json.JSONDecodeError: + self.logger.debug("Job creation response not JSON, trying XML parsing") + + # Fallback to XML parsing + import xml.etree.ElementTree as ET + + try: + root = ET.fromstring(response.text) + sid_elem = root.find(".//sid") + if sid_elem is not None: + search_id = sid_elem.text + except Exception as e: + self.logger.error( + "Failed to parse job creation response", + extra={"error": str(e), "response": response.text[:500]}, + ) + + if not search_id: + raise RuntimeError( + f"Failed to extract search ID from response: {response.text[:500]}" + ) + + self.logger.info( + "Created search job successfully", extra={"search_id": search_id} + ) + return search_id + + async def _wait_for_job_completion( + self, client: AsyncClient, search_id: str, max_wait_seconds: int = 300 + ): + """Wait for a search job to complete.""" + wait_count = 0 + + while wait_count < max_wait_seconds: + response = await client.get( + f"{self.get_jobs_endpoint()}/{search_id}", + params={"output_mode": "json"}, + headers=self._headers, + auth=self._auth, + ) + + if response.status_code != 200: + raise HTTPStatusError( + f"Failed to check job status: {response.status_code}", + request=response.request, + response=response, + ) + + dispatch_state = "UNKNOWN" + + try: + job_status = response.json() + dispatch_state = ( + job_status.get("entry", [{}])[0] + .get("content", {}) + .get("dispatchState") + ) + except (json.JSONDecodeError, KeyError, IndexError): + # Try XML parsing + import xml.etree.ElementTree as ET + + try: + root = ET.fromstring(response.text) + for elem in root.iter(): + if ( + "dispatchState" in elem.tag + or elem.get("name") == "dispatchState" + ): + dispatch_state = elem.text + break + except Exception as e: + self.logger.warning( + "Failed to parse job status", + extra={"error": str(e), "search_id": search_id}, + ) + + self.logger.debug( + "Job status check", + extra={ + "search_id": search_id, + "dispatch_state": dispatch_state, + "wait_time": wait_count, + }, + ) + + if dispatch_state == "DONE": + self.logger.info("Search job completed", extra={"search_id": search_id}) + return + elif dispatch_state == "FAILED": + raise RuntimeError(f"Search job failed: {search_id}") + + await asyncio.sleep(2) # Wait 2 seconds before checking again + wait_count += 2 + + raise RuntimeError( + f"Search job timed out after {max_wait_seconds} seconds: {search_id}" + ) + + async def _get_job_results( + self, client: AsyncClient, search_id: str + ) -> AsyncGenerator[Dict[str, Any], None]: + """Get results from a completed search job with pagination.""" + while not self.is_done: + params = { + "output_mode": "json", + "count": str(self.chunk_size), + "offset": str(self.offset), + } + + response = await client.get( + self.get_results_endpoint(search_id), + params=params, + headers=self._headers, + auth=self._auth, + ) + + if response.status_code != 200: + raise HTTPStatusError( + f"Failed to get job results: {response.status_code}", + request=response.request, + response=response, + ) + + try: + results = response.json() + results_list = results.get("results", []) + except json.JSONDecodeError: + self.logger.error( + "Failed to parse results as JSON", + extra={"search_id": search_id, "response": response.text[:500]}, + ) + results_list = [] + + if not results_list: + self.is_done = True + break + + self.logger.debug( + "Retrieved results chunk", + extra={ + "search_id": search_id, + "chunk_size": len(results_list), + "offset": self.offset, + "total_offset": self.offset + len(results_list), + }, + ) + + for result in results_list: + yield result + + self.offset += len(results_list) + + if len(results_list) < self.chunk_size: + self.is_done = True + + async def extract_records(self) -> AsyncGenerator[Any, Any]: + """Extract records using the official Splunk REST API job-based approach.""" + async with AsyncClient( + verify=self.verify_ssl, timeout=self.request_timeout_seconds + ) as client: + try: + # Step 1: Create search job if we don't have one + if not self.search_id: + self.search_id = await self._create_search_job(client) + await self._wait_for_job_completion(client, self.search_id) + + record_count = 0 + async for result in self._get_job_results(client, self.search_id): + record_count += 1 + yield SplunkRecord.from_raw_splunk_result(result).record_data + + self.logger.info( + "Extraction completed", + extra={"search_id": self.search_id, "total_records": record_count}, + ) + + except HTTPStatusError as exc: + self.logger.error( + "Splunk request failed", + extra={ + "status_code": exc.response.status_code, + "content": ( + exc.response.text[:500] + if hasattr(exc.response, "text") + else str(exc.response) + ), + "url": str(exc.request.url), + "query": self._normalized_query, + }, + ) + raise + except Exception as exc: + self.logger.error( + "Unexpected error during extraction", + extra={ + "error": str(exc), + "error_type": type(exc).__name__, + "query": self._normalized_query, + "search_id": self.search_id, + }, + ) + raise + + async def make_checkpoint(self): + """Create a checkpoint for resuming extraction.""" + return { + "search_id": self.search_id, + "offset": self.offset, + "is_done": self.is_done, + } + + async def resume_from_checkpoint(self, checkpoint_object): + """Resume extraction from a checkpoint.""" + if checkpoint_object: + self.search_id = checkpoint_object.get("search_id") + self.offset = checkpoint_object.get("offset", 0) + self.is_done = checkpoint_object.get("is_done", False) + + self.logger.info( + "Resuming from checkpoint", + extra={ + "search_id": self.search_id, + "offset": self.offset, + "is_done": self.is_done, + }, + ) diff --git a/tests/unit/pipeline/extractors/stores/test_splunk_extractor.py b/tests/unit/pipeline/extractors/stores/test_splunk_extractor.py new file mode 100644 index 00000000..a1c49d88 --- /dev/null +++ b/tests/unit/pipeline/extractors/stores/test_splunk_extractor.py @@ -0,0 +1,657 @@ +import json +from unittest.mock import AsyncMock, patch + +import pytest +from hamcrest import assert_that, equal_to, has_entries, has_length +from httpx import HTTPStatusError + +from nodestream.pipeline.extractors.stores.splunk_extractor import SplunkExtractor + + +@pytest.fixture +def mock_job_creation_response(): + """Mock response for job creation.""" + return '1234567890.123' + + +@pytest.fixture +def mock_job_results(): + """Sample Splunk search results in the format returned by the results API.""" + return { + "results": [ + { + "_time": "2023-01-01T10:00:00", + "host": "server1", + "message": "Login successful", + }, + { + "_time": "2023-01-01T10:01:00", + "host": "server2", + "message": "Error occurred", + }, + {"_time": "2023-01-01T10:02:00", "host": "server1", "message": "Logout"}, + ] + } + + +@pytest.fixture +def splunk_extractor(): + """Create a SplunkExtractor instance for testing.""" + return SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=test | head 10", + auth_token="test-token-123", + earliest_time="-1h", + latest_time="now", + chunk_size=100, + ) + + +@pytest.fixture +def splunk_extractor_basic_auth(): + """Create a SplunkExtractor instance with basic auth for testing.""" + return SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=test", + username="testuser", + password="testpass", + ) + + +def test_splunk_extractor_from_file_data_with_token_auth(): + extractor = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089/", # trailing slash should be stripped + query="index=main", + auth_token="my-token", + earliest_time="-24h", + chunk_size=500, + ) + + assert_that(extractor.base_url, equal_to("https://splunk.example.com:8089")) + assert_that(extractor.query, equal_to("index=main")) + assert_that(extractor.auth_token, equal_to("my-token")) + assert_that(extractor.username, equal_to(None)) + assert_that(extractor.password, equal_to(None)) + assert_that(extractor.earliest_time, equal_to("-24h")) + assert_that(extractor.chunk_size, equal_to(500)) + + +def test_splunk_extractor_from_file_data_with_basic_auth(): + extractor = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=security", + username="admin", + password="secret", + ) + + assert_that(extractor.auth_token, equal_to(None)) + assert_that(extractor.username, equal_to("admin")) + assert_that(extractor.password, equal_to("secret")) + + +def test_splunk_extractor_from_file_data_defaults(): + extractor = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=main", + ) + + assert_that(extractor.verify_ssl, equal_to(True)) + assert_that(extractor.request_timeout_seconds, equal_to(300)) + assert_that(extractor.chunk_size, equal_to(1000)) + + +# Helper property tests +def test_splunk_extractor_auth_property_with_token(splunk_extractor): + assert_that(splunk_extractor._auth, equal_to(None)) # Token goes in header + + +def test_splunk_extractor_auth_property_with_basic_auth(splunk_extractor_basic_auth): + auth = splunk_extractor_basic_auth._auth + assert_that(auth is not None, equal_to(True)) + + +def test_splunk_extractor_auth_property_returns_none_when_no_credentials(): + """Test _auth property returns None when no credentials are provided.""" + extractor = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=main", + # No auth_token, username, or password provided + ) + assert_that(extractor._auth, equal_to(None)) + + +def test_splunk_extractor_headers_with_token(splunk_extractor): + headers = splunk_extractor._headers + assert_that( + headers, + has_entries( + {"Accept": "application/json", "Authorization": "Splunk test-token-123"} + ), + ) + + +def test_splunk_extractor_headers_without_token(splunk_extractor_basic_auth): + headers = splunk_extractor_basic_auth._headers + assert_that(headers, equal_to({"Accept": "application/json"})) + + +def test_splunk_extractor_normalized_query_adds_search_prefix(): + # Query already has search prefix + extractor = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="search index=main", + ) + assert_that(extractor._normalized_query, equal_to("search index=main")) + + # Query without search prefix + extractor2 = SplunkExtractor.from_file_data( + base_url="https://splunk.example.com:8089", + query="index=main | head 10", + ) + assert_that(extractor2._normalized_query, equal_to("search index=main | head 10")) + + +def test_splunk_extractor_endpoint_urls(splunk_extractor): + jobs_endpoint = splunk_extractor.get_jobs_endpoint() + assert_that( + jobs_endpoint, + equal_to("https://splunk.example.com:8089/servicesNS/admin/search/search/jobs"), + ) + + results_endpoint = splunk_extractor.get_results_endpoint("test123") + assert_that( + results_endpoint, + equal_to( + "https://splunk.example.com:8089/servicesNS/admin/search/search/jobs/test123/results" + ), + ) + + +@pytest.mark.asyncio +async def test_splunk_extractor_create_search_job_json_response( + splunk_extractor, mocker +): + """Test job creation with JSON response (preferred format).""" + mock_response = mocker.MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"sid": "json123"} + + mock_client = mocker.MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + + search_id = await splunk_extractor._create_search_job(mock_client) + + assert_that(search_id, equal_to("json123")) + # Verify we requested JSON format + call_args = mock_client.post.call_args + assert_that(call_args[1]["data"]["output_mode"], equal_to("json")) + + +@pytest.mark.asyncio +async def test_splunk_extractor_create_search_job_xml_fallback( + splunk_extractor, mock_job_creation_response, mocker +): + """Test job creation with XML response (fallback when JSON fails).""" + mock_response = mocker.MagicMock() + mock_response.status_code = 201 + mock_response.text = mock_job_creation_response + mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0) + + mock_client = mocker.MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + + search_id = await splunk_extractor._create_search_job(mock_client) + + assert_that(search_id, equal_to("1234567890.123")) + mock_client.post.assert_called_once() + + +@pytest.mark.asyncio +async def test_splunk_extractor_create_search_job_malformed_xml( + splunk_extractor, mocker +): + """Test job creation with malformed XML response.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 201 + mock_response.text = "Invalid XML + + + + DONE + + + """ + + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.text = xml_response + mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0) + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + # Should complete without raising (finds DONE in XML) + await splunk_extractor._wait_for_job_completion( + mock_client, "test123", max_wait_seconds=10 + ) + + +@pytest.mark.asyncio +async def test_splunk_extractor_wait_for_job_completion_http_error( + splunk_extractor, mocker +): + """Test job status checking when Splunk returns non-200 status code.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 403 + mock_response.request = mocker.MagicMock() + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with pytest.raises(HTTPStatusError, match="Failed to check job status: 403"): + await splunk_extractor._wait_for_job_completion(mock_client, "test123") + + +@pytest.mark.asyncio +async def test_splunk_extractor_wait_for_job_completion_parse_error_warning( + splunk_extractor, mocker +): + """Test job status checking when both JSON and XML parsing fail.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.text = "Invalid response format" + mock_response.json.side_effect = json.JSONDecodeError("Not JSON", "", 0) + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + # Mock the logger to verify warning is logged + with patch.object(splunk_extractor, "logger") as mock_logger: + with pytest.raises(RuntimeError, match="Search job timed out"): + await splunk_extractor._wait_for_job_completion( + mock_client, "test123", max_wait_seconds=1 + ) + + # Verify warning was logged for parse failure + mock_logger.warning.assert_called() + warning_call = mock_logger.warning.call_args + assert "Failed to parse job status" in warning_call[0][0] + + +@pytest.mark.asyncio +async def test_splunk_extractor_wait_for_job_completion_timeout( + splunk_extractor, mocker +): + """Test job status checking timeout.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "entry": [{"content": {"dispatchState": "RUNNING"}}] + } + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with pytest.raises( + RuntimeError, match="Search job timed out after 1 seconds: test123" + ): + await splunk_extractor._wait_for_job_completion( + mock_client, "test123", max_wait_seconds=1 + ) + + +@pytest.mark.asyncio +async def test_splunk_extractor_wait_for_job_completion_failure( + splunk_extractor, mocker +): + """Test job status checking when job fails.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "entry": [{"content": {"dispatchState": "FAILED"}}] + } + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with pytest.raises(RuntimeError, match="Search job failed"): + await splunk_extractor._wait_for_job_completion(mock_client, "test123") + + +# Results retrieval tests +@pytest.mark.asyncio +async def test_splunk_extractor_get_job_results_single_chunk( + splunk_extractor, mock_job_results, mocker +): + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_job_results + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + results = [] + async for result in splunk_extractor._get_job_results(mock_client, "test123"): + results.append(result) + + assert_that(results, has_length(3)) + assert_that( + results[0], + has_entries( + { + "_time": "2023-01-01T10:00:00", + "host": "server1", + "message": "Login successful", + } + ), + ) + + +@pytest.mark.asyncio +async def test_splunk_extractor_get_job_results_json_parse_error( + splunk_extractor, mocker +): + """Test results retrieval when JSON parsing fails.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.text = "Invalid JSON response" + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + results = [] + async for result in splunk_extractor._get_job_results(mock_client, "test123"): + results.append(result) + + # Should handle gracefully and return empty results + assert_that(results, has_length(0)) + assert_that(splunk_extractor.is_done, equal_to(True)) + + +@pytest.mark.asyncio +async def test_splunk_extractor_get_job_results_pagination(splunk_extractor, mocker): + # Set small chunk size for testing + splunk_extractor.chunk_size = 2 + + # Mock responses for pagination + first_response = mocker.MagicMock() + first_response.status_code = 200 + first_response.json.return_value = { + "results": [ + {"_time": "2023-01-01T10:00:00", "host": "server1"}, + {"_time": "2023-01-01T10:01:00", "host": "server2"}, + ] + } + + second_response = mocker.MagicMock() + second_response.status_code = 200 + second_response.json.return_value = { + "results": [ + {"_time": "2023-01-01T10:02:00", "host": "server3"}, + ] + } + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(side_effect=[first_response, second_response]) + + results = [] + async for result in splunk_extractor._get_job_results(mock_client, "test123"): + results.append(result) + + assert_that(results, has_length(3)) + assert_that(splunk_extractor.offset, equal_to(3)) + assert_that(splunk_extractor.is_done, equal_to(True)) + + +@pytest.mark.asyncio +async def test_splunk_extractor_get_job_results_http_error(splunk_extractor, mocker): + """Test results retrieval when Splunk returns non-200 status code.""" + mock_response = mocker.MagicMock() + mock_response.status_code = 404 + mock_response.request = mocker.MagicMock() + + mock_client = mocker.MagicMock() + mock_client.get = AsyncMock(return_value=mock_response) + + with pytest.raises(HTTPStatusError, match="Failed to get job results: 404"): + async for _ in splunk_extractor._get_job_results(mock_client, "test123"): + pass + + +# Full extraction test +@pytest.mark.asyncio +async def test_splunk_extractor_extract_records_full_flow(splunk_extractor, mocker): + with patch( + "nodestream.pipeline.extractors.stores.splunk_extractor.AsyncClient" + ) as mock_client_class: + mock_client = mocker.MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + # Mock job creation + job_response = mocker.MagicMock() + job_response.status_code = 201 + job_response.json.return_value = {"sid": "test123"} + mock_client.post = AsyncMock(return_value=job_response) + + # Mock job completion + status_response = mocker.MagicMock() + status_response.status_code = 200 + status_response.json.return_value = { + "entry": [{"content": {"dispatchState": "DONE"}}] + } + + # Mock results + results_response = mocker.MagicMock() + results_response.status_code = 200 + results_response.json.return_value = { + "results": [ + {"_time": "2023-01-01T10:00:00", "host": "server1"}, + ] + } + + # Setup get calls: first for status check, then for results + mock_client.get = AsyncMock(side_effect=[status_response, results_response]) + + records = [] + async for record in splunk_extractor.extract_records(): + records.append(record) + + assert_that(records, has_length(1)) + assert_that( + records[0], + has_entries({"_time": "2023-01-01T10:00:00", "host": "server1"}), + ) + + +# Checkpointing tests +@pytest.mark.asyncio +async def test_splunk_extractor_make_checkpoint(splunk_extractor): + splunk_extractor.search_id = "test123" + splunk_extractor.offset = 100 + splunk_extractor.is_done = False + + checkpoint = await splunk_extractor.make_checkpoint() + + assert_that( + checkpoint, + equal_to({"search_id": "test123", "offset": 100, "is_done": False}), + ) + + +@pytest.mark.asyncio +async def test_splunk_extractor_resume_from_checkpoint(splunk_extractor): + checkpoint = {"search_id": "restored123", "offset": 50, "is_done": False} + + await splunk_extractor.resume_from_checkpoint(checkpoint) + + assert_that(splunk_extractor.search_id, equal_to("restored123")) + assert_that(splunk_extractor.offset, equal_to(50)) + assert_that(splunk_extractor.is_done, equal_to(False)) + + +@pytest.mark.asyncio +async def test_splunk_extractor_extract_records_http_error_handling( + splunk_extractor, mocker +): + """Test extract_records handles HTTPStatusError and logs properly.""" + with patch( + "nodestream.pipeline.extractors.stores.splunk_extractor.AsyncClient" + ) as mock_client_class: + mock_client = mocker.MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + # Mock HTTPStatusError during job creation + mock_response = mocker.MagicMock() + mock_response.status_code = 401 + mock_response.text = "Unauthorized" + mock_response.request = mocker.MagicMock() + mock_response.request.url = "https://splunk.example.com/jobs" + + http_error = HTTPStatusError( + "Unauthorized", request=mock_response.request, response=mock_response + ) + mock_client.post = AsyncMock(side_effect=http_error) + + # Mock the logger to verify error logging + with patch.object(splunk_extractor, "logger") as mock_logger: + with pytest.raises(HTTPStatusError): + async for _ in splunk_extractor.extract_records(): + pass + + # Verify HTTPStatusError was logged with proper details + mock_logger.error.assert_called() + error_call = mock_logger.error.call_args + assert "Splunk request failed" in error_call[0][0] + assert error_call[1]["extra"]["status_code"] == 401 + assert error_call[1]["extra"]["content"] == "Unauthorized" + + +@pytest.mark.asyncio +async def test_splunk_extractor_extract_records_generic_exception_handling( + splunk_extractor, mocker +): + """Test extract_records handles generic exceptions and logs properly.""" + with patch( + "nodestream.pipeline.extractors.stores.splunk_extractor.AsyncClient" + ) as mock_client_class: + mock_client = mocker.MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + # Mock a generic exception during job creation + generic_error = ValueError("Something went wrong") + mock_client.post = AsyncMock(side_effect=generic_error) + + # Mock the logger to verify error logging + with patch.object(splunk_extractor, "logger") as mock_logger: + with pytest.raises(ValueError): + async for _ in splunk_extractor.extract_records(): + pass + + # Verify generic exception was logged with proper details + mock_logger.error.assert_called() + error_call = mock_logger.error.call_args + assert "Unexpected error during extraction" in error_call[0][0] + assert error_call[1]["extra"]["error"] == "Something went wrong" + assert error_call[1]["extra"]["error_type"] == "ValueError" + + +@pytest.mark.asyncio +async def test_splunk_extractor_extract_records_http_error_without_text_attribute( + splunk_extractor, mocker +): + """Test HTTPStatusError handling when response doesn't have text attribute.""" + with patch( + "nodestream.pipeline.extractors.stores.splunk_extractor.AsyncClient" + ) as mock_client_class: + mock_client = mocker.MagicMock() + mock_client_class.return_value.__aenter__.return_value = mock_client + mock_client_class.return_value.__aexit__.return_value = None + + # Mock HTTPStatusError with response that doesn't have text attribute + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + # Remove text attribute to test fallback + del mock_response.text + mock_response.request = mocker.MagicMock() + mock_response.request.url = "https://splunk.example.com/jobs" + + http_error = HTTPStatusError( + "Server Error", request=mock_response.request, response=mock_response + ) + mock_client.post = AsyncMock(side_effect=http_error) + + # Mock the logger to verify error logging + with patch.object(splunk_extractor, "logger") as mock_logger: + with pytest.raises(HTTPStatusError): + async for _ in splunk_extractor.extract_records(): + pass + + # Verify HTTPStatusError was logged with str(response) fallback + mock_logger.error.assert_called() + error_call = mock_logger.error.call_args + assert "Splunk request failed" in error_call[0][0] + assert error_call[1]["extra"]["status_code"] == 500 + # Should use str(response) when text attribute is missing + assert str(mock_response) in error_call[1]["extra"]["content"]