Skip to content

Commit c499867

Browse files
ilayaperumalgchedim
authored andcommitted
Revert "Resolve OpenAI ApiKey for every request"
This reverts commit 3a527ee.
1 parent 296ba3b commit c499867

File tree

8 files changed

+20
-752
lines changed

8 files changed

+20
-752
lines changed

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
* @author Thomas Vitale
6363
* @author David Frizelle
6464
* @author Alexandros Pappas
65-
* @author Filip Hrisafov
6665
*/
6766
public class OpenAiApi {
6867

@@ -129,28 +128,22 @@ public OpenAiApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, String> he
129128

130129
// @formatter:off
131130
Consumer<HttpHeaders> finalHeaders = h -> {
131+
if (!(apiKey instanceof NoopApiKey)) {
132+
h.setBearerAuth(apiKey.getValue());
133+
}
134+
132135
h.setContentType(MediaType.APPLICATION_JSON);
133136
h.addAll(headers);
134137
};
135138
this.restClient = restClientBuilder.clone()
136139
.baseUrl(baseUrl)
137140
.defaultHeaders(finalHeaders)
138141
.defaultStatusHandler(responseErrorHandler)
139-
.defaultRequest(requestHeadersSpec -> {
140-
if (!(apiKey instanceof NoopApiKey)) {
141-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
142-
}
143-
})
144142
.build();
145143

146144
this.webClient = webClientBuilder.clone()
147145
.baseUrl(baseUrl)
148146
.defaultHeaders(finalHeaders)
149-
.defaultRequest(requestHeadersSpec -> {
150-
if (!(apiKey instanceof NoopApiKey)) {
151-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
152-
}
153-
})
154147
.build(); // @formatter:on
155148
}
156149

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiAudioApi.java

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
* @author Christian Tzolov
5050
* @author Ilayaperumal Gopinathan
5151
* @author Jonghoon Park
52-
* @author Filip Hrisafov
5352
* @since 0.8.1
5453
*/
5554
public class OpenAiAudioApi {
@@ -72,30 +71,20 @@ public OpenAiAudioApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
7271
ResponseErrorHandler responseErrorHandler) {
7372

7473
Consumer<HttpHeaders> authHeaders = h -> {
74+
if (!(apiKey instanceof NoopApiKey)) {
75+
h.setBearerAuth(apiKey.getValue());
76+
}
7577
h.addAll(headers);
78+
// h.setContentType(MediaType.APPLICATION_JSON);
7679
};
7780

78-
// @formatter:off
7981
this.restClient = restClientBuilder.clone()
8082
.baseUrl(baseUrl)
8183
.defaultHeaders(authHeaders)
8284
.defaultStatusHandler(responseErrorHandler)
83-
.defaultRequest(requestHeadersSpec -> {
84-
if (!(apiKey instanceof NoopApiKey)) {
85-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
86-
}
87-
})
8885
.build();
8986

90-
this.webClient = webClientBuilder.clone()
91-
.baseUrl(baseUrl)
92-
.defaultHeaders(authHeaders)
93-
.defaultRequest(requestHeadersSpec -> {
94-
if (!(apiKey instanceof NoopApiKey)) {
95-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
96-
}
97-
})
98-
.build(); // @formatter:on
87+
this.webClient = webClientBuilder.clone().baseUrl(baseUrl).defaultHeaders(authHeaders).build();
9988
}
10089

