diff --git a/.copyrightconfig b/.copyrightconfig index 2bc4c917..659ca80b 100644 --- a/.copyrightconfig +++ b/.copyrightconfig @@ -11,4 +11,4 @@ startyear: 2023 # - Dotfiles already skipped automatically # Enable by removing the leading '# ' from the next line and editing values. # filesexcluded: third_party/*, docs/generated/*.md, assets/*.png, scripts/temp_*.py, vendor/lib.js -filesexcluded: .github/*, README.md, CONTRIBUTING.md, Jenkinsfile, gradle/*, docker-compose.yml, *.gradle, gradle.properties, gradlew, gradlew.bat, **/test/resources/**, docs/**, *.json, *.txt, CODEOWNERS +filesexcluded: .github/*, README.md, CONTRIBUTING.md, Jenkinsfile, gradle/*, docker-compose.yml, *.gradle, gradle.properties, gradlew, gradlew.bat, **/test/resources/**, docs/**, *.json, *.txt, CODEOWNERS, *.properties diff --git a/gradle.properties b/gradle.properties index 59a6ce4c..898eccbe 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,4 +1,4 @@ -version=3.0-SNAPSHOT +version=3.1-SNAPSHOT sparkVersion=4.1.1 tikaVersion=3.2.3 semaphoreVersion=5.10.0 diff --git a/marklogic-spark-connector/build.gradle b/marklogic-spark-connector/build.gradle index f9ce2e91..e06e57a5 100644 --- a/marklogic-spark-connector/build.gradle +++ b/marklogic-spark-connector/build.gradle @@ -60,6 +60,9 @@ dependencies { // Only needs compileOnly, as the Java Client brings this as an implementation dependency. compileOnly 'com.squareup.okhttp3:okhttp:5.2.0' + // For Nuclia support + implementation 'com.squareup.okhttp3:okhttp-sse:5.2.1' + // Automatic loading of test framework implementation dependencies is deprecated. // https://docs.gradle.org/current/userguide/upgrading_version_8.html#test_framework_implementation_dependencies // Without this, once using JUnit 5.12 or higher, Gradle will not find any tests and report an error of: diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/Options.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/Options.java index f4368fb8..c2767801 100644 --- a/marklogic-spark-connector/src/main/java/com/marklogic/spark/Options.java +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/Options.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. */ package com.marklogic.spark; @@ -496,6 +496,29 @@ public abstract class Options { */ public static final String WRITE_EMBEDDER_BASE64_ENCODE = WRITE_EMBEDDER_PREFIX + "base64Encode"; + private static final String WRITE_NUCLIA_PREFIX = "spark.marklogic.write.nuclia."; + + /** + * Nuclia API key for authentication. Required if any Nuclia options are used. + * + * @since 3.1.0 + */ + public static final String WRITE_NUCLIA_API_KEY = WRITE_NUCLIA_PREFIX + "apikey"; + + /** + * Nuclia region (e.g., "aws-us-east-2-1"). Required if any Nuclia options are used. + * + * @since 3.1.0 + */ + public static final String WRITE_NUCLIA_REGION = WRITE_NUCLIA_PREFIX + "region"; + + /** + * Maximum number of seconds to wait for Nuclia processing to complete. Defaults to 120 seconds. + * + * @since 3.1.0 + */ + public static final String WRITE_NUCLIA_TIMEOUT = WRITE_NUCLIA_PREFIX + "timeout"; + /** * Defines the host for classification requests * diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java index b277d451..eda9f0a6 100644 --- a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentInputs.java @@ -156,6 +156,26 @@ public void setChunks(List chunks) { } } + /** + * Adds a chunk with its embedding and model name. This is useful for workflows like Nuclia where + * chunks and embeddings are received together. + * + * @param text the chunk text + * @param embedding the embedding vector (can be null) + * @param modelName the model name (can be null) + */ + public void addChunk(String text, float[] embedding, String modelName) { + if (chunkInputsList == null) { + chunkInputsList = new ArrayList<>(); + } + ChunkInputs chunkInputs = new ChunkInputs(text); + if (embedding != null) { + chunkInputs.setEmbedding(embedding); + chunkInputs.setModelName(modelName); + } + chunkInputsList.add(chunkInputs); + } + public byte[] getDocumentClassification() { return documentClassification; } diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipeline.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipeline.java index 0291185a..89f32ad7 100644 --- a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipeline.java +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipeline.java @@ -11,6 +11,8 @@ import com.marklogic.spark.core.embedding.EmbeddingProducer; import com.marklogic.spark.core.extraction.ExtractionResult; import com.marklogic.spark.core.extraction.TextExtractor; +import com.marklogic.spark.core.nuclia.NucliaClient; +import com.marklogic.spark.core.nuclia.NucliaDocumentProcessor; import com.marklogic.spark.core.splitter.TextSplitter; import java.io.Closeable; @@ -30,6 +32,8 @@ public class DocumentPipeline implements Closeable { private final TextClassifier textClassifier; private final EmbeddingProducer embeddingProducer; private final ChunkSelector chunkSelector; + private final NucliaClient nucliaClient; + private final NucliaDocumentProcessor nucliaProcessor; public DocumentPipeline(TextExtractor textExtractor, TextSplitter textSplitter, TextClassifier textClassifier, EmbeddingProducer embeddingProducer, ChunkSelector chunkSelector) { this.textExtractor = textExtractor; @@ -37,6 +41,25 @@ public DocumentPipeline(TextExtractor textExtractor, TextSplitter textSplitter, this.textClassifier = textClassifier; this.embeddingProducer = embeddingProducer; this.chunkSelector = chunkSelector; + this.nucliaClient = null; + this.nucliaProcessor = null; + } + + /** + * Constructor for Nuclia-based pipeline. Nuclia handles extraction, splitting, and embedding generation. + * + * @param nucliaClient the Nuclia client for processing + * @param textClassifier optional text classifier (can be null) + * @since 3.1.0 + */ + public DocumentPipeline(NucliaClient nucliaClient, TextClassifier textClassifier) { + this.nucliaClient = nucliaClient; + this.nucliaProcessor = new NucliaDocumentProcessor(nucliaClient); + this.textClassifier = textClassifier; + this.textExtractor = null; + this.textSplitter = null; + this.embeddingProducer = null; + this.chunkSelector = null; } @Override @@ -44,6 +67,34 @@ public void close() throws IOException { if (textClassifier != null) { textClassifier.close(); } + if (nucliaClient != null) { + nucliaClient.close(); + } + } + + // Package-private getters for testing + NucliaClient getNucliaClient() { + return nucliaClient; + } + + TextClassifier getTextClassifier() { + return textClassifier; + } + + TextExtractor getTextExtractor() { + return textExtractor; + } + + TextSplitter getTextSplitter() { + return textSplitter; + } + + EmbeddingProducer getEmbeddingProducer() { + return embeddingProducer; + } + + ChunkSelector getChunkSelector() { + return chunkSelector; } /** @@ -51,6 +102,11 @@ public void close() throws IOException { * embedding generation. */ public void processDocuments(List inputs) { + if (nucliaProcessor != null) { + processWithNuclia(inputs); + return; + } + if (textExtractor != null) { inputs.stream().forEach(this::extractText); } @@ -68,6 +124,15 @@ public void processDocuments(List inputs) { } } + private void processWithNuclia(List inputs) { + nucliaProcessor.processDocuments(inputs); + + // Optionally classify after Nuclia processing + if (textClassifier != null) { + classifyText(inputs); + } + } + private void classifyText(List inputs) { List contents = new ArrayList<>(); for (DocumentInputs input : inputs) { diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipelineFactory.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipelineFactory.java index 83736a4e..d4b5e7d7 100644 --- a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipelineFactory.java +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/DocumentPipelineFactory.java @@ -1,12 +1,12 @@ /* - * Copyright (c) 2023-2025 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. */ package com.marklogic.spark.core; +import com.marklogic.langchain4j.Langchain4jFactory; import com.marklogic.spark.ConnectorException; import com.marklogic.spark.Context; import com.marklogic.spark.Options; -import com.marklogic.spark.Util; import com.marklogic.spark.core.classifier.TextClassifier; import com.marklogic.spark.core.classifier.TextClassifierFactory; import com.marklogic.spark.core.embedding.ChunkSelector; @@ -15,19 +15,24 @@ import com.marklogic.spark.core.embedding.EmbeddingProducerFactory; import com.marklogic.spark.core.extraction.TextExtractor; import com.marklogic.spark.core.extraction.TikaTextExtractor; +import com.marklogic.spark.core.nuclia.NucliaClient; import com.marklogic.spark.core.splitter.TextSplitter; import com.marklogic.spark.core.splitter.TextSplitterFactory; -import java.lang.reflect.InvocationTargetException; - public abstract class DocumentPipelineFactory { - private static final String FACTORY_CLASS_NAME = "com.marklogic.langchain4j.Langchain4jFactory"; - // For some reason, Sonar thinks the check for four nulls always resolves to false, even though it's definitely // possible. So ignoring that warning. @SuppressWarnings("java:S2589") public static DocumentPipeline newDocumentPipeline(Context context) { + // Check for Nuclia configuration first + NucliaClient nucliaClient = newNucliaClient(context); + if (nucliaClient != null) { + TextClassifier textClassifier = TextClassifierFactory.newTextClassifier(context); + return new DocumentPipeline(nucliaClient, textClassifier); + } + + // Standard pipeline with separate components final TextExtractor textExtractor = context.getBooleanOption(Options.WRITE_EXTRACTED_TEXT, false) ? new TikaTextExtractor() : null; final TextSplitter textSplitter = newTextSplitter(context); @@ -52,6 +57,25 @@ public static DocumentPipeline newDocumentPipeline(Context context) { new DocumentPipeline(textExtractor, textSplitter, textClassifier, embeddingProducer, chunkSelector); } + private static NucliaClient newNucliaClient(Context context) { + String apiKey = context.getProperties().get(Options.WRITE_NUCLIA_API_KEY); + if (apiKey == null || apiKey.trim().isEmpty()) { + return null; + } + + final String region = context.getProperties().get(Options.WRITE_NUCLIA_REGION); + + if (region == null || region.trim().isEmpty()) { + throw new ConnectorException(String.format("When %s is specified, %s must also be specified.", + context.getOptionNameForMessage(Options.WRITE_NUCLIA_API_KEY), + context.getOptionNameForMessage(Options.WRITE_NUCLIA_REGION))); + } + + int timeout = context.getIntOption(Options.WRITE_NUCLIA_TIMEOUT, 120, 1); + + return new NucliaClient(apiKey, region, timeout); + } + private static TextSplitter newTextSplitter(Context context) { boolean shouldSplit = context.getProperties().keySet().stream().anyMatch(key -> key.startsWith(Options.WRITE_SPLITTER_PREFIX)); if (!shouldSplit) { @@ -73,24 +97,7 @@ private static EmbeddingProducer newEmbeddingProducer(Context context) { } private static Object newLangchain4jProcessorFactory() { - try { - return Class.forName(FACTORY_CLASS_NAME).getDeclaredConstructor().newInstance(); - } catch (UnsupportedClassVersionError e) { - throw new ConnectorException("Unable to configure support for splitting documents and/or generating embeddings. " + - "Please ensure you are using Java 17 or higher for these operations.", e); - } - // Catch every checked exception from trying to instantiate the class. Any exception from the factory class - // itself is expected to be a RuntimeException that should bubble up. - catch (ClassNotFoundException | InstantiationException | IllegalAccessException | NoSuchMethodException | - InvocationTargetException ex) { - if (Util.MAIN_LOGGER.isDebugEnabled()) { - Util.MAIN_LOGGER.debug( - "Unable to instantiate factory class {}; this is expected when the marklogic-langchain4j module is not on the classpath. Cause: {}", - FACTORY_CLASS_NAME, ex.getMessage() - ); - } - return null; - } + return new Langchain4jFactory(); } private DocumentPipelineFactory() { diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaClient.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaClient.java new file mode 100644 index 00000000..b77bbcca --- /dev/null +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaClient.java @@ -0,0 +1,308 @@ +/* + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.spark.core.nuclia; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.marklogic.spark.Util; +import okhttp3.*; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSources; + +import java.io.IOException; +import java.util.Base64; +import java.util.concurrent.TimeUnit; +import java.util.stream.Stream; + +/** + * Client for interacting with the Nuclia RAG API. + * Handles text ingestion, processing, and retrieval of embeddings and chunks. + * Implements AutoCloseable to properly shut down HTTP client resources. + */ +public class NucliaClient implements AutoCloseable { + + private final String apiKey; + private final String baseUrl; + private final OkHttpClient httpClient; + private final ObjectMapper objectMapper; + private final int timeoutSeconds; + + /** + * Creates a new NucliaClient with a custom timeout. + * + * @param apiKey the Nuclia API key for authentication + * @param region the region (e.g., "aws-us-east-2-1") + * @param timeoutSeconds the maximum number of seconds to wait for processing to complete + */ + public NucliaClient(String apiKey, String region, int timeoutSeconds) { + this.apiKey = apiKey; + this.baseUrl = "https://" + region + ".rag.progress.cloud/api/v1"; + this.httpClient = new OkHttpClient.Builder() + .readTimeout(timeoutSeconds, TimeUnit.SECONDS) + .build(); + this.objectMapper = new ObjectMapper(); + this.timeoutSeconds = timeoutSeconds; + + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Initialized NucliaClient: baseUrl={}, timeoutSeconds={}", baseUrl, timeoutSeconds); + } + } + + /** + * Processes a file through Nuclia's pipeline synchronously. + * Uploads the file, submits for processing, waits for completion, and retrieves the results. + * + * @param filename the name of the file + * @param content the binary content of the file + * @return a Stream of ObjectNode events containing chunks and embeddings + * @throws IOException if any request fails + * @throws InterruptedException if the thread is interrupted while waiting + */ + public Stream processData(String filename, byte[] content) throws IOException, InterruptedException { + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Starting processData for file: {}, size: {} bytes", filename, content.length); + } + + final String resourceId = uploadFile(content, filename); + final String processingId = submitFileForProcessing(filename, resourceId); + final boolean completed = waitForCompletion(processingId, timeoutSeconds); + + if (!completed) { + throw new IOException("Processing timed out after " + timeoutSeconds + " seconds for processing ID: " + processingId); + } + + return getProcessingResults(processingId); + } + + /** + * Uploads a binary file to Nuclia and returns a resource ID. + * + * @param content the binary content of the file + * @param filename the name of the file + * @return the resource ID for the uploaded file + * @throws IOException if the upload fails + */ + private String uploadFile(byte[] content, String filename) throws IOException { + String endpoint = baseUrl + "/processing/upload"; + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Uploading file to Nuclia: filename={}, endpoint={}, size={} bytes", filename, endpoint, content.length); + } + + // Encode filename in base64 for X-FILENAME header + String encodedFilename = Base64.getEncoder().encodeToString(filename.getBytes()); + + Request request = newAuthenticatedRequestBuilder(endpoint) + .header("X-FILENAME", encodedFilename) + .post(RequestBody.create(content, null)) + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + Util.MAIN_LOGGER.error("Failed to upload file: {} {}", response.code(), response.message()); + throw new IOException("Failed to upload file: " + response.code() + " " + response.message()); + } + String resourceId = response.body().string(); + return resourceId; + } + } + + /** + * Submits a file for processing using its resource ID. + * + * @param filename the name of the file + * @param resourceId the resource ID from uploadFile + * @return the processing ID for tracking the request + * @throws IOException if the request fails + */ + private String submitFileForProcessing(String filename, String resourceId) throws IOException { + final String endpoint = baseUrl + "/processing/push"; + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Submitting file for processing: filename={}, endpoint={}", filename, endpoint); + } + + final String requestBody = String.format(""" + {"filefield": {"%s": "%s"}} + """, escapeJson(filename), escapeJson(resourceId)); + + Request request = newAuthenticatedRequestBuilder(endpoint) + .header("Content-Type", "application/json") + .post(RequestBody.create(requestBody, MediaType.get("application/json"))) + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + Util.MAIN_LOGGER.error("Failed to submit file: {} {}", response.code(), response.message()); + throw new IOException("Failed to submit file: " + response.code() + " " + response.message()); + } + + String responseBody = response.body().string(); + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Submit file response: {}", responseBody); + } + JsonNode jsonResponse = objectMapper.readTree(responseBody); + String processingId = jsonResponse.get("processing_id").asText(); + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("File submission successful, received processingId: {}", processingId); + } + return processingId; + } + } + + /** + * Checks the status of a processing request. + * + * @param processingId the processing ID returned from submitText + * @return true if processing is complete, false otherwise + * @throws IOException if the request fails + */ + private boolean isProcessingComplete(String processingId) throws IOException { + String endpoint = baseUrl + "/processing/requests/" + processingId; + + Request request = newAuthenticatedRequestBuilder(endpoint) + .get() + .build(); + + try (Response response = httpClient.newCall(request).execute()) { + if (!response.isSuccessful()) { + throw new IOException("Failed to check status: " + response.code() + " " + response.message()); + } + + String responseBody = response.body().string(); + JsonNode jsonResponse = objectMapper.readTree(responseBody); + return jsonResponse.has("completed") && jsonResponse.get("completed").asBoolean(); + } + } + + /** + * Waits for processing to complete, polling at regular intervals. + * + * @param processingId the processing ID to wait for + * @param maxWaitSeconds maximum time to wait in seconds + * @return true if processing completed within the timeout, false otherwise + * @throws IOException if a request fails + * @throws InterruptedException if the thread is interrupted while waiting + */ + private boolean waitForCompletion(String processingId, int maxWaitSeconds) throws IOException, InterruptedException { + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Polling for completion: processingId={}, maxWaitSeconds={}", processingId, maxWaitSeconds); + } + + int attempts = 0; + int maxAttempts = maxWaitSeconds; + + while (attempts < maxAttempts) { + if (isProcessingComplete(processingId)) { + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Processing completed after {} attempts", attempts + 1); + } + return true; + } + TimeUnit.SECONDS.sleep(1); + attempts++; + if (Util.MAIN_LOGGER.isDebugEnabled() && attempts % 5 == 0) { + Util.MAIN_LOGGER.debug("Still waiting for completion: attempt {}/{}", attempts, maxAttempts); + } + } + + Util.MAIN_LOGGER.warn("Processing did not complete within {} seconds", maxWaitSeconds); + return false; + } + + /** + * Retrieves the SSE results stream from a completed processing request. + * + * @param processingId the processing ID + * @return a Stream of ObjectNode events + * @throws IOException if the request fails + */ + private Stream getProcessingResults(String processingId) throws IOException { + final String endpoint = baseUrl + "/processing/requests/" + processingId + "/results"; + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Retrieving SSE results: processingId={}, endpoint={}", processingId, endpoint); + } + + Request request = newAuthenticatedRequestBuilder(endpoint) + .get() + .build(); + + NucliaEventCollector collector = new NucliaEventCollector(objectMapper); + + EventSource eventSource = EventSources.createFactory(httpClient) + .newEventSource(request, collector); + + try { + Stream results = collector.awaitCompletion(timeoutSeconds).stream(); + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("SSE stream completed for processingId: {}", processingId); + } + return results; + } catch (InterruptedException e) { + Util.MAIN_LOGGER.error("Interrupted while waiting for SSE stream: processingId={}", processingId, e); + eventSource.cancel(); + Thread.currentThread().interrupt(); + throw new IOException("Interrupted while waiting for SSE stream", e); + } + } + + /** + * Creates a new Request.Builder with the endpoint URL and Authorization header already set. + * + * @param endpoint the full endpoint URL + * @return a Request.Builder with URL and auth header configured + */ + private Request.Builder newAuthenticatedRequestBuilder(String endpoint) { + return new Request.Builder() + .url(endpoint) + .header("Authorization", "Bearer " + apiKey); + } + + /** + * Escapes JSON special characters in a string. + */ + private String escapeJson(String input) { + if (input == null) { + return ""; + } + return input.replace("\\", "\\\\") + .replace("\"", "\\\"") + .replace("\n", "\\n") + .replace("\r", "\\r") + .replace("\t", "\\t"); + } + + /** + * Closes the HTTP client and releases all resources. + * Shuts down connection pools and executor services to allow JVM to exit cleanly. + */ + @Override + public void close() { + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("Closing NucliaClient and releasing HTTP client resources"); + } + try { + httpClient.dispatcher().executorService().shutdownNow(); + httpClient.connectionPool().evictAll(); + } finally { + try { + if (httpClient.cache() != null) { + httpClient.cache().close(); + } + } catch (IOException e) { + // Ignore - we're shutting down anyway + } + } + if (Util.MAIN_LOGGER.isDebugEnabled()) { + Util.MAIN_LOGGER.debug("NucliaClient closed successfully"); + } + } + + public String getBaseUrl() { + return baseUrl; + } + + public int getTimeoutSeconds() { + return timeoutSeconds; + } +} diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaDocumentProcessor.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaDocumentProcessor.java new file mode 100644 index 00000000..97115a93 --- /dev/null +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaDocumentProcessor.java @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.spark.core.nuclia; + +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.marklogic.spark.ConnectorException; +import com.marklogic.spark.core.DocumentInputs; + +import java.io.IOException; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Processes documents through Nuclia's API for text extraction, chunking, and embedding generation. + * + * @since 3.1.0 + */ +public class NucliaDocumentProcessor { + + private final NucliaClient nucliaClient; + + public NucliaDocumentProcessor(NucliaClient nucliaClient) { + this.nucliaClient = nucliaClient; + } + + /** + * Processes a list of document inputs through Nuclia. + * Each document's binary content is uploaded to Nuclia, which returns extracted text, chunks, and embeddings. + * + * @param inputs the list of document inputs to process + */ + public void processDocuments(List inputs) { + for (DocumentInputs input : inputs) { + try { + // Get binary content for file upload to Nuclia + byte[] content = input.getContentAsBytes(); + if (content == null || content.length == 0) { + continue; + } + + // Extract filename from URI (e.g., "/path/to/file.pdf" -> "file.pdf") + String filename = extractFilename(input.getInitialUri()); + + // Process through Nuclia using binary file upload workflow + // Collect to list since we need all the data anyway + List results = nucliaClient.processData(filename, content) + .collect(Collectors.toList()); + + if (results.isEmpty()) { + continue; + } + + // First ObjectNode is the extracted text + ObjectNode firstNode = results.get(0); + String extractedText = extractTextFromNucliaNode(firstNode); + if (extractedText != null && !extractedText.isEmpty()) { + input.setExtractedText(extractedText); + } + + // Rest are chunks with embeddings - extract all data for each chunk at once + for (int i = 1; i < results.size(); i++) { + ObjectNode chunkNode = results.get(i); + addChunkFromNucliaNode(chunkNode, input); + } + + } catch (IOException | InterruptedException e) { + throw new ConnectorException("Failed to process document with Nuclia: " + input.getInitialUri(), e); + } + } + } + + /** + * Extracts the full text from the first Nuclia SSE event node. + * The node structure is: + * { + * "type": "FullText", + * "field": "content", + * "field_type": "TEXT", + * "text": "extracted text content..." + * } + */ + private String extractTextFromNucliaNode(ObjectNode node) { + if (node.has("text")) { + return node.get("text").asText(); + } + return null; + } + + /** + * Extracts chunk data from a Nuclia SSE chunk event node and adds it to the document input. + * The node structure is: + * { + * "type": "Chunk", + * "text": "chunk text content...", + * "embeddings": [ + * { + * "id": "multilingual-2024-05-06", + * "embedding": [0.123, -0.456, ...] + * } + * ] + * } + * If multiple embeddings are present, creates a separate chunk for each embedding with the same text. + */ + private void addChunkFromNucliaNode(ObjectNode node, DocumentInputs input) { + // Extract text + String text = node.has("text") ? node.get("text").asText() : null; + if (text == null || text.isEmpty()) { + return; + } + + // Process each embedding in the array + if (node.has("embeddings") && node.get("embeddings").isArray()) { + var embeddingsArray = node.get("embeddings"); + + for (int i = 0; i < embeddingsArray.size(); i++) { + var embeddingObj = embeddingsArray.get(i); + float[] embedding = null; + String modelName = null; + + if (embeddingObj.has("embedding") && embeddingObj.get("embedding").isArray()) { + var embeddingArray = embeddingObj.get("embedding"); + int size = embeddingArray.size(); + embedding = new float[size]; + for (int j = 0; j < size; j++) { + embedding[j] = (float) embeddingArray.get(j).asDouble(); + } + } + + if (embeddingObj.has("id")) { + modelName = embeddingObj.get("id").asText(); + } + + input.addChunk(text, embedding, modelName); + } + } else { + // No embeddings, still add the chunk with just text + input.addChunk(text, null, null); + } + } + + /** + * Extracts the filename from a URI path. + * For example: "/path/to/file.pdf" returns "file.pdf" + */ + private String extractFilename(String uri) { + if (uri == null || uri.isEmpty()) { + return "document"; + } + int lastSlash = uri.lastIndexOf('/'); + if (lastSlash >= 0 && lastSlash < uri.length() - 1) { + return uri.substring(lastSlash + 1); + } + return uri; + } +} diff --git a/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaEventCollector.java b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaEventCollector.java new file mode 100644 index 00000000..67ebae4b --- /dev/null +++ b/marklogic-spark-connector/src/main/java/com/marklogic/spark/core/nuclia/NucliaEventCollector.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.spark.core.nuclia; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.node.ObjectNode; +import okhttp3.Response; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +/** + * EventSourceListener implementation that collects SSE events from Nuclia's processing results endpoint. + * Parses each event as JSON and collects ObjectNode events into a list. + */ +class NucliaEventCollector extends EventSourceListener { + + private final ObjectMapper objectMapper; + private final List events; + private final CountDownLatch latch; + private Exception error; + + NucliaEventCollector(ObjectMapper objectMapper) { + this.objectMapper = objectMapper; + this.events = new ArrayList<>(); + this.latch = new CountDownLatch(1); + this.error = null; + } + + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + try { + JsonNode node = objectMapper.readTree(data); + if (node.isObject()) { + events.add((ObjectNode) node); + } + } catch (Exception e) { + error = e; + } + } + + @Override + public void onClosed(EventSource eventSource) { + latch.countDown(); + } + + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + error = new IOException("SSE connection failed: " + + (response != null ? response.code() + " " + response.message() : t.getMessage()), t); + latch.countDown(); + } + + /** + * Waits for the SSE stream to complete. + * + * @param timeoutSeconds maximum time to wait in seconds + * @return the collected list of events + * @throws IOException if the stream fails or times out + * @throws InterruptedException if the thread is interrupted while waiting + */ + List awaitCompletion(int timeoutSeconds) throws IOException, InterruptedException { + if (!latch.await(timeoutSeconds, TimeUnit.SECONDS)) { + throw new IOException("SSE stream did not complete within " + timeoutSeconds + " seconds"); + } + + if (error != null) { + if (error instanceof IOException) { + throw (IOException) error; + } + throw new IOException("Failed to process SSE events; cause: " + error.getMessage(), error); + } + + return events; + } +} diff --git a/marklogic-spark-connector/src/main/resources/marklogic-spark-messages.properties b/marklogic-spark-connector/src/main/resources/marklogic-spark-messages.properties index 7f0ad14a..2731c9c5 100644 --- a/marklogic-spark-connector/src/main/resources/marklogic-spark-messages.properties +++ b/marklogic-spark-connector/src/main/resources/marklogic-spark-messages.properties @@ -22,4 +22,6 @@ spark.marklogic.write.splitter.sidecar.maxChunks= spark.marklogic.write.embedder.chunks.jsonPointer= spark.marklogic.write.embedder.chunks.xpath= spark.marklogic.write.embedder.batchSize= - +spark.marklogic.write.nuclia.apikey= +spark.marklogic.write.nuclia.kbid= +spark.marklogic.write.nuclia.region= diff --git a/marklogic-spark-connector/src/test/java/com/marklogic/spark/core/DocumentPipelineFactoryTest.java b/marklogic-spark-connector/src/test/java/com/marklogic/spark/core/DocumentPipelineFactoryTest.java new file mode 100644 index 00000000..fb8284bc --- /dev/null +++ b/marklogic-spark-connector/src/test/java/com/marklogic/spark/core/DocumentPipelineFactoryTest.java @@ -0,0 +1,137 @@ +/* + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.spark.core; + +import com.marklogic.spark.ConnectorException; +import com.marklogic.spark.Context; +import com.marklogic.spark.Options; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class DocumentPipelineFactoryTest { + + @Test + void nucliaWithAllRequiredOptions() { + Map options = new HashMap<>(); + options.put(Options.WRITE_NUCLIA_API_KEY, "test-api-key"); + options.put(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNotNull(pipeline.getNucliaClient(), "NucliaClient should be present"); + assertEquals("https://aws-us-east-2-1.rag.progress.cloud/api/v1", pipeline.getNucliaClient().getBaseUrl()); + assertEquals(120, pipeline.getNucliaClient().getTimeoutSeconds(), "Default timeout should be 120 seconds"); + + assertNull(pipeline.getTextExtractor(), "TextExtractor should not be present in Nuclia pipeline"); + assertNull(pipeline.getTextSplitter(), "TextSplitter should not be present in Nuclia pipeline"); + assertNull(pipeline.getEmbeddingProducer(), "EmbeddingProducer should not be present in Nuclia pipeline"); + assertNull(pipeline.getChunkSelector(), "ChunkSelector should not be present in Nuclia pipeline"); + } + + @Test + void nucliaWithCustomTimeout() { + Map options = new HashMap<>(); + options.put(Options.WRITE_NUCLIA_API_KEY, "test-api-key"); + options.put(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1"); + options.put(Options.WRITE_NUCLIA_TIMEOUT, "300"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNotNull(pipeline.getNucliaClient()); + assertEquals(300, pipeline.getNucliaClient().getTimeoutSeconds()); + } + + @Test + void nucliaWithMissingRegion() { + Map options = new HashMap<>(); + options.put(Options.WRITE_NUCLIA_API_KEY, "test-api-key"); + Context context = new Context(options); + + ConnectorException ex = assertThrows(ConnectorException.class, () -> { + DocumentPipelineFactory.newDocumentPipeline(context); + }); + + assertTrue(ex.getMessage().contains(Options.WRITE_NUCLIA_REGION), + "Error message should mention missing region option"); + } + + @Test + void nucliaWithEmptyApiKey() { + Map options = new HashMap<>(); + options.put(Options.WRITE_NUCLIA_API_KEY, " "); + options.put(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNull(pipeline, "Pipeline should be null when API key is empty/whitespace"); + } + + @Test + void noOptionsReturnsNull() { + Map options = new HashMap<>(); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNull(pipeline, "Pipeline should be null when no processing options are provided"); + } + + @Test + void textExtractorOnly() { + Map options = new HashMap<>(); + options.put(Options.WRITE_EXTRACTED_TEXT, "true"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNotNull(pipeline, "Pipeline should be created with text extractor"); + assertNotNull(pipeline.getTextExtractor(), "TextExtractor should be present"); + assertNull(pipeline.getNucliaClient(), "NucliaClient should not be present"); + assertNull(pipeline.getTextSplitter(), "TextSplitter should not be present"); + assertNull(pipeline.getEmbeddingProducer(), "EmbeddingProducer should not be present"); + } + + @Test + void nucliaWithClassifier() { + Map options = new HashMap<>(); + options.put(Options.WRITE_NUCLIA_API_KEY, "test-api-key"); + options.put(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1"); + options.put(Options.WRITE_CLASSIFIER_HOST, "classifier-host"); + options.put(Options.WRITE_CLASSIFIER_PORT, "8080"); + options.put(Options.WRITE_CLASSIFIER_PATH, "/classify"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNotNull(pipeline, "Pipeline should be created with Nuclia and classifier"); + assertNotNull(pipeline.getNucliaClient(), "NucliaClient should be present"); + assertNotNull(pipeline.getTextClassifier(), "TextClassifier should be present even with Nuclia"); + } + + @Test + void nucliaHasPriorityOverStandardPipeline() { + Map options = new HashMap<>(); + // Nuclia options + options.put(Options.WRITE_NUCLIA_API_KEY, "test-api-key"); + options.put(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1"); + // Standard pipeline options (should be ignored) + options.put(Options.WRITE_EXTRACTED_TEXT, "true"); + options.put(Options.WRITE_SPLITTER_MAX_CHUNK_SIZE, "1000"); + Context context = new Context(options); + + DocumentPipeline pipeline = DocumentPipelineFactory.newDocumentPipeline(context); + + assertNotNull(pipeline, "Pipeline should be created"); + assertNotNull(pipeline.getNucliaClient(), "NucliaClient should be present"); + assertNull(pipeline.getTextExtractor(), "TextExtractor should be ignored when Nuclia is configured"); + assertNull(pipeline.getTextSplitter(), "TextSplitter should be ignored when Nuclia is configured"); + } +} diff --git a/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java index 50968d1a..2a524f9a 100644 --- a/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java +++ b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/embedding/AddEmbeddingsToJsonTest.java @@ -17,7 +17,6 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SaveMode; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; diff --git a/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/nuclia/NucliaAdHocTest.java b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/nuclia/NucliaAdHocTest.java new file mode 100644 index 00000000..7b4a1590 --- /dev/null +++ b/marklogic-spark-connector/src/test/java/com/marklogic/spark/writer/nuclia/NucliaAdHocTest.java @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023-2026 Progress Software Corporation and/or its subsidiaries or affiliates. All Rights Reserved. + */ +package com.marklogic.spark.writer.nuclia; + +import com.marklogic.spark.AbstractIntegrationTest; +import com.marklogic.spark.Options; +import org.apache.spark.sql.SaveMode; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +/** + * For manual testing of Nuclia integration. + */ +class NucliaAdHocTest extends AbstractIntegrationTest { + + @EnabledIfEnvironmentVariable( + named = "NUCLIA_API_KEY", matches = ".*" + ) + @Test + void nuclia() { + newSparkSession() + .read().format(CONNECTOR_IDENTIFIER) + .load("src/test/resources/extraction-files/armstrong_neil.pdf") + .write().format(CONNECTOR_IDENTIFIER) + .option(Options.CLIENT_URI, makeClientUri()) + .option(Options.WRITE_NUCLIA_API_KEY, System.getenv("NUCLIA_API_KEY")) + .option(Options.WRITE_NUCLIA_REGION, "aws-us-east-2-1") + .option(Options.WRITE_URI_REPLACE, ".*extraction-files,'/aaa'") + .option(Options.WRITE_PERMISSIONS, DEFAULT_PERMISSIONS) + .mode(SaveMode.Append) + .save(); + } +}