Skip to content

Commit ece2beb

Browse files
Java Native Remote Inference (#36623)
* ml module setup * openai handler, remoteinference impl * prompt, example pipeline * output format and comments * add license header * batching and structured outputs * unit test and handler IT * api docs * api docs * retries * refactor inference module * handle json string * module comments * comments * update beam/build.gradle * commit * fix workflow * fix openai dependencies * update dependencies
1 parent b59f3d9 commit ece2beb

File tree

19 files changed

+2436
-0
lines changed

19 files changed

+2436
-0
lines changed

build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ tasks.register("javaPreCommit") {
322322
dependsOn(":sdks:java:io:xml:build")
323323
dependsOn(":sdks:java:javadoc:allJavadoc")
324324
dependsOn(":sdks:java:managed:build")
325+
dependsOn("sdks:java:ml:inference:remote:build")
326+
dependsOn("sdks:java:ml:inference:openai:build")
325327
dependsOn(":sdks:java:testing:expansion-service:build")
326328
dependsOn(":sdks:java:testing:jpms-tests:build")
327329
dependsOn(":sdks:java:testing:junit:build")

sdks/java/extensions/ml/build.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ applyJavaNature(
2626
)
2727

2828
description = 'Apache Beam :: SDKs :: Java :: Extensions :: ML'
29+
ext.summary = """beam-sdks-java-extensions-ml provides Apache Beam Java SDK machine learning integration with Google Cloud AI Video Intelligence service. For machine learning run inference modules, see beam-sdks-java-ml-reference-* artifacts."""
2930

3031
dependencies {
3132
implementation project(path: ":sdks:java:core", configuration: "shadow")
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* License); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an AS IS BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
plugins {
19+
id 'org.apache.beam.module'
20+
id 'java'
21+
}
22+
23+
description = "Apache Beam :: SDKs :: Java :: ML :: Inference :: OpenAI"
24+
25+
dependencies {
26+
implementation project(":sdks:java:ml:inference:remote")
27+
implementation "com.openai:openai-java-core:4.3.0"
28+
implementation "com.openai:openai-java-client-okhttp:4.3.0"
29+
implementation library.java.jackson_databind
30+
implementation library.java.jackson_annotations
31+
implementation library.java.jackson_core
32+
33+
testRuntimeOnly project(path: ":runners:direct-java", configuration: "shadow")
34+
testImplementation project(path: ":sdks:java:core", configuration: "shadow")
35+
testImplementation library.java.slf4j_api
36+
testRuntimeOnly library.java.slf4j_simple
37+
testImplementation library.java.junit
38+
}
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* License); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an AS IS BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.ml.inference.openai;
19+
20+
import com.fasterxml.jackson.annotation.JsonProperty;
21+
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
22+
import com.fasterxml.jackson.core.JsonProcessingException;
23+
import com.fasterxml.jackson.databind.ObjectMapper;
24+
import com.openai.client.OpenAIClient;
25+
import com.openai.client.okhttp.OpenAIOkHttpClient;
26+
import com.openai.core.JsonSchemaLocalValidation;
27+
import com.openai.models.responses.ResponseCreateParams;
28+
import com.openai.models.responses.StructuredResponseCreateParams;
29+
import org.apache.beam.sdk.ml.inference.remote.BaseModelHandler;
30+
import org.apache.beam.sdk.ml.inference.remote.PredictionResult;
31+
32+
import java.util.List;
33+
import java.util.stream.Collectors;
34+
35+
/**
36+
* Model handler for OpenAI API inference requests.
37+
*
38+
* <p>This handler manages communication with OpenAI's API, including client initialization,
39+
* request formatting, and response parsing. It uses OpenAI's structured output feature to
40+
* ensure reliable input-output pairing.
41+
*
42+
* <h3>Usage</h3>
43+
* <pre>{@code
44+
* OpenAIModelParameters params = OpenAIModelParameters.builder()
45+
* .apiKey("sk-...")
46+
* .modelName("gpt-4")
47+
* .instructionPrompt("Classify the following text into one of the categories: {CATEGORIES}")
48+
* .build();
49+
*
50+
* PCollection<OpenAIModelInput> inputs = ...;
51+
* PCollection<Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>>> results =
52+
* inputs.apply(
53+
* RemoteInference.<OpenAIModelInput, OpenAIModelResponse>invoke()
54+
* .handler(OpenAIModelHandler.class)
55+
* .withParameters(params)
56+
* );
57+
* }</pre>
58+
*
59+
*/
60+
public class OpenAIModelHandler
61+
implements BaseModelHandler<OpenAIModelParameters, OpenAIModelInput, OpenAIModelResponse> {
62+
63+
private transient OpenAIClient client;
64+
private OpenAIModelParameters modelParameters;
65+
private transient ObjectMapper objectMapper;
66+
67+
/**
68+
* Initializes the OpenAI client with the provided parameters.
69+
*
70+
* <p>This method is called once during setup. It creates an authenticated
71+
* OpenAI client using the API key from the parameters.
72+
*
73+
* @param parameters the configuration parameters including API key and model name
74+
*/
75+
@Override
76+
public void createClient(OpenAIModelParameters parameters) {
77+
this.modelParameters = parameters;
78+
this.client = OpenAIOkHttpClient.builder()
79+
.apiKey(this.modelParameters.getApiKey())
80+
.build();
81+
this.objectMapper = new ObjectMapper();
82+
}
83+
84+
/**
85+
* Performs inference on a batch of inputs using the OpenAI Client.
86+
*
87+
* <p>This method serializes the input batch to JSON string, sends it to OpenAI with structured
88+
* output requirements, and parses the response into {@link PredictionResult} objects
89+
* that pair each input with its corresponding output.
90+
*
91+
* @param input the list of inputs to process
92+
* @return an iterable of model results and input pairs
93+
*/
94+
@Override
95+
public Iterable<PredictionResult<OpenAIModelInput, OpenAIModelResponse>> request(List<OpenAIModelInput> input) {
96+
97+
try {
98+
// Convert input list to JSON string
99+
String inputBatch =
100+
objectMapper.writeValueAsString(
101+
input.stream()
102+
.map(OpenAIModelInput::getModelInput)
103+
.collect(Collectors.toList()));
104+
// Build structured response parameters
105+
StructuredResponseCreateParams<StructuredInputOutput> clientParams = ResponseCreateParams.builder()
106+
.model(modelParameters.getModelName())
107+
.input(inputBatch)
108+
.text(StructuredInputOutput.class, JsonSchemaLocalValidation.NO)
109+
.instructions(modelParameters.getInstructionPrompt())
110+
.build();
111+
112+
// Get structured output from the model
113+
StructuredInputOutput structuredOutput = client.responses()
114+
.create(clientParams)
115+
.output()
116+
.stream()
117+
.flatMap(item -> item.message().stream())
118+
.flatMap(message -> message.content().stream())
119+
.flatMap(content -> content.outputText().stream())
120+
.findFirst()
121+
.orElse(null);
122+
123+
if (structuredOutput == null || structuredOutput.responses == null) {
124+
throw new RuntimeException("Model returned no structured responses");
125+
}
126+
127+
// return PredictionResults
128+
return structuredOutput.responses.stream()
129+
.map(response -> PredictionResult.create(
130+
OpenAIModelInput.create(response.input),
131+
OpenAIModelResponse.create(response.output)))
132+
.collect(Collectors.toList());
133+
134+
} catch (JsonProcessingException e) {
135+
throw new RuntimeException("Failed to serialize input batch", e);
136+
}
137+
}
138+
139+
/**
140+
* Schema class for structured output response.
141+
*
142+
* <p>Represents a single input-output pair returned by the OpenAI API.
143+
*/
144+
public static class Response {
145+
@JsonProperty(required = true)
146+
@JsonPropertyDescription("The input string")
147+
public String input;
148+
149+
@JsonProperty(required = true)
150+
@JsonPropertyDescription("The output string")
151+
public String output;
152+
}
153+
154+
/**
155+
* Schema class for structured output containing multiple responses.
156+
*
157+
* <p>This class defines the expected JSON structure for OpenAI's structured output,
158+
* ensuring reliable parsing of batched inference results.
159+
*/
160+
public static class StructuredInputOutput {
161+
@JsonProperty(required = true)
162+
@JsonPropertyDescription("Array of input-output pairs")
163+
public List<Response> responses;
164+
}
165+
166+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* License); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an AS IS BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
package org.apache.beam.sdk.ml.inference.openai;
19+
20+
import org.apache.beam.sdk.ml.inference.remote.BaseInput;
21+
/**
22+
* Input for OpenAI model inference requests.
23+
*
24+
* <p>This class encapsulates text input to be sent to OpenAI models.
25+
*
26+
* <h3>Example Usage</h3>
27+
* <pre>{@code
28+
* OpenAIModelInput input = OpenAIModelInput.create("Translate to French: Hello");
29+
* String text = input.getModelInput(); // "Translate to French: Hello"
30+
* }</pre>
31+
*
32+
* @see OpenAIModelHandler
33+
* @see OpenAIModelResponse
34+
*/
35+
public class OpenAIModelInput implements BaseInput {
36+
37+
private final String input;
38+
39+
private OpenAIModelInput(String input) {
40+
41+
this.input = input;
42+
}
43+
44+
/**
45+
* Returns the text input for the model.
46+
*
47+
* @return the input text string
48+
*/
49+
public String getModelInput() {
50+
return input;
51+
}
52+
53+
/**
54+
* Creates a new input instance with the specified text.
55+
*
56+
* @param input the text to send to the model
57+
* @return a new {@link OpenAIModelInput} instance
58+
*/
59+
public static OpenAIModelInput create(String input) {
60+
return new OpenAIModelInput(input);
61+
}
62+
63+
}

0 commit comments

Comments
 (0)