-
Notifications
You must be signed in to change notification settings - Fork 4.5k
Java Native Remote Inference #36623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Java Native Remote Inference #36623
Changes from 13 commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
d511c3d
ml module setup
Ganeshsivakumar 20086a8
openai handler, remoteinference impl
Ganeshsivakumar cf5d0a4
prompt, example pipeline
Ganeshsivakumar d3f4f18
output format and comments
Ganeshsivakumar 3935a6d
add license header
Ganeshsivakumar 768472c
batching and structured outputs
Ganeshsivakumar 6bf1e23
unit test and handler IT
Ganeshsivakumar af49dbc
api docs
Ganeshsivakumar b3a2c67
api docs
Ganeshsivakumar 95e4841
retries
Ganeshsivakumar 5eeff77
refactor inference module
Ganeshsivakumar 5358fec
handle json string
Ganeshsivakumar f8c49ca
module comments
Ganeshsivakumar 9d30f5d
comments
Ganeshsivakumar c88fb26
update beam/build.gradle
Ganeshsivakumar 9b6aecb
commit
Ganeshsivakumar 78829d6
fix workflow
Ganeshsivakumar f6f5d14
fix openai dependencies
Ganeshsivakumar 793d3bc
update dependencies
Ganeshsivakumar File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * License); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an AS IS BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| plugins { | ||
| id 'org.apache.beam.module' | ||
| id 'java' | ||
| } | ||
|
|
||
| description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI" | ||
|
|
||
| dependencies { | ||
| implementation project(":sdks:java:ml:inference:remote") | ||
| implementation project(path: ":sdks:java:core", configuration: "shadow") | ||
|
|
||
| implementation "com.openai:openai-java:4.3.0" | ||
| implementation library.java.jackson_databind | ||
|
|
||
| testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow") | ||
| testImplementation library.java.slf4j_api | ||
| testRuntimeOnly library.java.slf4j_simple | ||
| testImplementation library.java.junit | ||
| testImplementation project(":sdks:java:testing:test-utils") | ||
| } |
164 changes: 164 additions & 0 deletions
164
...ence/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelHandler.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * License); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an AS IS BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.beam.sdk.ml.inference.openai; | ||
|
|
||
| import com.fasterxml.jackson.annotation.JsonProperty; | ||
| import com.fasterxml.jackson.annotation.JsonPropertyDescription; | ||
| import com.fasterxml.jackson.core.JsonProcessingException; | ||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
| import com.openai.client.OpenAIClient; | ||
| import com.openai.client.okhttp.OpenAIOkHttpClient; | ||
| import com.openai.core.JsonSchemaLocalValidation; | ||
| import com.openai.models.responses.ResponseCreateParams; | ||
| import com.openai.models.responses.StructuredResponseCreateParams; | ||
| import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler; | ||
| import org.apache.beam.sdk.ml.inference.remote.PredictionResult; | ||
|
|
||
| import java.util.List; | ||
| import java.util.stream.Collectors; | ||
|
|
||
| /** | ||
| * Model handler for OpenAI API inference requests. | ||
| * | ||
| * <p>This handler manages communication with OpenAI's API, including client initialization, | ||
| * request formatting, and response parsing. It uses OpenAI's structured output feature to | ||
| * ensure reliable input-output pairing. | ||
| * | ||
| * <h3>Usage</h3> | ||
| * <pre>{@code | ||
| * OpenAIModelParameters params = OpenAIModelParameters.builder() | ||
| * .apiKey("sk-...") | ||
| * .modelName("gpt-4") | ||
| * .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}") | ||
| * .build(); | ||
| * | ||
| * PCollection<OpenAIModelInput> inputs = ...; | ||
| * PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results = | ||
| * inputs.apply( | ||
| * RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke() | ||
| * .handler(OpenAIModelHandler.class) | ||
| * .withParameters(params) | ||
| * ); | ||
| * }</pre> | ||
| * | ||
| */ | ||
| public class OpenAIModelHandler | ||
| implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> { | ||
|
|
||
| private transient OpenAIClient client; | ||
| private OpenAIModelParameters modelParameters; | ||
| private transient ObjectMapper objectMapper; | ||
|
|
||
| /** | ||
| * Initializes the OpenAI client with the provided parameters. | ||
| * | ||
| * <p>This method is called once during setup. It creates an authenticated | ||
| * OpenAI client using the API key from the parameters. | ||
| * | ||
| * @param parameters the configuration parameters including API key and model name | ||
| */ | ||
| @Override | ||
| public void createClient(OpenAIModelParameters parameters) { | ||
| this.modelParameters = parameters; | ||
| this.client = OpenAIOkHttpClient.builder() | ||
| .apiKey(this.modelParameters.getApiKey()) | ||
| .build(); | ||
| this.objectMapper = new ObjectMapper(); | ||
| } | ||
|
|
||
| /** | ||
| * Performs inference on a batch of inputs using the OpenAI Client. | ||
| * | ||
| * <p>This method serializes the input batch to JSON string, sends it to OpenAI with structured | ||
| * output requirements, and parses the response into {@link PredictionResult} objects | ||
| * that pair each input with its corresponding output. | ||
| * | ||
| * @param input the list of inputs to process | ||
| * @return an iterable of model results and input pairs | ||
| */ | ||
| @Override | ||
| public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> request(List<OpenAIModelInput> input) { | ||
|
|
||
| try { | ||
| // Convert input list to JSON string | ||
| String inputBatch = objectMapper | ||
| .writeValueAsString(input.stream().map(OpenAIModelInput::getModelInput).toList()); | ||
|
|
||
| // Build structured response parameters | ||
| StructuredResponseCreateParams<StructuredInputOutput> clientParams = ResponseCreateParams.builder() | ||
| .model(modelParameters.getModelName()) | ||
| .input(inputBatch) | ||
| .text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO) | ||
| .instructions(modelParameters.getInstructionPrompt()) | ||
| .build(); | ||
|
|
||
| // Get structured output from the model | ||
| StructuredInputOutput structuredOutput = client.responses() | ||
| .create(clientParams) | ||
| .output() | ||
| .stream() | ||
| .flatMap(item -> item.message().stream()) | ||
| .flatMap(message -> message.content().stream()) | ||
| .flatMap(content -> content.outputText().stream()) | ||
| .findFirst() | ||
| .orElse(null); | ||
|
|
||
| if (structuredOutput == null || structuredOutput.responses == null) { | ||
| throw new RuntimeException("Model returned no structured responses"); | ||
| } | ||
|
|
||
| // return PredictionResults | ||
| return structuredOutput.responses.stream() | ||
| .map(response -> PredictionResult.create( | ||
| OpenAIModelInput.create(response.input), | ||
| OpenAIModelResponse.create(response.output))) | ||
| .collect(Collectors.toList()); | ||
|
|
||
| } catch (JsonProcessingException e) { | ||
| throw new RuntimeException("Failed to serialize input batch", e); | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Schema class for structured output response. | ||
| * | ||
| * <p>Represents a single input-output pair returned by the OpenAI API. | ||
| */ | ||
| public static class Response { | ||
| @JsonProperty(required = true) | ||
| @JsonPropertyDescription("The input string") | ||
| public String input; | ||
|
|
||
| @JsonProperty(required = true) | ||
| @JsonPropertyDescription("The output string") | ||
| public String output; | ||
| } | ||
|
|
||
| /** | ||
| * Schema class for structured output containing multiple responses. | ||
| * | ||
| * <p>This class defines the expected JSON structure for OpenAI's structured output, | ||
| * ensuring reliable parsing of batched inference results. | ||
| */ | ||
| public static class StructuredInputOutput { | ||
| @JsonProperty(required = true) | ||
| @JsonPropertyDescription("Array of input-output pairs") | ||
| public List<Response> responses; | ||
| } | ||
|
|
||
| } |
63 changes: 63 additions & 0 deletions
63
...erence/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelInput.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,63 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * License); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an AS IS BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.beam.sdk.ml.inference.openai; | ||
|
|
||
| import org.apache.beam.sdk.ml.inference.remote.BaseInput; | ||
| /** | ||
| * Input for OpenAI model inference requests. | ||
| * | ||
| * <p>This class encapsulates text input to be sent to OpenAI models. | ||
| * | ||
| * <h3>Example Usage</h3> | ||
| * <pre>{@code | ||
| * OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello"); | ||
| * String text = input.getModelInput(); // "Translate to French: Hello" | ||
| * }</pre> | ||
| * | ||
| * @see OpenAIModelHandler | ||
| * @see OpenAIModelResponse | ||
| */ | ||
| public class OpenAIModelInput extends BaseInput { | ||
|
|
||
| private final String input; | ||
|
|
||
| private OpenAIModelInput(String input) { | ||
|
|
||
| this.input = input; | ||
| } | ||
|
|
||
| /** | ||
| * Returns the text input for the model. | ||
| * | ||
| * @return the input text string | ||
| */ | ||
| public String getModelInput() { | ||
| return input; | ||
| } | ||
|
|
||
| /** | ||
| * Creates a new input instance with the specified text. | ||
| * | ||
| * @param input the text to send to the model | ||
| * @return a new {@link OpenAIModelInput} instance | ||
| */ | ||
| public static OpenAIModelInput create(String input) { | ||
| return new OpenAIModelInput(input); | ||
| } | ||
|
|
||
| } |
114 changes: 114 additions & 0 deletions
114
...e/openai/src/main/java/org/apache/beam/sdk/ml/inference/openai/OpenAIModelParameters.java
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * License); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an AS IS BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
| package org.apache.beam.sdk.ml.inference.openai; | ||
|
|
||
| import org.apache.beam.sdk.ml.inference.remote.BaseModelParameters; | ||
|
|
||
| /** | ||
| * Configuration parameters required for OpenAI model inference. | ||
| * | ||
| * <p>This class encapsulates all configuration needed to initialize and communicate with | ||
| * OpenAI's API, including authentication credentials, model selection, and inference instructions. | ||
| * | ||
| * <h3>Example Usage</h3> | ||
| * <pre>{@code | ||
| * OpenAIModelParameters params = OpenAIModelParameters.builder() | ||
| * .apiKey("sk-...") | ||
| * .modelName("gpt-4") | ||
| * .instructionPrompt("Translate the following text to French:") | ||
| * .build(); | ||
| * }</pre> | ||
| * | ||
| * @see OpenAIModelHandler | ||
| */ | ||
| public class OpenAIModelParameters implements BaseModelParameters { | ||
|
|
||
| private final String apiKey; | ||
| private final String modelName; | ||
| private final String instructionPrompt; | ||
|
|
||
| private OpenAIModelParameters(Builder builder) { | ||
| this.apiKey = builder.apiKey; | ||
| this.modelName = builder.modelName; | ||
| this.instructionPrompt = builder.instructionPrompt; | ||
| } | ||
|
|
||
| public String getApiKey() { | ||
| return apiKey; | ||
| } | ||
|
|
||
| public String getModelName() { | ||
| return modelName; | ||
| } | ||
|
|
||
| public String getInstructionPrompt() { | ||
| return instructionPrompt; | ||
| } | ||
|
|
||
| public static Builder builder() { | ||
| return new Builder(); | ||
| } | ||
|
|
||
|
|
||
| public static class Builder { | ||
| private String apiKey; | ||
| private String modelName; | ||
| private String instructionPrompt; | ||
|
|
||
| private Builder() { | ||
| } | ||
|
|
||
| /** | ||
| * Sets the OpenAI API key for authentication. | ||
| * | ||
| * @param apiKey the API key (required) | ||
| */ | ||
| public Builder apiKey(String apiKey) { | ||
| this.apiKey = apiKey; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * Sets the name of the OpenAI model to use. | ||
| * | ||
| * @param modelName the model name, e.g., "gpt-4" (required) | ||
| */ | ||
| public Builder modelName(String modelName) { | ||
| this.modelName = modelName; | ||
| return this; | ||
| } | ||
| /** | ||
| * Sets the instruction prompt for the model. | ||
| * This prompt provides context or instructions to the model about how to process | ||
| * the input text. | ||
| * | ||
| * @param prompt the instruction text (required) | ||
| */ | ||
| public Builder instructionPrompt(String prompt) { | ||
| this.instructionPrompt = prompt; | ||
| return this; | ||
| } | ||
|
|
||
| /** | ||
| * Builds the {@link OpenAIModelParameters} instance. | ||
| */ | ||
| public OpenAIModelParameters build() { | ||
| return new OpenAIModelParameters(this); | ||
| } | ||
| } | ||
| } | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The classes
OpenAIModelParameters,OpenAIModelInput, andOpenAIModelResponsedo not overrideequals()andhashCode(). This can lead to unexpected behavior when these objects are used in collections (likeSetor as keys in aMap) or in tests that rely on object equality. The test classes you've written forRemoteInferenceTestcorrectly implement these methods, and the production classes should as well.For
OpenAIModelParameters, you can add the following:Similar implementations should be added to
OpenAIModelInputandOpenAIModelResponse.