10190
public static Builder builder() {

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiImageApi.java

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30-
import org.springframework.http.HttpHeaders;
3130
import org.springframework.http.MediaType;
3231
import org.springframework.http.ResponseEntity;
3332
import org.springframework.util.Assert;
@@ -41,7 +40,6 @@
4140
*
4241
* @see <a href= "https://platform.openai.com/docs/api-reference/images">Images</a>
4342
* @author lambochen
44-
* @author Filip Hrisafov
4543
*/
4644
public class OpenAiImageApi {
4745

@@ -64,18 +62,15 @@ public OpenAiImageApi(String baseUrl, ApiKey apiKey, MultiValueMap<String, Strin
6462
RestClient.Builder restClientBuilder, ResponseErrorHandler responseErrorHandler) {
6563

6664
// @formatter:off
67-
this.restClient = restClientBuilder.clone()
68-
.baseUrl(baseUrl)
65+
this.restClient = restClientBuilder.baseUrl(baseUrl)
6966
.defaultHeaders(h -> {
67+
if (!(apiKey instanceof NoopApiKey)) {
68+
h.setBearerAuth(apiKey.getValue());
69+
}
7070
h.setContentType(MediaType.APPLICATION_JSON);
7171
h.addAll(headers);
7272
})
7373
.defaultStatusHandler(responseErrorHandler)
74-
.defaultRequest(requestHeadersSpec -> {
75-
if (!(apiKey instanceof NoopApiKey)) {
76-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
77-
}
78-
})
7974
.build();
8075
// @formatter:on
8176

models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiModerationApi.java

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import org.springframework.ai.model.SimpleApiKey;
2828
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
2929
import org.springframework.ai.retry.RetryUtils;
30-
import org.springframework.http.HttpHeaders;
3130
import org.springframework.http.MediaType;
3231
import org.springframework.http.ResponseEntity;
3332
import org.springframework.util.Assert;
@@ -41,7 +40,6 @@
4140
*
4241
* @author Ahmed Yousri
4342
* @author Ilayaperumal Gopinathan
44-
* @author Filip Hrisafov
4543
* @see <a href=
4644
* "https://platform.openai.com/docs/api-reference/moderations">https://platform.openai.com/docs/api-reference/moderations</a>
4745
*/
@@ -66,20 +64,13 @@ public OpenAiModerationApi(String baseUrl, ApiKey apiKey, MultiValueMap<String,
6664

6765
this.objectMapper = new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
6866

69-
// @formatter:off
70-
this.restClient = restClientBuilder.clone()
71-
.baseUrl(baseUrl)
72-
.defaultHeaders(h -> {
73-
h.setContentType(MediaType.APPLICATION_JSON);
74-
h.addAll(headers);
75-
})
76-
.defaultStatusHandler(responseErrorHandler)
77-
.defaultRequest(requestHeadersSpec -> {
78-
if (!(apiKey instanceof NoopApiKey)) {
79-
requestHeadersSpec.header(HttpHeaders.AUTHORIZATION, "Bearer " + apiKey.getValue());
80-
}
81-
})
82-
.build(); // @formatter:on
67+
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> {
68+
if (!(apiKey instanceof NoopApiKey)) {
69+
h.setBearerAuth(apiKey.getValue());
70+
}
71+
h.setContentType(MediaType.APPLICATION_JSON);
72+
h.addAll(headers);
73+
}).defaultStatusHandler(responseErrorHandler).build();
8374
}
8475

8576
public ResponseEntity<OpenAiModerationResponse> createModeration(OpenAiModerationRequest openAiModerationRequest) {

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java

Lines changed: 0 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,10 @@
1616

1717
package org.springframework.ai.openai.api;
1818

19-
import java.io.IOException;
20-
import java.util.LinkedList;
21-
import java.util.List;
22-
import java.util.Objects;
23-
import java.util.Queue;
24-
25-
import okhttp3.mockwebserver.MockResponse;
26-
import okhttp3.mockwebserver.MockWebServer;
27-
import okhttp3.mockwebserver.RecordedRequest;
28-
29-
import org.junit.jupiter.api.AfterEach;
30-
import org.junit.jupiter.api.BeforeEach;
31-
import org.junit.jupiter.api.Nested;
3219
import org.junit.jupiter.api.Test;
3320

3421
import org.springframework.ai.model.ApiKey;
3522
import org.springframework.ai.model.SimpleApiKey;
36-
import org.springframework.http.HttpHeaders;
37-
import org.springframework.http.HttpStatus;
38-
import org.springframework.http.MediaType;
39-
import org.springframework.http.ResponseEntity;
4023
import org.springframework.util.LinkedMultiValueMap;
4124
import org.springframework.util.MultiValueMap;
4225
import org.springframework.web.client.ResponseErrorHandler;
@@ -159,126 +142,4 @@ void testInvalidResponseErrorHandler() {
159142
.hasMessageContaining("responseErrorHandler cannot be null");
160143
}
161144

162-
@Nested
163-
class MockRequests {
164-
165-
MockWebServer mockWebServer;
166-
167-
@BeforeEach
168-
void setUp() throws IOException {
169-
mockWebServer = new MockWebServer();
170-
mockWebServer.start();
171-
}
172-
173-
@AfterEach
174-
void tearDown() throws IOException {
175-
mockWebServer.shutdown();
176-
}
177-
178-
@Test
179-
void dynamicApiKeyRestClient() throws InterruptedException {
180-
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
181-
OpenAiApi api = OpenAiApi.builder()
182-
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
183-
.baseUrl(mockWebServer.url("/").toString())
184-
.build();
185-
186-
MockResponse mockResponse = new MockResponse().setResponseCode(200)
187-
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
188-
.setBody("""
189-
{
190-
"id": "chatcmpl-12345",
191-
"object": "chat.completion",
192-
"created": 1677858242,
193-
"model": "gpt-3.5-turbo",
194-
"choices": [
195-
{
196-
"index": 0,
197-
"message": {
198-
"role": "assistant",
199-
"content": "Hello world"
200-
},
201-
"finish_reason": "stop"
202-
}
203-
],
204-
"usage": {
205-
"prompt_tokens": 10,
206-
"completion_tokens": 5,
207-
"total_tokens": 15
208-
}
209-
}
210-
""");
211-
mockWebServer.enqueue(mockResponse);
212-
mockWebServer.enqueue(mockResponse);
213-
214-
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
215-
OpenAiApi.ChatCompletionMessage.Role.USER);
216-
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
217-
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, false);
218-
ResponseEntity<OpenAiApi.ChatCompletion> response = api.chatCompletionEntity(request);
219-
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
220-
RecordedRequest recordedRequest = mockWebServer.takeRequest();
221-
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
222-
223-
response = api.chatCompletionEntity(request);
224-
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.OK);
225-
226-
recordedRequest = mockWebServer.takeRequest();
227-
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
228-
}
229-
230-
@Test
231-
void dynamicApiKeyWebClient() throws InterruptedException {
232-
Queue<ApiKey> apiKeys = new LinkedList<>(List.of(new SimpleApiKey("key1"), new SimpleApiKey("key2")));
233-
OpenAiApi api = OpenAiApi.builder()
234-
.apiKey(() -> Objects.requireNonNull(apiKeys.poll()).getValue())
235-
.baseUrl(mockWebServer.url("/").toString())
236-
.build();
237-
238-
MockResponse mockResponse = new MockResponse().setResponseCode(200)
239-
.addHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
240-
.setBody("""
241-
{
242-
"id": "chatcmpl-12345",
243-
"object": "chat.completion",
244-
"created": 1677858242,
245-
"model": "gpt-3.5-turbo",
246-
"choices": [
247-
{
248-
"index": 0,
249-
"message": {
250-
"role": "assistant",
251-
"content": "Hello world"
252-
},
253-
"finish_reason": "stop"
254-
}
255-
],
256-
"usage": {
257-
"prompt_tokens": 10,
258-
"completion_tokens": 5,
259-
"total_tokens": 15
260-
}
261-
}
262-
""".replace("\n", ""));
263-
mockWebServer.enqueue(mockResponse);
264-
mockWebServer.enqueue(mockResponse);
265-
266-
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage("Hello world",
267-
OpenAiApi.ChatCompletionMessage.Role.USER);
268-
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(
269-
List.of(chatCompletionMessage), "gpt-3.5-turbo", 0.8, true);
270-
List<OpenAiApi.ChatCompletionChunk> response = api.chatCompletionStream(request).collectList().block();
271-
assertThat(response).hasSize(1);
272-
RecordedRequest recordedRequest = mockWebServer.takeRequest();
273-
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key1");
274-
275-
response = api.chatCompletionStream(request).collectList().block();
276-
assertThat(response).hasSize(1);
277-
278-
recordedRequest = mockWebServer.takeRequest();
279-
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer key2");
280-
}
281-
282-
}
283-
284145
}

0 commit comments

Comments
 (0)