diff --git a/build.gradle.kts b/build.gradle.kts index 3a96f5341dbc..387491371131 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -321,6 +321,8 @@ tasks.register("javaPreCommit") { dependsOn(":sdks:java:io:xml:build") dependsOn(":sdks:java:javadoc:allJavadoc") dependsOn(":sdks:java:managed:build") + dependsOn("sdks:java:ml:inference:remote:build") + dependsOn("sdks:java:ml:inference:openai:build") dependsOn(":sdks:java:testing:expansion-service:build") dependsOn(":sdks:java:testing:jpms-tests:build") dependsOn(":sdks:java:testing:junit:build") diff --git a/sdks/java/extensions/ml/build.gradle b/sdks/java/extensions/ml/build.gradle index 708a44402df5..cb4a9f577ad6 100644 --- a/sdks/java/extensions/ml/build.gradle +++ b/sdks/java/extensions/ml/build.gradle @@ -26,6 +26,7 @@ applyJavaNature( ) description = 'Apache Beam :: SDKs :: Java :: Extensions :: ML' +ext.summary = """beam-sdks-java-extensions-ml provides Apache Beam Java SDK machine learning integration with Google Cloud AI Video Intelligence service. For machine learning run inference modules, see beam-sdks-java-ml-reference-* artifacts.""" dependencies { implementation project(path: ":sdks:java:core", configuration: "shadow") diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle new file mode 100644 index 000000000000..96de0cbe52fd --- /dev/null +++ b/sdks/java/ml/inference/openai/build.gradle @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'org.apache.beam.module' + id 'java' +} + +description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" + +dependencies { + implementation project(":sdks:java:ml:inference:remote") + implementation "com.openai:openai-java-core:4.3.0" + implementation "com.openai:openai-java-client-okhttp:4.3.0" + implementation library.java.jackson_databind + implementation library.java.jackson_annotations + implementation library.java.jackson_core + + testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") + testImplementation project(path: ":sdks:java:core", configuration: "shadow") + testImplementation library.java.slf4j_api + testRuntimeOnly library.java.slf4j_simple + testImplementation library.java.junit +} diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java new file mode 100644 index 000000000000..a7ebb1ea02a5 --- /dev/null +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.openai.client.OpenAIClient; +import com.openai.client.okhttp.OpenAIOkHttpClient; +import com.openai.core.JsonSchemaLocalValidation; +import com.openai.models.responses.ResponseCreateParams; +import com.openai.models.responses.StructuredResponseCreateParams; +import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; + +import java.util.List; +import java.util.stream.Collectors; + +/** + * Model handler for OpenAI API inference requests. + * + *

This handler manages communication with OpenAI's API, including client initialization, + * request formatting, and response parsing. It uses OpenAI's structured output feature to + * ensure reliable input-output pairing. + * + *

Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}")
+ *     .build();
+ *
+ * PCollection inputs = ...;
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ +public class OpenAIModelHandler + implements BaseModelHandler { + + private transient OpenAIClient client; + private OpenAIModelParameters modelParameters; + private transient ObjectMapper objectMapper; + + /** + * Initializes the OpenAI client with the provided parameters. + * + *

This method is called once during setup. It creates an authenticated + * OpenAI client using the API key from the parameters. + * + * @param parameters the configuration parameters including API key and model name + */ + @Override + public void createClient(OpenAIModelParameters parameters) { + this.modelParameters = parameters; + this.client = OpenAIOkHttpClient.builder() + .apiKey(this.modelParameters.getApiKey()) + .build(); + this.objectMapper = new ObjectMapper(); + } + + /** + * Performs inference on a batch of inputs using the OpenAI Client. + * + *

This method serializes the input batch to JSON string, sends it to OpenAI with structured + * output requirements, and parses the response into {@link PredictionResult} objects + * that pair each input with its corresponding output. + * + * @param input the list of inputs to process + * @return an iterable of model results and input pairs + */ + @Override + public Iterable> request(List input) { + + try { + // Convert input list to JSON string + String inputBatch = + objectMapper.writeValueAsString( + input.stream() + .map(OpenAIModelInput::getModelInput) + .collect(Collectors.toList())); + // Build structured response parameters + StructuredResponseCreateParams clientParams = ResponseCreateParams.builder() + .model(modelParameters.getModelName()) + .input(inputBatch) + .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) + .instructions(modelParameters.getInstructionPrompt()) + .build(); + + // Get structured output from the model + StructuredInputOutput structuredOutput = client.responses() + .create(clientParams) + .output() + .stream() + .flatMap(item -> item.message().stream()) + .flatMap(message -> message.content().stream()) + .flatMap(content -> content.outputText().stream()) + .findFirst() + .orElse(null); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + // return PredictionResults + return structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + + } catch (JsonProcessingException e) { + throw new RuntimeException("Failed to serialize input batch", e); + } + } + + /** + * Schema class for structured output response. + * + *

Represents a single input-output pair returned by the OpenAI API. + */ + public static class Response { + @JsonProperty(required = true) + @JsonPropertyDescription("The input string") + public String input; + + @JsonProperty(required = true) + @JsonPropertyDescription("The output string") + public String output; + } + + /** + * Schema class for structured output containing multiple responses. + * + *

This class defines the expected JSON structure for OpenAI's structured output, + * ensuring reliable parsing of batched inference results. + */ + public static class StructuredInputOutput { + @JsonProperty(required = true) + @JsonPropertyDescription("Array of input-output pairs") + public List responses; + } + +} diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java new file mode 100644 index 000000000000..65160a4548a4 --- /dev/null +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import org.apache.beam.sdk.ml.inference.remote.BaseInput; +/** + * Input for OpenAI model inference requests. + * + *

This class encapsulates text input to be sent to OpenAI models. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
+ * String text = input.getModelInput(); // "Translate to French: Hello"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelResponse + */ +public class OpenAIModelInput implements BaseInput { + + private final String input; + + private OpenAIModelInput(String input) { + + this.input = input; + } + + /** + * Returns the text input for the model. + * + * @return the input text string + */ + public String getModelInput() { + return input; + } + + /** + * Creates a new input instance with the specified text. + * + * @param input the text to send to the model + * @return a new {@link OpenAIModelInput} instance + */ + public static OpenAIModelInput create(String input) { + return new OpenAIModelInput(input); + } + +} diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java new file mode 100644 index 000000000000..2b2b04dfa94b --- /dev/null +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import org.apache.beam.sdk.ml.inference.remote.BaseModelParameters; + +/** + * Configuration parameters required for OpenAI model inference. + * + *

This class encapsulates all configuration needed to initialize and communicate with + * OpenAI's API, including authentication credentials, model selection, and inference instructions. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("sk-...")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Translate the following text to French:")
+ *     .build();
+ * }
+ * + * @see OpenAIModelHandler + */ +public class OpenAIModelParameters implements BaseModelParameters { + + private final String apiKey; + private final String modelName; + private final String instructionPrompt; + + private OpenAIModelParameters(Builder builder) { + this.apiKey = builder.apiKey; + this.modelName = builder.modelName; + this.instructionPrompt = builder.instructionPrompt; + } + + public String getApiKey() { + return apiKey; + } + + public String getModelName() { + return modelName; + } + + public String getInstructionPrompt() { + return instructionPrompt; + } + + public static Builder builder() { + return new Builder(); + } + + + public static class Builder { + private String apiKey; + private String modelName; + private String instructionPrompt; + + private Builder() { + } + + /** + * Sets the OpenAI API key for authentication. + * + * @param apiKey the API key (required) + */ + public Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + /** + * Sets the name of the OpenAI model to use. + * + * @param modelName the model name, e.g., "gpt-4" (required) + */ + public Builder modelName(String modelName) { + this.modelName = modelName; + return this; + } + /** + * Sets the instruction prompt for the model. + * This prompt provides context or instructions to the model about how to process + * the input text. + * + * @param prompt the instruction text (required) + */ + public Builder instructionPrompt(String prompt) { + this.instructionPrompt = prompt; + return this; + } + + /** + * Builds the {@link OpenAIModelParameters} instance. + */ + public OpenAIModelParameters build() { + return new OpenAIModelParameters(this); + } + } +} diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java new file mode 100644 index 000000000000..f1c92bc765f8 --- /dev/null +++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import org.apache.beam.sdk.ml.inference.remote.BaseResponse; + +/** + * Response from OpenAI model inference results. + *

This class encapsulates the text output returned from OpenAI models.. + * + *

Example Usage

+ *
{@code
+ * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
+ * String output = response.getModelResponse(); // "Bonjour"
+ * }
+ * + * @see OpenAIModelHandler + * @see OpenAIModelInput + */ +public class OpenAIModelResponse implements BaseResponse { + + private final String output; + + private OpenAIModelResponse(String output) { + this.output = output; + } + + /** + * Returns the text output from the model. + * + * @return the output text string + */ + public String getModelResponse() { + return output; + } + + /** + * Creates a new response instance with the specified output text. + * + * @param output the text returned by the model + * @return a new {@link OpenAIModelResponse} instance + */ + public static OpenAIModelResponse create(String output) { + return new OpenAIModelResponse(output); + } +} diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java new file mode 100644 index 000000000000..ba03bce86988 --- /dev/null +++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.MapElements; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.TypeDescriptor; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.beam.sdk.ml.inference.remote.RemoteInference; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; + +public class OpenAIModelHandlerIT { + private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class); + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + private String apiKey; + private static final String API_KEY_ENV = "OPENAI_API_KEY"; + private static final String DEFAULT_MODEL = "gpt-4o-mini"; + + + @Before + public void setUp() { + // Get API key + apiKey = System.getenv(API_KEY_ENV); + + // Skip tests if API key is not provided + assumeNotNull( + "OpenAI API key not found. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + apiKey); + assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV + + " environment variable to run integration tests.", + !apiKey.trim().isEmpty()); + } + + @Test + public void testSentimentAnalysisWithSingleInput() { + String input = "This product is absolutely amazing! I love it!"; + + PCollection inputs = pipeline + .apply("CreateSingleInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze the sentiment as 'positive' or 'negative'. Return only one word.") + .build())); + + // Verify results + PAssert.that(results).satisfies(batches -> { + int count = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + count++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertNotNull("Output text should not be null", + result.getOutput().getModelResponse()); + + String sentiment = result.getOutput().getModelResponse().toLowerCase(); + assertTrue("Sentiment should be positive or negative, got: " + sentiment, + sentiment.contains("positive") + || sentiment.contains("negative")); + } + } + assertEquals("Should have exactly 1 result", 1, count); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testSentimentAnalysisWithMultipleInputs() { + List inputs = Arrays.asList( + "An excellent B2B SaaS solution that streamlines business processes efficiently.", + "The customer support is terrible. I've been waiting for days without any response.", + "The application works as expected. Installation was straightforward.", + "Really impressed with the innovative features! The AI capabilities are groundbreaking!", + "Mediocre product with occasional glitches. Documentation could be better."); + + PCollection inputCollection = pipeline + .apply("CreateMultipleInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("SentimentInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Analyze sentiment as positive or negative") + .build())); + + // Verify we get results for all inputs + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + assertEquals("Should have results for all 5 inputs", 5, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testTextClassification() { + List inputs = Arrays.asList( + "How do I reset my password?", + "Your product is broken and I want a refund!", + "Thank you for the excellent service!"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("ClassificationInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.") + .build())); + + PAssert.that(results).satisfies(batches -> { + List categories = new ArrayList<>(); + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String category = result.getOutput().getModelResponse().toLowerCase(); + categories.add(category); + } + } + + assertEquals("Should have 3 categories", 3, categories.size()); + + // Verify expected categories + boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question")); + boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint")); + boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise")); + + assertTrue("Should have at least one recognized category", + hasQuestion || hasComplaint || hasPraise); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInputOutputMapping() { + List inputs = Arrays.asList("apple", "banana", "cherry"); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs)) + .apply("MapToInputs", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputCollection + .apply("MappingInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Return the input word in uppercase") + .build())); + + // Verify input-output pairing is preserved + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String input = result.getInput().getModelInput(); + String output = result.getOutput().getModelResponse().toLowerCase(); + + // Verify the output relates to the input + assertTrue("Output should relate to input '" + input + "', got: " + output, + output.contains(input.toLowerCase())); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithDifferentModel() { + // Test with a different model + String input = "Explain quantum computing in one sentence."; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("DifferentModelInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("gpt-5") + .instructionPrompt("Respond concisely") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + assertNotNull("Output should not be null", + result.getOutput().getModelResponse()); + assertFalse("Output should not be empty", + result.getOutput().getModelResponse().trim().isEmpty()); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithInvalidApiKey() { + String input = "Test input"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply("InvalidKeyInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey("invalid-api-key-12345") + .modelName(DEFAULT_MODEL) + .instructionPrompt("Test") + .build())); + + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline failure due to invalid API key"); + } catch (Exception e) { + String msg = e.toString().toLowerCase(); + + assertTrue( + "Expected retry exhaustion or API key issue. Got: " + msg, + msg.contains("exhaust") || + msg.contains("max retries") || + msg.contains("401") || + msg.contains("api key") || + msg.contains("incorrect api key") + ); + } + } + + /** + * Test with custom instruction formats + */ + @Test + public void testWithJsonOutputFormat() { + String input = "Paris is the capital of France"; + + PCollection inputs = pipeline + .apply("CreateInput", Create.of(input)) + .apply("MapToInput", MapElements + .into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + PCollection>> results = inputs + .apply("JsonFormatInference", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters(OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName(DEFAULT_MODEL) + .instructionPrompt( + "Extract the city and country. Return as: City: [city], Country: [country]") + .build())); + + PAssert.that(results).satisfies(batches -> { + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + String output = result.getOutput().getModelResponse(); + LOG.info("Structured output: " + output); + + // Verify output contains expected information + assertTrue("Output should mention Paris: " + output, + output.toLowerCase().contains("paris")); + assertTrue("Output should mention France: " + output, + output.toLowerCase().contains("france")); + } + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testRetryWithInvalidModel() { + + PCollection inputs = + pipeline + .apply("CreateInput", Create.of("Test input")) + .apply("MapToInput", + MapElements.into(TypeDescriptor.of(OpenAIModelInput.class)) + .via(OpenAIModelInput::create)); + + inputs.apply( + "FailingOpenAIRequest", + RemoteInference.invoke() + .handler(OpenAIModelHandler.class) + .withParameters( + OpenAIModelParameters.builder() + .apiKey(apiKey) + .modelName("fake-model") + .instructionPrompt("test retry") + .build())); + + try { + pipeline.run().waitUntilFinish(); + fail("Pipeline should fail after retry exhaustion."); + } catch (Exception e) { + String message = e.getMessage().toLowerCase(); + + assertTrue( + "Expected retry-exhaustion error. Actual: " + message, + message.contains("exhaust") || + message.contains("retry") || + message.contains("max retries") || + message.contains("request failed") || + message.contains("fake-model")); + } + } + +} diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java new file mode 100644 index 000000000000..0250c559fe65 --- /dev/null +++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java @@ -0,0 +1,450 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.openai; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput; +import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response; +import org.apache.beam.sdk.ml.inference.remote.PredictionResult; + + + +@RunWith(JUnit4.class) +public class OpenAIModelHandlerTest { + private OpenAIModelParameters testParameters; + + @Before + public void setUp() { + testParameters = OpenAIModelParameters.builder() + .apiKey("test-api-key") + .modelName("gpt-4") + .instructionPrompt("Test instruction") + .build(); + } + + /** + * Fake OpenAiModelHandler for testing. + */ + static class FakeOpenAiModelHandler extends OpenAIModelHandler { + + private boolean clientCreated = false; + private OpenAIModelParameters storedParameters; + private List responsesToReturn; + private RuntimeException exceptionToThrow; + private boolean shouldReturnNull = false; + + public void setResponsesToReturn(List responses) { + this.responsesToReturn = responses; + } + + public void setExceptionToThrow(RuntimeException exception) { + this.exceptionToThrow = exception; + } + + public void setShouldReturnNull(boolean shouldReturnNull) { + this.shouldReturnNull = shouldReturnNull; + } + + public boolean isClientCreated() { + return clientCreated; + } + + public OpenAIModelParameters getStoredParameters() { + return storedParameters; + } + + @Override + public void createClient(OpenAIModelParameters parameters) { + this.storedParameters = parameters; + this.clientCreated = true; + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + } + + @Override + public Iterable> request( + List input) { + + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + + if (shouldReturnNull || responsesToReturn == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + StructuredInputOutput structuredOutput = responsesToReturn.get(0); + + if (structuredOutput == null || structuredOutput.responses == null) { + throw new RuntimeException("Model returned no structured responses"); + } + + return structuredOutput.responses.stream() + .map(response -> PredictionResult.create( + OpenAIModelInput.create(response.input), + OpenAIModelResponse.create(response.output))) + .collect(Collectors.toList()); + } + } + + @Test + public void testCreateClient() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("test-key") + .modelName("gpt-4") + .instructionPrompt("test prompt") + .build(); + + handler.createClient(params); + + assertTrue("Client should be created", handler.isClientCreated()); + assertNotNull("Parameters should be stored", handler.getStoredParameters()); + assertEquals("test-key", handler.getStoredParameters().getApiKey()); + assertEquals("gpt-4", handler.getStoredParameters().getModelName()); + } + + @Test + public void testRequestWithSingleInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + assertNotNull("Results should not be null", results); + + List> resultList = iterableToList(results); + + assertEquals("Should have 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test input", result.getInput().getModelInput()); + assertEquals("test output", result.getOutput().getModelResponse()); + } + + @Test + public void testRequestWithMultipleInputs() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("input1"), + OpenAIModelInput.create("input2"), + OpenAIModelInput.create("input3")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "input1"; + response1.output = "output1"; + + Response response2 = new Response(); + response2.input = "input2"; + response2.output = "output2"; + + Response response3 = new Response(); + response3.input = "input3"; + response3.output = "output3"; + + structuredOutput.responses = Arrays.asList(response1, response2, response3); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals("Should have 3 results", 3, resultList.size()); + + for (int i = 0; i < 3; i++) { + PredictionResult result = resultList.get(i); + assertEquals("input" + (i + 1), result.getInput().getModelInput()); + assertEquals("output" + (i + 1), result.getOutput().getModelResponse()); + } + } + + @Test + public void testRequestWithEmptyInput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.emptyList(); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = Collections.emptyList(); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + assertEquals("Should have 0 results", 0, resultList.size()); + } + + @Test + public void testRequestWithNullStructuredOutput() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.setShouldReturnNull(true); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when structured output is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testRequestWithNullResponsesList() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + structuredOutput.responses = null; + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + try { + handler.request(inputs); + fail("Expected RuntimeException when responses list is null"); + } catch (RuntimeException e) { + assertTrue("Exception message should mention no structured responses", + e.getMessage().contains("Model returned no structured responses")); + } + } + + @Test + public void testCreateClientFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.setExceptionToThrow(new RuntimeException("Setup failed")); + + try { + handler.createClient(testParameters); + fail("Expected RuntimeException during client creation"); + } catch (RuntimeException e) { + assertEquals("Setup failed", e.getMessage()); + } + } + + @Test + public void testRequestApiFailure() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + handler.createClient(testParameters); + handler.setExceptionToThrow(new RuntimeException("API Error")); + + try { + handler.request(inputs); + fail("Expected RuntimeException when API fails"); + } catch (RuntimeException e) { + assertEquals("API Error", e.getMessage()); + } + } + + @Test + public void testRequestWithoutClientInitialization() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Collections.singletonList( + OpenAIModelInput.create("test input")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + Response response = new Response(); + response.input = "test input"; + response.output = "test output"; + structuredOutput.responses = Collections.singletonList(response); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + + // Don't call createClient + try { + handler.request(inputs); + fail("Expected IllegalStateException when client not initialized"); + } catch (IllegalStateException e) { + assertTrue("Exception should mention client not initialized", + e.getMessage().contains("Client not initialized")); + } + } + + @Test + public void testInputOutputMapping() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + + List inputs = Arrays.asList( + OpenAIModelInput.create("alpha"), + OpenAIModelInput.create("beta")); + + StructuredInputOutput structuredOutput = new StructuredInputOutput(); + + Response response1 = new Response(); + response1.input = "alpha"; + response1.output = "ALPHA"; + + Response response2 = new Response(); + response2.input = "beta"; + response2.output = "BETA"; + + structuredOutput.responses = Arrays.asList(response1, response2); + + handler.setResponsesToReturn(Collections.singletonList(structuredOutput)); + handler.createClient(testParameters); + + Iterable> results = handler.request(inputs); + + List> resultList = iterableToList(results); + + assertEquals(2, resultList.size()); + assertEquals("alpha", resultList.get(0).getInput().getModelInput()); + assertEquals("ALPHA", resultList.get(0).getOutput().getModelResponse()); + + assertEquals("beta", resultList.get(1).getInput().getModelInput()); + assertEquals("BETA", resultList.get(1).getOutput().getModelResponse()); + } + + @Test + public void testParametersBuilder() { + OpenAIModelParameters params = OpenAIModelParameters.builder() + .apiKey("my-api-key") + .modelName("gpt-4-turbo") + .instructionPrompt("Custom instruction") + .build(); + + assertEquals("my-api-key", params.getApiKey()); + assertEquals("gpt-4-turbo", params.getModelName()); + assertEquals("Custom instruction", params.getInstructionPrompt()); + } + + @Test + public void testOpenAIModelInputCreate() { + OpenAIModelInput input = OpenAIModelInput.create("test value"); + + assertNotNull("Input should not be null", input); + assertEquals("test value", input.getModelInput()); + } + + @Test + public void testOpenAIModelResponseCreate() { + OpenAIModelResponse response = OpenAIModelResponse.create("test output"); + + assertNotNull("Response should not be null", response); + assertEquals("test output", response.getModelResponse()); + } + + @Test + public void testStructuredInputOutputStructure() { + Response response = new Response(); + response.input = "test-input"; + response.output = "test-output"; + + assertEquals("test-input", response.input); + assertEquals("test-output", response.output); + + StructuredInputOutput structured = new StructuredInputOutput(); + structured.responses = Collections.singletonList(response); + + assertNotNull("Responses should not be null", structured.responses); + assertEquals("Should have 1 response", 1, structured.responses.size()); + assertEquals("test-input", structured.responses.get(0).input); + } + + @Test + public void testMultipleRequestsWithSameHandler() { + FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler(); + handler.createClient(testParameters); + + // First request + StructuredInputOutput output1 = new StructuredInputOutput(); + Response response1 = new Response(); + response1.input = "first"; + response1.output = "FIRST"; + output1.responses = Collections.singletonList(response1); + handler.setResponsesToReturn(Collections.singletonList(output1)); + + List inputs1 = Collections.singletonList( + OpenAIModelInput.create("first")); + Iterable> results1 = handler.request(inputs1); + + List> resultList1 = iterableToList(results1); + assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse()); + + // Second request with different data + StructuredInputOutput output2 = new StructuredInputOutput(); + Response response2 = new Response(); + response2.input = "second"; + response2.output = "SECOND"; + output2.responses = Collections.singletonList(response2); + handler.setResponsesToReturn(Collections.singletonList(output2)); + + List inputs2 = Collections.singletonList( + OpenAIModelInput.create("second")); + Iterable> results2 = handler.request(inputs2); + + List> resultList2 = iterableToList(results2); + assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse()); + } + + // Helper method to convert Iterable to List + private List iterableToList(Iterable iterable) { + List list = new java.util.ArrayList<>(); + iterable.forEach(list::add); + return list; + } +} diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle new file mode 100644 index 000000000000..7cbea0c594d2 --- /dev/null +++ b/sdks/java/ml/inference/remote/build.gradle @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +plugins { + id 'org.apache.beam.module' + id 'java-library' +} + +description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote" + +dependencies { + // Core Beam SDK + implementation project(path: ":sdks:java:core", configuration: "shadow") + + compileOnly "com.google.auto.value:auto-value-annotations:1.11.0" + annotationProcessor "com.google.auto.value:auto-value:1.11.0" + implementation library.java.checker_qual; + implementation library.java.vendored_guava_32_1_2_jre + implementation library.java.slf4j_api + implementation library.java.joda_time + + // testing + testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") + testImplementation library.java.junit + testRuntimeOnly library.java.hamcrest + testRuntimeOnly library.java.slf4j_simple +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java new file mode 100644 index 000000000000..73bc43684a94 --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.io.Serializable; + +/** + * Base class for defining input types used with remote inference transforms. + *Implementations holds the data needed for inference (text, images, etc.) + */ +public interface BaseInput extends Serializable { + +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java new file mode 100644 index 000000000000..314aec34cf9b --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.util.List; + +/** + * Interface for model-specific handlers that perform remote inference operations. + * + *

Implementations of this interface encapsulate all logic for communicating with a + * specific remote inference service. Each handler is responsible for: + *

    + *
  • Initializing and managing client connections
  • + *
  • Converting Beam inputs to service-specific request formats
  • + *
  • Making inference API calls
  • + *
  • Converting service responses to Beam output types
  • + *
  • Handling errors and retries if applicable
  • + *
+ * + *

Lifecycle

+ * + *

Handler instances follow this lifecycle: + *

    + *
  1. Instantiation via no-argument constructor
  2. + *
  3. {@link #createClient} called with parameters during setup
  4. + *
  5. {@link #request} called for each batch of inputs
  6. + *
+ * + * + *

Handlers typically contain non-serializable client objects. + * Mark client fields as {@code transient} and initialize them in {@link #createClient} + * + *

Batching Considerations

+ * + *

The {@link #request} method receives a list of inputs. Implementations should: + *

    + *
  • Batch inputs efficiently if the service supports batch inference
  • + *
  • Return results in the same order as inputs
  • + *
  • Maintain input-output correspondence in {@link PredictionResult}
  • + *
+ * + */ +public interface BaseModelHandler { + /** + * Initializes the remote model client with the provided parameters. + */ + public void createClient(ParamT parameters); + + /** + * Performs inference on a batch of inputs and returns the results. + */ + public Iterable> request(List input); + +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java new file mode 100644 index 000000000000..f285377da977 --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.io.Serializable; + +/** + * Base interface for defining model-specific parameters used to configure remote inference clients. + * + *

Implementations of this interface encapsulate all configuration needed to initialize + * and communicate with a remote model inference service. This typically includes: + *

    + *
  • Authentication credentials (API keys, tokens)
  • + *
  • Model identifiers or names
  • + *
  • Endpoint URLs or connection settings
  • + *
  • Inference configuration (temperature, max tokens, timeout values, etc.)
  • + *
+ * + *

Parameters must be serializable. Consider using + * the builder pattern for complex parameter objects. + * + */ +public interface BaseModelParameters extends Serializable { + +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java new file mode 100644 index 000000000000..b92a8e2d4228 --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.io.Serializable; + +/** + * Base class for defining response types returned from remote inference operations. + + *

Implementations: + *

    + *
  • Contain the inference results (predictions, classifications, generated text, etc.)
  • + *
  • Includes any relevant metadata
  • + *
+ * + */ +public interface BaseResponse extends Serializable { + +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java new file mode 100644 index 000000000000..bf1ae66127cf --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.io.Serializable; + +/** + * Pairs an input with its corresponding inference output. + * + *

This class maintains the association between input data and its model's results + * for Downstream processing + */ +public class PredictionResult implements Serializable { + + private final InputT input; + private final OutputT output; + + private PredictionResult(InputT input, OutputT output) { + this.input = input; + this.output = output; + + } + + /* Returns input to handler */ + public InputT getInput() { + return input; + } + + /* Returns model handler's response*/ + public OutputT getOutput() { + return output; + } + + /* Creates a PredictionResult instance of provided input, output and types */ + public static PredictionResult create(InputT input, OutputT output) { + return new PredictionResult<>(input, output); + } +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java new file mode 100644 index 000000000000..da9217bfd52e --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import org.apache.beam.sdk.transforms.*; +import org.checkerframework.checker.nullness.qual.Nullable; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import org.apache.beam.sdk.values.PCollection; +import com.google.auto.value.AutoValue; + +import java.util.Collections; +import java.util.List; + +/** + * A {@link PTransform} for making remote inference calls to external machine learning services. + * + *

{@code RemoteInference} provides a framework for integrating remote ML model + * inference into Apache Beam pipelines and handles the communication between pipelines + * and external inference APIs. + * + *

Example: OpenAI Model Inference

+ * + *
{@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ *     .apiKey("your-api-key")
+ *     .modelName("gpt-4")
+ *     .instructionPrompt("Analyse sentiment as positive or negative")
+ *     .build();
+ *
+ * // Apply remote inference transform
+ * PCollection inputs = pipeline.apply(Create.of(
+ *     OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."),
+ *     OpenAIModelInput.create("Really impressed with the innovative features!")
+ * ));
+ *
+ * PCollection>> results =
+ *     inputs.apply(
+ *         RemoteInference.invoke()
+ *             .handler(OpenAIModelHandler.class)
+ *             .withParameters(params)
+ *     );
+ * }
+ * + */ +@SuppressWarnings({ "rawtypes", "unchecked" }) +public class RemoteInference { + + /** Invoke the model handler with model parameters */ + public static Invoke invoke() { + return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null) + .build(); + } + + private RemoteInference() { + } + + @AutoValue + public abstract static class Invoke + extends PTransform, PCollection>>> { + + abstract @Nullable Class handler(); + + abstract @Nullable BaseModelParameters parameters(); + + + abstract Builder builder(); + + @AutoValue.Builder + abstract static class Builder { + + abstract Builder setHandler(Class modelHandler); + + abstract Builder setParameters(BaseModelParameters modelParameters); + + + abstract Invoke build(); + } + + /** + * Model handler class for inference. + */ + public Invoke handler(Class modelHandler) { + return builder().setHandler(modelHandler).build(); + } + + /** + * Configures the parameters for model initialization. + */ + public Invoke withParameters(BaseModelParameters modelParameters) { + return builder().setParameters(modelParameters).build(); + } + + + @Override + public PCollection>> expand(PCollection input) { + checkArgument(handler() != null, "handler() is required"); + checkArgument(parameters() != null, "withParameters() is required"); + return input + .apply("WrapInputInList", MapElements.via(new SimpleFunction>() { + @Override + public List apply(InputT element) { + return Collections.singletonList(element); + } + })) + // Pass the list to the inference function + .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this))); + } + + /** + * A {@link DoFn} that performs remote inference operation. + * + *

This function manages the lifecycle of the model handler: + *

    + *
  • Instantiates the handler during {@link Setup}
  • + *
  • Initializes the remote client via {@link BaseModelHandler#createClient}
  • + *
  • Processes elements by calling {@link BaseModelHandler#request}
  • + *
+ */ + static class RemoteInferenceFn + extends DoFn, Iterable>> { + + private final Class handlerClass; + private final BaseModelParameters parameters; + private transient BaseModelHandler modelHandler; + private final RetryHandler retryHandler; + + RemoteInferenceFn(Invoke spec) { + this.handlerClass = spec.handler(); + this.parameters = spec.parameters(); + retryHandler = RetryHandler.withDefaults(); + } + + /** Instantiate the model handler and client*/ + @Setup + public void setupHandler() { + try { + this.modelHandler = handlerClass.getDeclaredConstructor().newInstance(); + this.modelHandler.createClient(parameters); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate handler: " + + handlerClass.getName(), e); + } + } + /** Perform Inference */ + @ProcessElement + public void processElement(ProcessContext c) throws Exception { + Iterable> response = retryHandler + .execute(() -> modelHandler.request(c.element())); + c.output(response); + } + } + + } +} diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java new file mode 100644 index 000000000000..27041d8cb237 --- /dev/null +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import org.apache.beam.sdk.util.BackOff; +import org.apache.beam.sdk.util.FluentBackoff; +import org.apache.beam.sdk.util.Sleeper; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; + +/** + * A utility for running request and handle failures and retries. + */ +public class RetryHandler implements Serializable { + + private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class); + + private final int maxRetries; + private final Duration initialBackoff; + private final Duration maxBackoff; + private final Duration maxCumulativeBackoff; + + private RetryHandler( + int maxRetries, + Duration initialBackoff, + Duration maxBackoff, + Duration maxCumulativeBackoff) { + this.maxRetries = maxRetries; + this.initialBackoff = initialBackoff; + this.maxBackoff = maxBackoff; + this.maxCumulativeBackoff = maxCumulativeBackoff; + } + + public static RetryHandler withDefaults() { + return new RetryHandler( + 3, // maxRetries + Duration.standardSeconds(1), // initialBackoff + Duration.standardSeconds(10), // maxBackoff per retry + Duration.standardMinutes(1) // maxCumulativeBackoff + ); + } + + public T execute(RetryableRequest request) throws Exception { + BackOff backoff = FluentBackoff.DEFAULT + .withMaxRetries(maxRetries) + .withInitialBackoff(initialBackoff) + .withMaxBackoff(maxBackoff) + .withMaxCumulativeBackoff(maxCumulativeBackoff) + .backoff(); + + Sleeper sleeper = Sleeper.DEFAULT; + Exception lastException; + int attempt = 0; + + while (true) { + try { + return request.call(); + + } catch (Exception e) { + lastException = e; + + long backoffMillis = backoff.nextBackOffMillis(); + + if (backoffMillis == BackOff.STOP) { + LOG.error("Request failed after {} retry attempts.", attempt); + throw new RuntimeException( + "Request failed after exhausting retries. " + + "Max retries: " + maxRetries + ", " , + lastException); + } + + attempt++; + LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis); + + sleeper.sleep(backoffMillis); + } + } + } + + @FunctionalInterface + public interface RetryableRequest { + + T call() throws Exception; + } +} diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java new file mode 100644 index 000000000000..41e4be2dcb33 --- /dev/null +++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * License); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an AS IS BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; + +import org.apache.beam.sdk.coders.SerializableCoder; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + + +@RunWith(JUnit4.class) +public class RemoteInferenceTest { + + @Rule + public final transient TestPipeline pipeline = TestPipeline.create(); + + // Test input class + public static class TestInput implements BaseInput { + private final String value; + + private TestInput(String value) { + this.value = value; + } + + public static TestInput create(String value) { + return new TestInput(value); + } + + public String getModelInput() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestInput)) + return false; + TestInput testInput = (TestInput) o; + return value.equals(testInput.value); + } + + @Override + public int hashCode() { + return value.hashCode(); + } + + @Override + public String toString() { + return "TestInput{value='" + value + "'}"; + } + } + + // Test output class + public static class TestOutput implements BaseResponse { + private final String result; + + private TestOutput(String result) { + this.result = result; + } + + public static TestOutput create(String result) { + return new TestOutput(result); + } + + public String getModelResponse() { + return result; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestOutput)) + return false; + TestOutput that = (TestOutput) o; + return result.equals(that.result); + } + + @Override + public int hashCode() { + return result.hashCode(); + } + + @Override + public String toString() { + return "TestOutput{result='" + result + "'}"; + } + } + + // Test parameters class + public static class TestParameters implements BaseModelParameters { + private final String config; + + private TestParameters(Builder builder) { + this.config = builder.config; + } + + public String getConfig() { + return config; + } + + @Override + public String toString() { + return "TestParameters{config='" + config + "'}"; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof TestParameters)) + return false; + TestParameters that = (TestParameters) o; + return config.equals(that.config); + } + + @Override + public int hashCode() { + return config.hashCode(); + } + + // Builder + public static class Builder { + private String config; + + public Builder setConfig(String config) { + this.config = config; + return this; + } + + public TestParameters build() { + return new TestParameters(this); + } + } + + public static Builder builder() { + return new Builder(); + } + } + + // Mock handler for successful inference + public static class MockSuccessHandler + implements BaseModelHandler { + + private TestParameters parameters; + private boolean clientCreated = false; + + @Override + public void createClient(TestParameters parameters) { + this.parameters = parameters; + this.clientCreated = true; + } + + @Override + public Iterable> request(List input) { + if (!clientCreated) { + throw new IllegalStateException("Client not initialized"); + } + return input.stream() + .map(i -> PredictionResult.create( + i, + new TestOutput("processed-" + i.getModelInput()))) + .collect(Collectors.toList()); + } + } + + // Mock handler that returns empty results + public static class MockEmptyResultHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during setup + public static class MockFailingSetupHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + throw new RuntimeException("Setup failed intentionally"); + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + // Mock handler that throws exception during request + public static class MockFailingRequestHandler + implements BaseModelHandler { + + @Override + public void createClient(TestParameters parameters) { + // Setup succeeds + } + + @Override + public Iterable> request(List input) { + throw new RuntimeException("Request failed intentionally"); + } + } + + // Mock handler without default constructor (to test error handling) + public static class MockNoDefaultConstructorHandler + implements BaseModelHandler { + + private final String required; + + public MockNoDefaultConstructorHandler(String required) { + this.required = required; + } + + @Override + public void createClient(TestParameters parameters) { + } + + @Override + public Iterable> request(List input) { + return Collections.emptyList(); + } + } + + private static boolean containsMessage(Throwable e, String message) { + Throwable current = e; + while (current != null) { + if (current.getMessage() != null && current.getMessage().contains(message)) { + return true; + } + current = current.getCause(); + } + return false; + } + + @Test + public void testInvokeWithSingleElement() { + TestInput input = TestInput.create("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline.apply(Create.of(input)); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Verify the output contains expected predictions + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + + assertEquals("Expected exactly 1 result", 1, resultList.size()); + + PredictionResult result = resultList.get(0); + assertEquals("test-value", result.getInput().getModelInput()); + assertEquals("processed-test-value", result.getOutput().getModelResponse()); + + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithMultipleElements() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2"), + new TestInput("input3")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // Count total results across all batches + PAssert.that(results).satisfies(batches -> { + int totalCount = 0; + for (Iterable> batch : batches) { + for (PredictionResult result : batch) { + totalCount++; + assertTrue("Output should start with 'processed-'", + result.getOutput().getModelResponse().startsWith("processed-")); + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + } + } + assertEquals("Expected 3 total results", 3, totalCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testInvokeWithEmptyCollection() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + // assertion for empty PCollection + PAssert.that(results).empty(); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerReturnsEmptyResults() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockEmptyResultHandler.class) + .withParameters(params)); + + // Verify we still get a result, but it's empty + PAssert.thatSingleton(results).satisfies(batch -> { + List> resultList = StreamSupport.stream(batch.spliterator(), false) + .collect(Collectors.toList()); + assertEquals("Expected empty result list", 0, resultList.size()); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testHandlerSetupFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingSetupHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to handler setup failure"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention setup failure or handler instantiation failure", + message != null && (message.contains("Setup failed intentionally") || + message.contains("Failed to instantiate handler"))); + } + } + + @Test + public void testHandlerRequestFailure() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockFailingRequestHandler.class) + .withParameters(params)); + + // Verify pipeline fails with expected error + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to request failure"); + } catch (Exception e) { + + assertTrue( + "Expected 'Request failed intentionally' in exception chain", + containsMessage(e, "Request failed intentionally")); + } + } + + @Test + public void testHandlerWithoutDefaultConstructor() { + TestInput input = new TestInput("test-value"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockNoDefaultConstructorHandler.class) + .withParameters(params)); + + // Verify pipeline fails when handler cannot be instantiated + try { + pipeline.run().waitUntilFinish(); + fail("Expected pipeline to fail due to missing default constructor"); + } catch (Exception e) { + String message = e.getMessage(); + assertTrue("Exception should mention handler instantiation failure", + message != null && message.contains("Failed to instantiate handler")); + } + } + + @Test + public void testBuilderPattern() { + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + RemoteInference.Invoke transform = RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params); + + assertNotNull("Transform should not be null", transform); + } + + @Test + public void testPredictionResultMapping() { + TestInput input = new TestInput("mapping-test"); + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.thatSingleton(results).satisfies(batch -> { + for (PredictionResult result : batch) { + // Verify that input is preserved in the result + assertNotNull("Input should not be null", result.getInput()); + assertNotNull("Output should not be null", result.getOutput()); + assertEquals("mapping-test", result.getInput().getModelInput()); + assertTrue("Output should contain input value", + result.getOutput().getModelResponse().contains("mapping-test")); + } + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + // Temporary behaviour until we introduce java BatchElements transform + // to batch elements in RemoteInference + @Test + public void testMultipleInputsProduceSeparateBatches() { + List inputs = Arrays.asList( + new TestInput("input1"), + new TestInput("input2")); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + PCollection inputCollection = pipeline + .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class))); + + PCollection>> results = inputCollection + .apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class) + .withParameters(params)); + + PAssert.that(results).satisfies(batches -> { + int batchCount = 0; + for (Iterable> batch : batches) { + batchCount++; + int elementCount = 0; + elementCount += StreamSupport.stream(batch.spliterator(), false).count(); + // Each batch should contain exactly 1 element + assertEquals("Each batch should contain 1 element", 1, elementCount); + } + assertEquals("Expected 2 batches", 2, batchCount); + return null; + }); + + pipeline.run().waitUntilFinish(); + } + + @Test + public void testWithEmptyParameters() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .handler(MockSuccessHandler.class))); + + assertTrue( + "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("withParameters() is required")); + } + + @Test + public void testWithEmptyHandler() { + + pipeline.enableAbandonedNodeEnforcement(false); + + TestParameters params = TestParameters.builder() + .setConfig("test-config") + .build(); + + TestInput input = TestInput.create("test-value"); + PCollection inputCollection = pipeline.apply(Create.of(input)); + + IllegalArgumentException thrown = assertThrows( + IllegalArgumentException.class, + () -> inputCollection.apply("RemoteInference", + RemoteInference.invoke() + .withParameters(params))); + + assertTrue( + "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(), + thrown.getMessage().contains("handler() is required")); + } +} diff --git a/settings.gradle.kts b/settings.gradle.kts index 72c5194ec93d..d3a57ed6b1a4 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -383,3 +383,6 @@ include("sdks:java:extensions:sql:iceberg") findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg" include("examples:java:iceberg") findProject(":examples:java:iceberg")?.name = "iceberg" + +include("sdks:java:ml:inference:remote") +include("sdks:java:ml:inference:openai")