diff --git a/build.gradle.kts b/build.gradle.kts
index 3a96f5341dbc..387491371131 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -321,6 +321,8 @@ tasks.register("javaPreCommit") {
dependsOn(":sdks:java:io:xml:build")
dependsOn(":sdks:java:javadoc:allJavadoc")
dependsOn(":sdks:java:managed:build")
+ dependsOn("sdks:java:ml:inference:remote:build")
+ dependsOn("sdks:java:ml:inference:openai:build")
dependsOn(":sdks:java:testing:expansion-service:build")
dependsOn(":sdks:java:testing:jpms-tests:build")
dependsOn(":sdks:java:testing:junit:build")
diff --git a/sdks/java/extensions/ml/build.gradle b/sdks/java/extensions/ml/build.gradle
index 708a44402df5..cb4a9f577ad6 100644
--- a/sdks/java/extensions/ml/build.gradle
+++ b/sdks/java/extensions/ml/build.gradle
@@ -26,6 +26,7 @@ applyJavaNature(
)
description = 'Apache Beam :: SDKs :: Java :: Extensions :: ML'
+ext.summary = """beam-sdks-java-extensions-ml provides Apache Beam Java SDK machine learning integration with Google Cloud AI Video Intelligence service. For machine learning run inference modules, see beam-sdks-java-ml-reference-* artifacts."""
dependencies {
implementation project(path: ":sdks:java:core", configuration: "shadow")
diff --git a/sdks/java/ml/inference/openai/build.gradle b/sdks/java/ml/inference/openai/build.gradle
new file mode 100644
index 000000000000..96de0cbe52fd
--- /dev/null
+++ b/sdks/java/ml/inference/openai/build.gradle
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+plugins {
+ id 'org.apache.beam.module'
+ id 'java'
+}
+
+description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI"
+
+dependencies {
+ implementation project(":sdks:java:ml:inference:remote")
+ implementation "com.openai:openai-java-core:4.3.0"
+ implementation "com.openai:openai-java-client-okhttp:4.3.0"
+ implementation library.java.jackson_databind
+ implementation library.java.jackson_annotations
+ implementation library.java.jackson_core
+
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
+ testImplementation project(path: ":sdks:java:core", configuration: "shadow")
+ testImplementation library.java.slf4j_api
+ testRuntimeOnly library.java.slf4j_simple
+ testImplementation library.java.junit
+}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java
new file mode 100644
index 000000000000..a7ebb1ea02a5
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.annotation.JsonPropertyDescription;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.openai.client.OpenAIClient;
+import com.openai.client.okhttp.OpenAIOkHttpClient;
+import com.openai.core.JsonSchemaLocalValidation;
+import com.openai.models.responses.ResponseCreateParams;
+import com.openai.models.responses.StructuredResponseCreateParams;
+import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler;
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Model handler for OpenAI API inference requests.
+ *
+ *
This handler manages communication with OpenAI's API, including client initialization,
+ * request formatting, and response parsing. It uses OpenAI's structured output feature to
+ * ensure reliable input-output pairing.
+ *
+ *
{
+
+ private transient OpenAIClient client;
+ private OpenAIModelParameters modelParameters;
+ private transient ObjectMapper objectMapper;
+
+ /**
+ * Initializes the OpenAI client with the provided parameters.
+ *
+ * This method is called once during setup. It creates an authenticated
+ * OpenAI client using the API key from the parameters.
+ *
+ * @param parameters the configuration parameters including API key and model name
+ */
+ @Override
+ public void createClient(OpenAIModelParameters parameters) {
+ this.modelParameters = parameters;
+ this.client = OpenAIOkHttpClient.builder()
+ .apiKey(this.modelParameters.getApiKey())
+ .build();
+ this.objectMapper = new ObjectMapper();
+ }
+
+ /**
+ * Performs inference on a batch of inputs using the OpenAI Client.
+ *
+ *
This method serializes the input batch to JSON string, sends it to OpenAI with structured
+ * output requirements, and parses the response into {@link PredictionResult} objects
+ * that pair each input with its corresponding output.
+ *
+ * @param input the list of inputs to process
+ * @return an iterable of model results and input pairs
+ */
+ @Override
+ public Iterable> request(List input) {
+
+ try {
+ // Convert input list to JSON string
+ String inputBatch =
+ objectMapper.writeValueAsString(
+ input.stream()
+ .map(OpenAIModelInput::getModelInput)
+ .collect(Collectors.toList()));
+ // Build structured response parameters
+ StructuredResponseCreateParams clientParams = ResponseCreateParams.builder()
+ .model(modelParameters.getModelName())
+ .input(inputBatch)
+ .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
+ .instructions(modelParameters.getInstructionPrompt())
+ .build();
+
+ // Get structured output from the model
+ StructuredInputOutput structuredOutput = client.responses()
+ .create(clientParams)
+ .output()
+ .stream()
+ .flatMap(item -> item.message().stream())
+ .flatMap(message -> message.content().stream())
+ .flatMap(content -> content.outputText().stream())
+ .findFirst()
+ .orElse(null);
+
+ if (structuredOutput == null || structuredOutput.responses == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ // return PredictionResults
+ return structuredOutput.responses.stream()
+ .map(response -> PredictionResult.create(
+ OpenAIModelInput.create(response.input),
+ OpenAIModelResponse.create(response.output)))
+ .collect(Collectors.toList());
+
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException("Failed to serialize input batch", e);
+ }
+ }
+
+ /**
+ * Schema class for structured output response.
+ *
+ * Represents a single input-output pair returned by the OpenAI API.
+ */
+ public static class Response {
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("The input string")
+ public String input;
+
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("The output string")
+ public String output;
+ }
+
+ /**
+ * Schema class for structured output containing multiple responses.
+ *
+ *
This class defines the expected JSON structure for OpenAI's structured output,
+ * ensuring reliable parsing of batched inference results.
+ */
+ public static class StructuredInputOutput {
+ @JsonProperty(required = true)
+ @JsonPropertyDescription("Array of input-output pairs")
+ public List responses;
+ }
+
+}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
new file mode 100644
index 000000000000..65160a4548a4
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import org.apache.beam.sdk.ml.inference.remote.BaseInput;
+/**
+ * Input for OpenAI model inference requests.
+ *
+ * This class encapsulates text input to be sent to OpenAI models.
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
+ * String text = input.getModelInput(); // "Translate to French: Hello"
+ * }
+ *
+ * @see OpenAIModelHandler
+ * @see OpenAIModelResponse
+ */
+public class OpenAIModelInput implements BaseInput {
+
+ private final String input;
+
+ private OpenAIModelInput(String input) {
+
+ this.input = input;
+ }
+
+ /**
+ * Returns the text input for the model.
+ *
+ * @return the input text string
+ */
+ public String getModelInput() {
+ return input;
+ }
+
+ /**
+ * Creates a new input instance with the specified text.
+ *
+ * @param input the text to send to the model
+ * @return a new {@link OpenAIModelInput} instance
+ */
+ public static OpenAIModelInput create(String input) {
+ return new OpenAIModelInput(input);
+ }
+
+}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
new file mode 100644
index 000000000000..2b2b04dfa94b
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import org.apache.beam.sdk.ml.inference.remote.BaseModelParameters;
+
+/**
+ * Configuration parameters required for OpenAI model inference.
+ *
+ * This class encapsulates all configuration needed to initialize and communicate with
+ * OpenAI's API, including authentication credentials, model selection, and inference instructions.
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ * .apiKey("sk-...")
+ * .modelName("gpt-4")
+ * .instructionPrompt("Translate the following text to French:")
+ * .build();
+ * }
+ *
+ * @see OpenAIModelHandler
+ */
+public class OpenAIModelParameters implements BaseModelParameters {
+
+ private final String apiKey;
+ private final String modelName;
+ private final String instructionPrompt;
+
+ private OpenAIModelParameters(Builder builder) {
+ this.apiKey = builder.apiKey;
+ this.modelName = builder.modelName;
+ this.instructionPrompt = builder.instructionPrompt;
+ }
+
+ public String getApiKey() {
+ return apiKey;
+ }
+
+ public String getModelName() {
+ return modelName;
+ }
+
+ public String getInstructionPrompt() {
+ return instructionPrompt;
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+
+
+ public static class Builder {
+ private String apiKey;
+ private String modelName;
+ private String instructionPrompt;
+
+ private Builder() {
+ }
+
+ /**
+ * Sets the OpenAI API key for authentication.
+ *
+ * @param apiKey the API key (required)
+ */
+ public Builder apiKey(String apiKey) {
+ this.apiKey = apiKey;
+ return this;
+ }
+
+ /**
+ * Sets the name of the OpenAI model to use.
+ *
+ * @param modelName the model name, e.g., "gpt-4" (required)
+ */
+ public Builder modelName(String modelName) {
+ this.modelName = modelName;
+ return this;
+ }
+ /**
+ * Sets the instruction prompt for the model.
+ * This prompt provides context or instructions to the model about how to process
+ * the input text.
+ *
+ * @param prompt the instruction text (required)
+ */
+ public Builder instructionPrompt(String prompt) {
+ this.instructionPrompt = prompt;
+ return this;
+ }
+
+ /**
+ * Builds the {@link OpenAIModelParameters} instance.
+ */
+ public OpenAIModelParameters build() {
+ return new OpenAIModelParameters(this);
+ }
+ }
+}
diff --git a/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
new file mode 100644
index 000000000000..f1c92bc765f8
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelResponse.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import org.apache.beam.sdk.ml.inference.remote.BaseResponse;
+
+/**
+ * Response from OpenAI model inference results.
+ * This class encapsulates the text output returned from OpenAI models..
+ *
+ *
Example Usage
+ * {@code
+ * OpenAIModelResponse response = OpenAIModelResponse.create("Bonjour");
+ * String output = response.getModelResponse(); // "Bonjour"
+ * }
+ *
+ * @see OpenAIModelHandler
+ * @see OpenAIModelInput
+ */
+public class OpenAIModelResponse implements BaseResponse {
+
+ private final String output;
+
+ private OpenAIModelResponse(String output) {
+ this.output = output;
+ }
+
+ /**
+ * Returns the text output from the model.
+ *
+ * @return the output text string
+ */
+ public String getModelResponse() {
+ return output;
+ }
+
+ /**
+ * Creates a new response instance with the specified output text.
+ *
+ * @param output the text returned by the model
+ * @return a new {@link OpenAIModelResponse} instance
+ */
+ public static OpenAIModelResponse create(String output) {
+ return new OpenAIModelResponse(output);
+ }
+}
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
new file mode 100644
index 000000000000..ba03bce86988
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerIT.java
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.transforms.MapElements;
+import org.apache.beam.sdk.values.PCollection;
+import org.apache.beam.sdk.values.TypeDescriptor;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.beam.sdk.ml.inference.remote.RemoteInference;
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+
+public class OpenAIModelHandlerIT {
+ private static final Logger LOG = LoggerFactory.getLogger(OpenAIModelHandlerIT.class);
+
+ @Rule
+ public final transient TestPipeline pipeline = TestPipeline.create();
+
+ private String apiKey;
+ private static final String API_KEY_ENV = "OPENAI_API_KEY";
+ private static final String DEFAULT_MODEL = "gpt-4o-mini";
+
+
+ @Before
+ public void setUp() {
+ // Get API key
+ apiKey = System.getenv(API_KEY_ENV);
+
+ // Skip tests if API key is not provided
+ assumeNotNull(
+ "OpenAI API key not found. Set " + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ apiKey);
+ assumeTrue("OpenAI API key is empty. Set " + API_KEY_ENV
+ + " environment variable to run integration tests.",
+ !apiKey.trim().isEmpty());
+ }
+
+ @Test
+ public void testSentimentAnalysisWithSingleInput() {
+ String input = "This product is absolutely amazing! I love it!";
+
+ PCollection inputs = pipeline
+ .apply("CreateSingleInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Analyze the sentiment as 'positive' or 'negative'. Return only one word.")
+ .build()));
+
+ // Verify results
+ PAssert.that(results).satisfies(batches -> {
+ int count = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ count++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertNotNull("Output text should not be null",
+ result.getOutput().getModelResponse());
+
+ String sentiment = result.getOutput().getModelResponse().toLowerCase();
+ assertTrue("Sentiment should be positive or negative, got: " + sentiment,
+ sentiment.contains("positive")
+ || sentiment.contains("negative"));
+ }
+ }
+ assertEquals("Should have exactly 1 result", 1, count);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testSentimentAnalysisWithMultipleInputs() {
+ List inputs = Arrays.asList(
+ "An excellent B2B SaaS solution that streamlines business processes efficiently.",
+ "The customer support is terrible. I've been waiting for days without any response.",
+ "The application works as expected. Installation was straightforward.",
+ "Really impressed with the innovative features! The AI capabilities are groundbreaking!",
+ "Mediocre product with occasional glitches. Documentation could be better.");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateMultipleInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("SentimentInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Analyze sentiment as positive or negative")
+ .build()));
+
+ // Verify we get results for all inputs
+ PAssert.that(results).satisfies(batches -> {
+ int totalCount = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertFalse("Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ assertEquals("Should have results for all 5 inputs", 5, totalCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testTextClassification() {
+ List inputs = Arrays.asList(
+ "How do I reset my password?",
+ "Your product is broken and I want a refund!",
+ "Thank you for the excellent service!");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("ClassificationInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Classify each text into one category: 'question', 'complaint', or 'praise'. Return only the category.")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ List categories = new ArrayList<>();
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String category = result.getOutput().getModelResponse().toLowerCase();
+ categories.add(category);
+ }
+ }
+
+ assertEquals("Should have 3 categories", 3, categories.size());
+
+ // Verify expected categories
+ boolean hasQuestion = categories.stream().anyMatch(c -> c.contains("question"));
+ boolean hasComplaint = categories.stream().anyMatch(c -> c.contains("complaint"));
+ boolean hasPraise = categories.stream().anyMatch(c -> c.contains("praise"));
+
+ assertTrue("Should have at least one recognized category",
+ hasQuestion || hasComplaint || hasPraise);
+
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInputOutputMapping() {
+ List inputs = Arrays.asList("apple", "banana", "cherry");
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs))
+ .apply("MapToInputs", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputCollection
+ .apply("MappingInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Return the input word in uppercase")
+ .build()));
+
+ // Verify input-output pairing is preserved
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String input = result.getInput().getModelInput();
+ String output = result.getOutput().getModelResponse().toLowerCase();
+
+ // Verify the output relates to the input
+ assertTrue("Output should relate to input '" + input + "', got: " + output,
+ output.contains(input.toLowerCase()));
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithDifferentModel() {
+ // Test with a different model
+ String input = "Explain quantum computing in one sentence.";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("DifferentModelInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName("gpt-5")
+ .instructionPrompt("Respond concisely")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ assertNotNull("Output should not be null",
+ result.getOutput().getModelResponse());
+ assertFalse("Output should not be empty",
+ result.getOutput().getModelResponse().trim().isEmpty());
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithInvalidApiKey() {
+ String input = "Test input";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ inputs.apply("InvalidKeyInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey("invalid-api-key-12345")
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt("Test")
+ .build()));
+
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline failure due to invalid API key");
+ } catch (Exception e) {
+ String msg = e.toString().toLowerCase();
+
+ assertTrue(
+ "Expected retry exhaustion or API key issue. Got: " + msg,
+ msg.contains("exhaust") ||
+ msg.contains("max retries") ||
+ msg.contains("401") ||
+ msg.contains("api key") ||
+ msg.contains("incorrect api key")
+ );
+ }
+ }
+
+ /**
+ * Test with custom instruction formats
+ */
+ @Test
+ public void testWithJsonOutputFormat() {
+ String input = "Paris is the capital of France";
+
+ PCollection inputs = pipeline
+ .apply("CreateInput", Create.of(input))
+ .apply("MapToInput", MapElements
+ .into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ PCollection>> results = inputs
+ .apply("JsonFormatInference",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName(DEFAULT_MODEL)
+ .instructionPrompt(
+ "Extract the city and country. Return as: City: [city], Country: [country]")
+ .build()));
+
+ PAssert.that(results).satisfies(batches -> {
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ String output = result.getOutput().getModelResponse();
+ LOG.info("Structured output: " + output);
+
+ // Verify output contains expected information
+ assertTrue("Output should mention Paris: " + output,
+ output.toLowerCase().contains("paris"));
+ assertTrue("Output should mention France: " + output,
+ output.toLowerCase().contains("france"));
+ }
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testRetryWithInvalidModel() {
+
+ PCollection inputs =
+ pipeline
+ .apply("CreateInput", Create.of("Test input"))
+ .apply("MapToInput",
+ MapElements.into(TypeDescriptor.of(OpenAIModelInput.class))
+ .via(OpenAIModelInput::create));
+
+ inputs.apply(
+ "FailingOpenAIRequest",
+ RemoteInference.invoke()
+ .handler(OpenAIModelHandler.class)
+ .withParameters(
+ OpenAIModelParameters.builder()
+ .apiKey(apiKey)
+ .modelName("fake-model")
+ .instructionPrompt("test retry")
+ .build()));
+
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Pipeline should fail after retry exhaustion.");
+ } catch (Exception e) {
+ String message = e.getMessage().toLowerCase();
+
+ assertTrue(
+ "Expected retry-exhaustion error. Actual: " + message,
+ message.contains("exhaust") ||
+ message.contains("retry") ||
+ message.contains("max retries") ||
+ message.contains("request failed") ||
+ message.contains("fake-model"));
+ }
+ }
+
+}
diff --git a/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
new file mode 100644
index 000000000000..0250c559fe65
--- /dev/null
+++ b/sdks/java/ml/inference/openai/src/test/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandlerTest.java
@@ -0,0 +1,450 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.openai;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.StructuredInputOutput;
+import org.apache.beam.sdk.ml.inference.openai.OpenAIModelHandler.Response;
+import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
+
+
+
+@RunWith(JUnit4.class)
+public class OpenAIModelHandlerTest {
+ private OpenAIModelParameters testParameters;
+
+ @Before
+ public void setUp() {
+ testParameters = OpenAIModelParameters.builder()
+ .apiKey("test-api-key")
+ .modelName("gpt-4")
+ .instructionPrompt("Test instruction")
+ .build();
+ }
+
+ /**
+ * Fake OpenAiModelHandler for testing.
+ */
+ static class FakeOpenAiModelHandler extends OpenAIModelHandler {
+
+ private boolean clientCreated = false;
+ private OpenAIModelParameters storedParameters;
+ private List responsesToReturn;
+ private RuntimeException exceptionToThrow;
+ private boolean shouldReturnNull = false;
+
+ public void setResponsesToReturn(List responses) {
+ this.responsesToReturn = responses;
+ }
+
+ public void setExceptionToThrow(RuntimeException exception) {
+ this.exceptionToThrow = exception;
+ }
+
+ public void setShouldReturnNull(boolean shouldReturnNull) {
+ this.shouldReturnNull = shouldReturnNull;
+ }
+
+ public boolean isClientCreated() {
+ return clientCreated;
+ }
+
+ public OpenAIModelParameters getStoredParameters() {
+ return storedParameters;
+ }
+
+ @Override
+ public void createClient(OpenAIModelParameters parameters) {
+ this.storedParameters = parameters;
+ this.clientCreated = true;
+
+ if (exceptionToThrow != null) {
+ throw exceptionToThrow;
+ }
+ }
+
+ @Override
+ public Iterable> request(
+ List input) {
+
+ if (!clientCreated) {
+ throw new IllegalStateException("Client not initialized");
+ }
+
+ if (exceptionToThrow != null) {
+ throw exceptionToThrow;
+ }
+
+ if (shouldReturnNull || responsesToReturn == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ StructuredInputOutput structuredOutput = responsesToReturn.get(0);
+
+ if (structuredOutput == null || structuredOutput.responses == null) {
+ throw new RuntimeException("Model returned no structured responses");
+ }
+
+ return structuredOutput.responses.stream()
+ .map(response -> PredictionResult.create(
+ OpenAIModelInput.create(response.input),
+ OpenAIModelResponse.create(response.output)))
+ .collect(Collectors.toList());
+ }
+ }
+
+ @Test
+ public void testCreateClient() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ OpenAIModelParameters params = OpenAIModelParameters.builder()
+ .apiKey("test-key")
+ .modelName("gpt-4")
+ .instructionPrompt("test prompt")
+ .build();
+
+ handler.createClient(params);
+
+ assertTrue("Client should be created", handler.isClientCreated());
+ assertNotNull("Parameters should be stored", handler.getStoredParameters());
+ assertEquals("test-key", handler.getStoredParameters().getApiKey());
+ assertEquals("gpt-4", handler.getStoredParameters().getModelName());
+ }
+
+ @Test
+ public void testRequestWithSingleInput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ Response response = new Response();
+ response.input = "test input";
+ response.output = "test output";
+ structuredOutput.responses = Collections.singletonList(response);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ assertNotNull("Results should not be null", results);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals("Should have 1 result", 1, resultList.size());
+
+ PredictionResult result = resultList.get(0);
+ assertEquals("test input", result.getInput().getModelInput());
+ assertEquals("test output", result.getOutput().getModelResponse());
+ }
+
+ @Test
+ public void testRequestWithMultipleInputs() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Arrays.asList(
+ OpenAIModelInput.create("input1"),
+ OpenAIModelInput.create("input2"),
+ OpenAIModelInput.create("input3"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+
+ Response response1 = new Response();
+ response1.input = "input1";
+ response1.output = "output1";
+
+ Response response2 = new Response();
+ response2.input = "input2";
+ response2.output = "output2";
+
+ Response response3 = new Response();
+ response3.input = "input3";
+ response3.output = "output3";
+
+ structuredOutput.responses = Arrays.asList(response1, response2, response3);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals("Should have 3 results", 3, resultList.size());
+
+ for (int i = 0; i < 3; i++) {
+ PredictionResult result = resultList.get(i);
+ assertEquals("input" + (i + 1), result.getInput().getModelInput());
+ assertEquals("output" + (i + 1), result.getOutput().getModelResponse());
+ }
+ }
+
+ @Test
+ public void testRequestWithEmptyInput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.emptyList();
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ structuredOutput.responses = Collections.emptyList();
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+ assertEquals("Should have 0 results", 0, resultList.size());
+ }
+
+ @Test
+ public void testRequestWithNullStructuredOutput() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ handler.setShouldReturnNull(true);
+ handler.createClient(testParameters);
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when structured output is null");
+ } catch (RuntimeException e) {
+ assertTrue("Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
+ }
+ }
+
+ @Test
+ public void testRequestWithNullResponsesList() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ structuredOutput.responses = null;
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when responses list is null");
+ } catch (RuntimeException e) {
+ assertTrue("Exception message should mention no structured responses",
+ e.getMessage().contains("Model returned no structured responses"));
+ }
+ }
+
+ @Test
+ public void testCreateClientFailure() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+ handler.setExceptionToThrow(new RuntimeException("Setup failed"));
+
+ try {
+ handler.createClient(testParameters);
+ fail("Expected RuntimeException during client creation");
+ } catch (RuntimeException e) {
+ assertEquals("Setup failed", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testRequestApiFailure() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ handler.createClient(testParameters);
+ handler.setExceptionToThrow(new RuntimeException("API Error"));
+
+ try {
+ handler.request(inputs);
+ fail("Expected RuntimeException when API fails");
+ } catch (RuntimeException e) {
+ assertEquals("API Error", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testRequestWithoutClientInitialization() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Collections.singletonList(
+ OpenAIModelInput.create("test input"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+ Response response = new Response();
+ response.input = "test input";
+ response.output = "test output";
+ structuredOutput.responses = Collections.singletonList(response);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+
+ // Don't call createClient
+ try {
+ handler.request(inputs);
+ fail("Expected IllegalStateException when client not initialized");
+ } catch (IllegalStateException e) {
+ assertTrue("Exception should mention client not initialized",
+ e.getMessage().contains("Client not initialized"));
+ }
+ }
+
+ @Test
+ public void testInputOutputMapping() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+
+ List inputs = Arrays.asList(
+ OpenAIModelInput.create("alpha"),
+ OpenAIModelInput.create("beta"));
+
+ StructuredInputOutput structuredOutput = new StructuredInputOutput();
+
+ Response response1 = new Response();
+ response1.input = "alpha";
+ response1.output = "ALPHA";
+
+ Response response2 = new Response();
+ response2.input = "beta";
+ response2.output = "BETA";
+
+ structuredOutput.responses = Arrays.asList(response1, response2);
+
+ handler.setResponsesToReturn(Collections.singletonList(structuredOutput));
+ handler.createClient(testParameters);
+
+ Iterable> results = handler.request(inputs);
+
+ List> resultList = iterableToList(results);
+
+ assertEquals(2, resultList.size());
+ assertEquals("alpha", resultList.get(0).getInput().getModelInput());
+ assertEquals("ALPHA", resultList.get(0).getOutput().getModelResponse());
+
+ assertEquals("beta", resultList.get(1).getInput().getModelInput());
+ assertEquals("BETA", resultList.get(1).getOutput().getModelResponse());
+ }
+
+ @Test
+ public void testParametersBuilder() {
+ OpenAIModelParameters params = OpenAIModelParameters.builder()
+ .apiKey("my-api-key")
+ .modelName("gpt-4-turbo")
+ .instructionPrompt("Custom instruction")
+ .build();
+
+ assertEquals("my-api-key", params.getApiKey());
+ assertEquals("gpt-4-turbo", params.getModelName());
+ assertEquals("Custom instruction", params.getInstructionPrompt());
+ }
+
+ @Test
+ public void testOpenAIModelInputCreate() {
+ OpenAIModelInput input = OpenAIModelInput.create("test value");
+
+ assertNotNull("Input should not be null", input);
+ assertEquals("test value", input.getModelInput());
+ }
+
+ @Test
+ public void testOpenAIModelResponseCreate() {
+ OpenAIModelResponse response = OpenAIModelResponse.create("test output");
+
+ assertNotNull("Response should not be null", response);
+ assertEquals("test output", response.getModelResponse());
+ }
+
+ @Test
+ public void testStructuredInputOutputStructure() {
+ Response response = new Response();
+ response.input = "test-input";
+ response.output = "test-output";
+
+ assertEquals("test-input", response.input);
+ assertEquals("test-output", response.output);
+
+ StructuredInputOutput structured = new StructuredInputOutput();
+ structured.responses = Collections.singletonList(response);
+
+ assertNotNull("Responses should not be null", structured.responses);
+ assertEquals("Should have 1 response", 1, structured.responses.size());
+ assertEquals("test-input", structured.responses.get(0).input);
+ }
+
+ @Test
+ public void testMultipleRequestsWithSameHandler() {
+ FakeOpenAiModelHandler handler = new FakeOpenAiModelHandler();
+ handler.createClient(testParameters);
+
+ // First request
+ StructuredInputOutput output1 = new StructuredInputOutput();
+ Response response1 = new Response();
+ response1.input = "first";
+ response1.output = "FIRST";
+ output1.responses = Collections.singletonList(response1);
+ handler.setResponsesToReturn(Collections.singletonList(output1));
+
+ List inputs1 = Collections.singletonList(
+ OpenAIModelInput.create("first"));
+ Iterable> results1 = handler.request(inputs1);
+
+ List> resultList1 = iterableToList(results1);
+ assertEquals("FIRST", resultList1.get(0).getOutput().getModelResponse());
+
+ // Second request with different data
+ StructuredInputOutput output2 = new StructuredInputOutput();
+ Response response2 = new Response();
+ response2.input = "second";
+ response2.output = "SECOND";
+ output2.responses = Collections.singletonList(response2);
+ handler.setResponsesToReturn(Collections.singletonList(output2));
+
+ List inputs2 = Collections.singletonList(
+ OpenAIModelInput.create("second"));
+ Iterable> results2 = handler.request(inputs2);
+
+ List> resultList2 = iterableToList(results2);
+ assertEquals("SECOND", resultList2.get(0).getOutput().getModelResponse());
+ }
+
+ // Helper method to convert Iterable to List
+ private List iterableToList(Iterable iterable) {
+ List list = new java.util.ArrayList<>();
+ iterable.forEach(list::add);
+ return list;
+ }
+}
diff --git a/sdks/java/ml/inference/remote/build.gradle b/sdks/java/ml/inference/remote/build.gradle
new file mode 100644
index 000000000000..7cbea0c594d2
--- /dev/null
+++ b/sdks/java/ml/inference/remote/build.gradle
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+plugins {
+ id 'org.apache.beam.module'
+ id 'java-library'
+}
+
+description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: Remote"
+
+dependencies {
+ // Core Beam SDK
+ implementation project(path: ":sdks:java:core", configuration: "shadow")
+
+ compileOnly "com.google.auto.value:auto-value-annotations:1.11.0"
+ annotationProcessor "com.google.auto.value:auto-value:1.11.0"
+ implementation library.java.checker_qual;
+ implementation library.java.vendored_guava_32_1_2_jre
+ implementation library.java.slf4j_api
+ implementation library.java.joda_time
+
+ // testing
+ testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
+ testImplementation library.java.junit
+ testRuntimeOnly library.java.hamcrest
+ testRuntimeOnly library.java.slf4j_simple
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java
new file mode 100644
index 000000000000..73bc43684a94
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseInput.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.io.Serializable;
+
+/**
+ * Base class for defining input types used with remote inference transforms.
+ *Implementations holds the data needed for inference (text, images, etc.)
+ */
+public interface BaseInput extends Serializable {
+
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java
new file mode 100644
index 000000000000..314aec34cf9b
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelHandler.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.util.List;
+
+/**
+ * Interface for model-specific handlers that perform remote inference operations.
+ *
+ * Implementations of this interface encapsulate all logic for communicating with a
+ * specific remote inference service. Each handler is responsible for:
+ *
+ * - Initializing and managing client connections
+ * - Converting Beam inputs to service-specific request formats
+ * - Making inference API calls
+ * - Converting service responses to Beam output types
+ * - Handling errors and retries if applicable
+ *
+ *
+ * Lifecycle
+ *
+ * Handler instances follow this lifecycle:
+ *
+ * - Instantiation via no-argument constructor
+ * - {@link #createClient} called with parameters during setup
+ * - {@link #request} called for each batch of inputs
+ *
+ *
+ *
+ * Handlers typically contain non-serializable client objects.
+ * Mark client fields as {@code transient} and initialize them in {@link #createClient}
+ *
+ *
Batching Considerations
+ *
+ * The {@link #request} method receives a list of inputs. Implementations should:
+ *
+ * - Batch inputs efficiently if the service supports batch inference
+ * - Return results in the same order as inputs
+ * - Maintain input-output correspondence in {@link PredictionResult}
+ *
+ *
+ */
+public interface BaseModelHandler {
+ /**
+ * Initializes the remote model client with the provided parameters.
+ */
+ public void createClient(ParamT parameters);
+
+ /**
+ * Performs inference on a batch of inputs and returns the results.
+ */
+ public Iterable> request(List input);
+
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java
new file mode 100644
index 000000000000..f285377da977
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseModelParameters.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.io.Serializable;
+
+/**
+ * Base interface for defining model-specific parameters used to configure remote inference clients.
+ *
+ * Implementations of this interface encapsulate all configuration needed to initialize
+ * and communicate with a remote model inference service. This typically includes:
+ *
+ * - Authentication credentials (API keys, tokens)
+ * - Model identifiers or names
+ * - Endpoint URLs or connection settings
+ * - Inference configuration (temperature, max tokens, timeout values, etc.)
+ *
+ *
+ * Parameters must be serializable. Consider using
+ * the builder pattern for complex parameter objects.
+ *
+ */
+public interface BaseModelParameters extends Serializable {
+
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java
new file mode 100644
index 000000000000..b92a8e2d4228
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/BaseResponse.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.io.Serializable;
+
+/**
+ * Base class for defining response types returned from remote inference operations.
+
+ *
Implementations:
+ *
+ * - Contain the inference results (predictions, classifications, generated text, etc.)
+ * - Includes any relevant metadata
+ *
+ *
+ */
+public interface BaseResponse extends Serializable {
+
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java
new file mode 100644
index 000000000000..bf1ae66127cf
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/PredictionResult.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.io.Serializable;
+
+/**
+ * Pairs an input with its corresponding inference output.
+ *
+ * This class maintains the association between input data and its model's results
+ * for Downstream processing
+ */
+public class PredictionResult implements Serializable {
+
+ private final InputT input;
+ private final OutputT output;
+
+ private PredictionResult(InputT input, OutputT output) {
+ this.input = input;
+ this.output = output;
+
+ }
+
+ /* Returns input to handler */
+ public InputT getInput() {
+ return input;
+ }
+
+ /* Returns model handler's response*/
+ public OutputT getOutput() {
+ return output;
+ }
+
+ /* Creates a PredictionResult instance of provided input, output and types */
+ public static PredictionResult create(InputT input, OutputT output) {
+ return new PredictionResult<>(input, output);
+ }
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java
new file mode 100644
index 000000000000..da9217bfd52e
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RemoteInference.java
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import org.apache.beam.sdk.transforms.*;
+import org.checkerframework.checker.nullness.qual.Nullable;
+import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument;
+import org.apache.beam.sdk.values.PCollection;
+import com.google.auto.value.AutoValue;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * A {@link PTransform} for making remote inference calls to external machine learning services.
+ *
+ * {@code RemoteInference} provides a framework for integrating remote ML model
+ * inference into Apache Beam pipelines and handles the communication between pipelines
+ * and external inference APIs.
+ *
+ *
Example: OpenAI Model Inference
+ *
+ * {@code
+ * // Create model parameters
+ * OpenAIModelParameters params = OpenAIModelParameters.builder()
+ * .apiKey("your-api-key")
+ * .modelName("gpt-4")
+ * .instructionPrompt("Analyse sentiment as positive or negative")
+ * .build();
+ *
+ * // Apply remote inference transform
+ * PCollection inputs = pipeline.apply(Create.of(
+ * OpenAIModelInput.create("An excellent B2B SaaS solution that streamlines business processes efficiently."),
+ * OpenAIModelInput.create("Really impressed with the innovative features!")
+ * ));
+ *
+ * PCollection>> results =
+ * inputs.apply(
+ * RemoteInference.invoke()
+ * .handler(OpenAIModelHandler.class)
+ * .withParameters(params)
+ * );
+ * }
+ *
+ */
+@SuppressWarnings({ "rawtypes", "unchecked" })
+public class RemoteInference {
+
+ /** Invoke the model handler with model parameters */
+ public static Invoke invoke() {
+ return new AutoValue_RemoteInference_Invoke.Builder().setParameters(null)
+ .build();
+ }
+
+ private RemoteInference() {
+ }
+
+ @AutoValue
+ public abstract static class Invoke
+ extends PTransform, PCollection>>> {
+
+ abstract @Nullable Class extends BaseModelHandler> handler();
+
+ abstract @Nullable BaseModelParameters parameters();
+
+
+ abstract Builder builder();
+
+ @AutoValue.Builder
+ abstract static class Builder {
+
+ abstract Builder setHandler(Class extends BaseModelHandler> modelHandler);
+
+ abstract Builder setParameters(BaseModelParameters modelParameters);
+
+
+ abstract Invoke build();
+ }
+
+ /**
+ * Model handler class for inference.
+ */
+ public Invoke handler(Class extends BaseModelHandler> modelHandler) {
+ return builder().setHandler(modelHandler).build();
+ }
+
+ /**
+ * Configures the parameters for model initialization.
+ */
+ public Invoke withParameters(BaseModelParameters modelParameters) {
+ return builder().setParameters(modelParameters).build();
+ }
+
+
+ @Override
+ public PCollection>> expand(PCollection input) {
+ checkArgument(handler() != null, "handler() is required");
+ checkArgument(parameters() != null, "withParameters() is required");
+ return input
+ .apply("WrapInputInList", MapElements.via(new SimpleFunction>() {
+ @Override
+ public List apply(InputT element) {
+ return Collections.singletonList(element);
+ }
+ }))
+ // Pass the list to the inference function
+ .apply("RemoteInference", ParDo.of(new RemoteInferenceFn(this)));
+ }
+
+ /**
+ * A {@link DoFn} that performs remote inference operation.
+ *
+ * This function manages the lifecycle of the model handler:
+ *
+ * - Instantiates the handler during {@link Setup}
+ * - Initializes the remote client via {@link BaseModelHandler#createClient}
+ * - Processes elements by calling {@link BaseModelHandler#request}
+ *
+ */
+ static class RemoteInferenceFn
+ extends DoFn, Iterable>> {
+
+ private final Class extends BaseModelHandler> handlerClass;
+ private final BaseModelParameters parameters;
+ private transient BaseModelHandler modelHandler;
+ private final RetryHandler retryHandler;
+
+ RemoteInferenceFn(Invoke spec) {
+ this.handlerClass = spec.handler();
+ this.parameters = spec.parameters();
+ retryHandler = RetryHandler.withDefaults();
+ }
+
+ /** Instantiate the model handler and client*/
+ @Setup
+ public void setupHandler() {
+ try {
+ this.modelHandler = handlerClass.getDeclaredConstructor().newInstance();
+ this.modelHandler.createClient(parameters);
+ } catch (Exception e) {
+ throw new RuntimeException("Failed to instantiate handler: "
+ + handlerClass.getName(), e);
+ }
+ }
+ /** Perform Inference */
+ @ProcessElement
+ public void processElement(ProcessContext c) throws Exception {
+ Iterable> response = retryHandler
+ .execute(() -> modelHandler.request(c.element()));
+ c.output(response);
+ }
+ }
+
+ }
+}
diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java
new file mode 100644
index 000000000000..27041d8cb237
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import org.apache.beam.sdk.util.BackOff;
+import org.apache.beam.sdk.util.FluentBackoff;
+import org.apache.beam.sdk.util.Sleeper;
+import org.joda.time.Duration;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+
+/**
+ * A utility for running request and handle failures and retries.
+ */
+public class RetryHandler implements Serializable {
+
+ private static final Logger LOG = LoggerFactory.getLogger(RetryHandler.class);
+
+ private final int maxRetries;
+ private final Duration initialBackoff;
+ private final Duration maxBackoff;
+ private final Duration maxCumulativeBackoff;
+
+ private RetryHandler(
+ int maxRetries,
+ Duration initialBackoff,
+ Duration maxBackoff,
+ Duration maxCumulativeBackoff) {
+ this.maxRetries = maxRetries;
+ this.initialBackoff = initialBackoff;
+ this.maxBackoff = maxBackoff;
+ this.maxCumulativeBackoff = maxCumulativeBackoff;
+ }
+
+ public static RetryHandler withDefaults() {
+ return new RetryHandler(
+ 3, // maxRetries
+ Duration.standardSeconds(1), // initialBackoff
+ Duration.standardSeconds(10), // maxBackoff per retry
+ Duration.standardMinutes(1) // maxCumulativeBackoff
+ );
+ }
+
+ public T execute(RetryableRequest request) throws Exception {
+ BackOff backoff = FluentBackoff.DEFAULT
+ .withMaxRetries(maxRetries)
+ .withInitialBackoff(initialBackoff)
+ .withMaxBackoff(maxBackoff)
+ .withMaxCumulativeBackoff(maxCumulativeBackoff)
+ .backoff();
+
+ Sleeper sleeper = Sleeper.DEFAULT;
+ Exception lastException;
+ int attempt = 0;
+
+ while (true) {
+ try {
+ return request.call();
+
+ } catch (Exception e) {
+ lastException = e;
+
+ long backoffMillis = backoff.nextBackOffMillis();
+
+ if (backoffMillis == BackOff.STOP) {
+ LOG.error("Request failed after {} retry attempts.", attempt);
+ throw new RuntimeException(
+ "Request failed after exhausting retries. " +
+ "Max retries: " + maxRetries + ", " ,
+ lastException);
+ }
+
+ attempt++;
+ LOG.warn("Retry request attempt {} failed with: {}. Retrying in {} ms", attempt, e.getMessage(), backoffMillis);
+
+ sleeper.sleep(backoffMillis);
+ }
+ }
+ }
+
+ @FunctionalInterface
+ public interface RetryableRequest {
+
+ T call() throws Exception;
+ }
+}
diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java
new file mode 100644
index 000000000000..41e4be2dcb33
--- /dev/null
+++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RemoteInferenceTest.java
@@ -0,0 +1,598 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * License); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an AS IS BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.beam.sdk.ml.inference.remote;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import org.apache.beam.sdk.coders.SerializableCoder;
+import org.apache.beam.sdk.testing.PAssert;
+import org.apache.beam.sdk.testing.TestPipeline;
+import org.apache.beam.sdk.transforms.Create;
+import org.apache.beam.sdk.values.PCollection;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+
+
+@RunWith(JUnit4.class)
+public class RemoteInferenceTest {
+
+ @Rule
+ public final transient TestPipeline pipeline = TestPipeline.create();
+
+ // Test input class
+ public static class TestInput implements BaseInput {
+ private final String value;
+
+ private TestInput(String value) {
+ this.value = value;
+ }
+
+ public static TestInput create(String value) {
+ return new TestInput(value);
+ }
+
+ public String getModelInput() {
+ return value;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestInput))
+ return false;
+ TestInput testInput = (TestInput) o;
+ return value.equals(testInput.value);
+ }
+
+ @Override
+ public int hashCode() {
+ return value.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return "TestInput{value='" + value + "'}";
+ }
+ }
+
+ // Test output class
+ public static class TestOutput implements BaseResponse {
+ private final String result;
+
+ private TestOutput(String result) {
+ this.result = result;
+ }
+
+ public static TestOutput create(String result) {
+ return new TestOutput(result);
+ }
+
+ public String getModelResponse() {
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestOutput))
+ return false;
+ TestOutput that = (TestOutput) o;
+ return result.equals(that.result);
+ }
+
+ @Override
+ public int hashCode() {
+ return result.hashCode();
+ }
+
+ @Override
+ public String toString() {
+ return "TestOutput{result='" + result + "'}";
+ }
+ }
+
+ // Test parameters class
+ public static class TestParameters implements BaseModelParameters {
+ private final String config;
+
+ private TestParameters(Builder builder) {
+ this.config = builder.config;
+ }
+
+ public String getConfig() {
+ return config;
+ }
+
+ @Override
+ public String toString() {
+ return "TestParameters{config='" + config + "'}";
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o)
+ return true;
+ if (!(o instanceof TestParameters))
+ return false;
+ TestParameters that = (TestParameters) o;
+ return config.equals(that.config);
+ }
+
+ @Override
+ public int hashCode() {
+ return config.hashCode();
+ }
+
+ // Builder
+ public static class Builder {
+ private String config;
+
+ public Builder setConfig(String config) {
+ this.config = config;
+ return this;
+ }
+
+ public TestParameters build() {
+ return new TestParameters(this);
+ }
+ }
+
+ public static Builder builder() {
+ return new Builder();
+ }
+ }
+
+ // Mock handler for successful inference
+ public static class MockSuccessHandler
+ implements BaseModelHandler {
+
+ private TestParameters parameters;
+ private boolean clientCreated = false;
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ this.parameters = parameters;
+ this.clientCreated = true;
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ if (!clientCreated) {
+ throw new IllegalStateException("Client not initialized");
+ }
+ return input.stream()
+ .map(i -> PredictionResult.create(
+ i,
+ new TestOutput("processed-" + i.getModelInput())))
+ .collect(Collectors.toList());
+ }
+ }
+
+ // Mock handler that returns empty results
+ public static class MockEmptyResultHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ // Setup succeeds
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ // Mock handler that throws exception during setup
+ public static class MockFailingSetupHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ throw new RuntimeException("Setup failed intentionally");
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ // Mock handler that throws exception during request
+ public static class MockFailingRequestHandler
+ implements BaseModelHandler {
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ // Setup succeeds
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ throw new RuntimeException("Request failed intentionally");
+ }
+ }
+
+ // Mock handler without default constructor (to test error handling)
+ public static class MockNoDefaultConstructorHandler
+ implements BaseModelHandler {
+
+ private final String required;
+
+ public MockNoDefaultConstructorHandler(String required) {
+ this.required = required;
+ }
+
+ @Override
+ public void createClient(TestParameters parameters) {
+ }
+
+ @Override
+ public Iterable> request(List input) {
+ return Collections.emptyList();
+ }
+ }
+
+ private static boolean containsMessage(Throwable e, String message) {
+ Throwable current = e;
+ while (current != null) {
+ if (current.getMessage() != null && current.getMessage().contains(message)) {
+ return true;
+ }
+ current = current.getCause();
+ }
+ return false;
+ }
+
+ @Test
+ public void testInvokeWithSingleElement() {
+ TestInput input = TestInput.create("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // Verify the output contains expected predictions
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ List> resultList = StreamSupport.stream(batch.spliterator(), false)
+ .collect(Collectors.toList());
+
+ assertEquals("Expected exactly 1 result", 1, resultList.size());
+
+ PredictionResult result = resultList.get(0);
+ assertEquals("test-value", result.getInput().getModelInput());
+ assertEquals("processed-test-value", result.getOutput().getModelResponse());
+
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInvokeWithMultipleElements() {
+ List inputs = Arrays.asList(
+ new TestInput("input1"),
+ new TestInput("input2"),
+ new TestInput("input3"));
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // Count total results across all batches
+ PAssert.that(results).satisfies(batches -> {
+ int totalCount = 0;
+ for (Iterable> batch : batches) {
+ for (PredictionResult result : batch) {
+ totalCount++;
+ assertTrue("Output should start with 'processed-'",
+ result.getOutput().getModelResponse().startsWith("processed-"));
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ }
+ }
+ assertEquals("Expected 3 total results", 3, totalCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testInvokeWithEmptyCollection() {
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateEmptyInput", Create.empty(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ // assertion for empty PCollection
+ PAssert.that(results).empty();
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testHandlerReturnsEmptyResults() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockEmptyResultHandler.class)
+ .withParameters(params));
+
+ // Verify we still get a result, but it's empty
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ List> resultList = StreamSupport.stream(batch.spliterator(), false)
+ .collect(Collectors.toList());
+ assertEquals("Expected empty result list", 0, resultList.size());
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testHandlerSetupFailure() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingSetupHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails with expected error
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to handler setup failure");
+ } catch (Exception e) {
+ String message = e.getMessage();
+ assertTrue("Exception should mention setup failure or handler instantiation failure",
+ message != null && (message.contains("Setup failed intentionally") ||
+ message.contains("Failed to instantiate handler")));
+ }
+ }
+
+ @Test
+ public void testHandlerRequestFailure() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockFailingRequestHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails with expected error
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to request failure");
+ } catch (Exception e) {
+
+ assertTrue(
+ "Expected 'Request failed intentionally' in exception chain",
+ containsMessage(e, "Request failed intentionally"));
+ }
+ }
+
+ @Test
+ public void testHandlerWithoutDefaultConstructor() {
+ TestInput input = new TestInput("test-value");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockNoDefaultConstructorHandler.class)
+ .withParameters(params));
+
+ // Verify pipeline fails when handler cannot be instantiated
+ try {
+ pipeline.run().waitUntilFinish();
+ fail("Expected pipeline to fail due to missing default constructor");
+ } catch (Exception e) {
+ String message = e.getMessage();
+ assertTrue("Exception should mention handler instantiation failure",
+ message != null && message.contains("Failed to instantiate handler"));
+ }
+ }
+
+ @Test
+ public void testBuilderPattern() {
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ RemoteInference.Invoke transform = RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params);
+
+ assertNotNull("Transform should not be null", transform);
+ }
+
+ @Test
+ public void testPredictionResultMapping() {
+ TestInput input = new TestInput("mapping-test");
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInput", Create.of(input).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.thatSingleton(results).satisfies(batch -> {
+ for (PredictionResult result : batch) {
+ // Verify that input is preserved in the result
+ assertNotNull("Input should not be null", result.getInput());
+ assertNotNull("Output should not be null", result.getOutput());
+ assertEquals("mapping-test", result.getInput().getModelInput());
+ assertTrue("Output should contain input value",
+ result.getOutput().getModelResponse().contains("mapping-test"));
+ }
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ // Temporary behaviour until we introduce java BatchElements transform
+ // to batch elements in RemoteInference
+ @Test
+ public void testMultipleInputsProduceSeparateBatches() {
+ List inputs = Arrays.asList(
+ new TestInput("input1"),
+ new TestInput("input2"));
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ PCollection inputCollection = pipeline
+ .apply("CreateInputs", Create.of(inputs).withCoder(SerializableCoder.of(TestInput.class)));
+
+ PCollection>> results = inputCollection
+ .apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)
+ .withParameters(params));
+
+ PAssert.that(results).satisfies(batches -> {
+ int batchCount = 0;
+ for (Iterable> batch : batches) {
+ batchCount++;
+ int elementCount = 0;
+ elementCount += StreamSupport.stream(batch.spliterator(), false).count();
+ // Each batch should contain exactly 1 element
+ assertEquals("Each batch should contain 1 element", 1, elementCount);
+ }
+ assertEquals("Expected 2 batches", 2, batchCount);
+ return null;
+ });
+
+ pipeline.run().waitUntilFinish();
+ }
+
+ @Test
+ public void testWithEmptyParameters() {
+
+ pipeline.enableAbandonedNodeEnforcement(false);
+
+ TestInput input = TestInput.create("test-value");
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ IllegalArgumentException thrown = assertThrows(
+ IllegalArgumentException.class,
+ () -> inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .handler(MockSuccessHandler.class)));
+
+ assertTrue(
+ "Expected message to contain 'withParameters() is required', but got: " + thrown.getMessage(),
+ thrown.getMessage().contains("withParameters() is required"));
+ }
+
+ @Test
+ public void testWithEmptyHandler() {
+
+ pipeline.enableAbandonedNodeEnforcement(false);
+
+ TestParameters params = TestParameters.builder()
+ .setConfig("test-config")
+ .build();
+
+ TestInput input = TestInput.create("test-value");
+ PCollection inputCollection = pipeline.apply(Create.of(input));
+
+ IllegalArgumentException thrown = assertThrows(
+ IllegalArgumentException.class,
+ () -> inputCollection.apply("RemoteInference",
+ RemoteInference.invoke()
+ .withParameters(params)));
+
+ assertTrue(
+ "Expected message to contain 'handler() is required', but got: " + thrown.getMessage(),
+ thrown.getMessage().contains("handler() is required"));
+ }
+}
diff --git a/settings.gradle.kts b/settings.gradle.kts
index 72c5194ec93d..d3a57ed6b1a4 100644
--- a/settings.gradle.kts
+++ b/settings.gradle.kts
@@ -383,3 +383,6 @@ include("sdks:java:extensions:sql:iceberg")
findProject(":sdks:java:extensions:sql:iceberg")?.name = "iceberg"
include("examples:java:iceberg")
findProject(":examples:java:iceberg")?.name = "iceberg"
+
+include("sdks:java:ml:inference:remote")
+include("sdks:java:ml:inference:openai")