Skip to content

Commit 4adaa21

Browse files
authored
[OpenAI] Added a new overload getChatCompletionsStreamWithResponse (Azure#39258)
1 parent 5f7abaa commit 4adaa21

16 files changed

+325
-102
lines changed

sdk/openai/azure-ai-openai/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
### Features Added
66

7+
- Added a new overload `getChatCompletionsStreamWithResponse` that takes `RequestOptions` to provide the flexibility to
8+
modify the HTTP request.
9+
710
### Breaking Changes
811

912
### Bugs Fixed

sdk/openai/azure-ai-openai/README.md

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -234,24 +234,20 @@ chatMessages.add(new ChatRequestUserMessage("Can you help me?"));
234234
chatMessages.add(new ChatRequestAssistantMessage("Of course, me hearty! What can I do for ye?"));
235235
chatMessages.add(new ChatRequestUserMessage("What's the best way to train a parrot?"));
236236

237-
IterableStream<ChatCompletions> chatCompletionsStream = client.getChatCompletionsStream("{deploymentOrModelName}",
238-
new ChatCompletionsOptions(chatMessages));
239-
240-
chatCompletionsStream
241-
.stream()
242-
// Remove .skip(1) when using Non-Azure OpenAI API
243-
// Note: the first chat completions can be ignored when using Azure OpenAI service which is a known service bug.
244-
// TODO: remove .skip(1) when service fix the issue.
245-
.skip(1)
246-
.forEach(chatCompletions -> {
247-
ChatResponseMessage delta = chatCompletions.getChoices().get(0).getDelta();
248-
if (delta.getRole() != null) {
249-
System.out.println("Role = " + delta.getRole());
250-
}
251-
if (delta.getContent() != null) {
252-
System.out.print(delta.getContent());
253-
}
254-
});
237+
client.getChatCompletionsStream("{deploymentOrModelName}", new ChatCompletionsOptions(chatMessages))
238+
.forEach(chatCompletions -> {
239+
if (CoreUtils.isNullOrEmpty(chatCompletions.getChoices())) {
240+
return;
241+
}
242+
ChatResponseMessage delta = chatCompletions.getChoices().get(0).getDelta();
243+
if (delta.getRole() != null) {
244+
System.out.println("Role = " + delta.getRole());
245+
}
246+
if (delta.getContent() != null) {
247+
String content = delta.getContent();
248+
System.out.print(content);
249+
}
250+
});
255251
```
256252

257253
To compute tokens in streaming chat completions, see sample [Streaming Chat Completions][sample_get_chat_completions_streaming].

sdk/openai/azure-ai-openai/assets.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
"AssetsRepo": "Azure/azure-sdk-assets",
33
"AssetsRepoPrefixPath": "java",
44
"TagPrefix": "java/openai/azure-ai-openai",
5-
"Tag": "java/openai/azure-ai-openai_915389e465"
5+
"Tag": "java/openai/azure-ai-openai_76031b0cb0"
66
}

sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIAsyncClient.java

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
import com.azure.core.util.logging.ClientLogger;
4747
import com.fasterxml.jackson.core.JsonProcessingException;
4848
import java.nio.ByteBuffer;
49+
import java.util.concurrent.atomic.AtomicReference;
4950
import reactor.core.publisher.Flux;
5051
import reactor.core.publisher.Mono;
5152

@@ -556,6 +557,57 @@ public Mono<Response<ChatCompletions>> getChatCompletionsWithResponse(String dep
556557
.map(response -> new SimpleResponse<>(response, response.getValue().toObject(ChatCompletions.class)));
557558
}
558559

560+
/**
561+
* Gets chat completions for the provided chat messages. Chat completions support a wide variety of tasks and
562+
* generate text that continues from or "completes" provided prompt data.
563+
*
564+
* <p>
565+
* <strong>Code Samples</strong>
566+
* </p>
567+
* <!-- src_embed
568+
* com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload -->
569+
* <pre>
570+
* openAIAsyncClient.getChatCompletionsStreamWithResponse&#40;deploymentOrModelId, new ChatCompletionsOptions&#40;chatMessages&#41;,
571+
* new RequestOptions&#40;&#41;.setHeader&#40;&quot;my-header&quot;, &quot;my-header-value&quot;&#41;&#41;
572+
* .subscribe&#40;
573+
* response -&gt; System.out.print&#40;response.getValue&#40;&#41;.getId&#40;&#41;&#41;,
574+
* error -&gt; System.err.println&#40;&quot;There was an error getting chat completions.&quot; + error&#41;,
575+
* &#40;&#41; -&gt; System.out.println&#40;&quot;Completed called getChatCompletionsStreamWithResponse.&quot;&#41;&#41;;
576+
* </pre>
577+
* <!-- end com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload
578+
* -->
579+
*
580+
* @param deploymentOrModelName Specifies either the model deployment name (when using Azure OpenAI) or model name
581+
* (when using non-Azure OpenAI) to use for this request.
582+
* @param chatCompletionsOptions The configuration information for a chat completions request. Completions support a
583+
* wide variety of tasks and generate text that continues from or "completes" provided prompt data.
584+
* @param requestOptions The options to configure the HTTP request before HTTP client sends it.
585+
* @throws IllegalArgumentException thrown if parameters fail the validation.
586+
* @throws HttpResponseException thrown if the request is rejected by server.
587+
* @throws ClientAuthenticationException thrown if the request is rejected by server on status code 401.
588+
* @throws ResourceNotFoundException thrown if the request is rejected by server on status code 404.
589+
* @throws ResourceModifiedException thrown if the request is rejected by server on status code 409.
590+
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
591+
* @return chat completions stream for the provided chat messages. Completions support a wide variety of tasks and
592+
* generate text that continues from or "completes" provided prompt data.
593+
*/
594+
@ServiceMethod(returns = ReturnType.COLLECTION)
595+
public Flux<Response<ChatCompletions>> getChatCompletionsStreamWithResponse(String deploymentOrModelName,
596+
ChatCompletionsOptions chatCompletionsOptions, RequestOptions requestOptions) {
597+
chatCompletionsOptions.setStream(true);
598+
Mono<Response<BinaryData>> chatCompletionsWithResponse = getChatCompletionsWithResponse(deploymentOrModelName,
599+
BinaryData.fromObject(chatCompletionsOptions), requestOptions);
600+
AtomicReference<Response<BinaryData>> responseCopy = new AtomicReference<>();
601+
Flux<ByteBuffer> responseStream = chatCompletionsWithResponse.flatMapMany(response -> {
602+
responseCopy.set(response);
603+
return response.getValue().toFluxByteBuffer();
604+
});
605+
OpenAIServerSentEvents<ChatCompletions> chatCompletionsStream
606+
= new OpenAIServerSentEvents<>(responseStream, ChatCompletions.class);
607+
return chatCompletionsStream.getEvents()
608+
.map(chatCompletions -> new SimpleResponse<>(responseCopy.get(), chatCompletions));
609+
}
610+
559611
/**
560612
* Return the embeddings for a given prompt.
561613
*
@@ -646,21 +698,10 @@ public Mono<Completions> getCompletions(String deploymentOrModelName, String pro
646698
* <pre>
647699
* openAIAsyncClient
648700
* .getChatCompletionsStream&#40;deploymentOrModelId, new ChatCompletionsOptions&#40;chatMessages&#41;&#41;
649-
* .toStream&#40;&#41;
650-
* &#47;&#47; Remove .skip&#40;1&#41; when using Non-Azure OpenAI API
651-
* &#47;&#47; Note: the first chat completions can be ignored when using Azure OpenAI service which is a known service bug.
652-
* &#47;&#47; TODO: remove .skip&#40;1&#41; after service fixes the issue.
653-
* .skip&#40;1&#41;
654-
* .forEach&#40;chatCompletions -&gt; &#123;
655-
* ChatResponseMessage delta = chatCompletions.getChoices&#40;&#41;.get&#40;0&#41;.getDelta&#40;&#41;;
656-
* if &#40;delta.getRole&#40;&#41; != null&#41; &#123;
657-
* System.out.println&#40;&quot;Role = &quot; + delta.getRole&#40;&#41;&#41;;
658-
* &#125;
659-
* if &#40;delta.getContent&#40;&#41; != null&#41; &#123;
660-
* String content = delta.getContent&#40;&#41;;
661-
* System.out.print&#40;content&#41;;
662-
* &#125;
663-
* &#125;&#41;;
701+
* .subscribe&#40;
702+
* chatCompletions -&gt; System.out.print&#40;chatCompletions.getId&#40;&#41;&#41;,
703+
* error -&gt; System.err.println&#40;&quot;There was an error getting chat completions.&quot; + error&#41;,
704+
* &#40;&#41; -&gt; System.out.println&#40;&quot;Completed called getChatCompletionsStream.&quot;&#41;&#41;;
664705
* </pre>
665706
* <!-- end com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptions -->
666707
*

sdk/openai/azure-ai-openai/src/main/java/com/azure/ai/openai/OpenAIClient.java

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -689,12 +689,10 @@ public ChatCompletions getChatCompletions(String deploymentOrModelName,
689689
* <!-- src_embed com.azure.ai.openai.OpenAIClient.getChatCompletionsStream#String-ChatCompletionsOptions -->
690690
* <pre>
691691
* openAIClient.getChatCompletionsStream&#40;deploymentOrModelId, new ChatCompletionsOptions&#40;chatMessages&#41;&#41;
692-
* .stream&#40;&#41;
693-
* &#47;&#47; Remove .skip&#40;1&#41; when using Non-Azure OpenAI API
694-
* &#47;&#47; Note: the first chat completions can be ignored when using Azure OpenAI service which is a known service bug.
695-
* &#47;&#47; TODO: remove .skip&#40;1&#41; after service fixes the issue.
696-
* .skip&#40;1&#41;
697692
* .forEach&#40;chatCompletions -&gt; &#123;
693+
* if &#40;CoreUtils.isNullOrEmpty&#40;chatCompletions.getChoices&#40;&#41;&#41;&#41; &#123;
694+
* return;
695+
* &#125;
698696
* ChatResponseMessage delta = chatCompletions.getChoices&#40;&#41;.get&#40;0&#41;.getDelta&#40;&#41;;
699697
* if &#40;delta.getRole&#40;&#41; != null&#41; &#123;
700698
* System.out.println&#40;&quot;Role = &quot; + delta.getRole&#40;&#41;&#41;;
@@ -732,6 +730,60 @@ public IterableStream<ChatCompletions> getChatCompletionsStream(String deploymen
732730
return new IterableStream<>(chatCompletionsStream.getEvents());
733731
}
734732

733+
/**
734+
* Gets chat completions for the provided chat messages in streaming mode. Chat completions support a wide variety
735+
* of tasks and generate text that continues from or "completes" provided prompt data.
736+
* <p>
737+
* <strong>Code Samples</strong>
738+
* </p>
739+
* <!-- src_embed com.azure.ai.openai.OpenAIClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload
740+
* -->
741+
* <pre>
742+
* openAIClient.getChatCompletionsStreamWithResponse&#40;deploymentOrModelId, new ChatCompletionsOptions&#40;chatMessages&#41;,
743+
* new RequestOptions&#40;&#41;.setHeader&#40;&quot;my-header&quot;, &quot;my-header-value&quot;&#41;&#41;
744+
* .getValue&#40;&#41;
745+
* .forEach&#40;chatCompletions -&gt; &#123;
746+
* if &#40;CoreUtils.isNullOrEmpty&#40;chatCompletions.getChoices&#40;&#41;&#41;&#41; &#123;
747+
* return;
748+
* &#125;
749+
* ChatResponseMessage delta = chatCompletions.getChoices&#40;&#41;.get&#40;0&#41;.getDelta&#40;&#41;;
750+
* if &#40;delta.getRole&#40;&#41; != null&#41; &#123;
751+
* System.out.println&#40;&quot;Role = &quot; + delta.getRole&#40;&#41;&#41;;
752+
* &#125;
753+
* if &#40;delta.getContent&#40;&#41; != null&#41; &#123;
754+
* String content = delta.getContent&#40;&#41;;
755+
* System.out.print&#40;content&#41;;
756+
* &#125;
757+
* &#125;&#41;;
758+
* </pre>
759+
* <!-- end com.azure.ai.openai.OpenAIClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload -->
760+
*
761+
* @param deploymentOrModelName Specifies either the model deployment name (when using Azure OpenAI) or model name
762+
* (when using non-Azure OpenAI) to use for this request.
763+
* @param chatCompletionsOptions The configuration information for a chat completions request. Completions support a
764+
* wide variety of tasks and generate text that continues from or "completes" provided prompt data.
765+
* @param requestOptions The options to configure the HTTP request before HTTP client sends it.
766+
* @throws IllegalArgumentException thrown if parameters fail the validation.
767+
* @throws HttpResponseException thrown if the request is rejected by server.
768+
* @throws ClientAuthenticationException thrown if the request is rejected by server on status code 401.
769+
* @throws ResourceNotFoundException thrown if the request is rejected by server on status code 404.
770+
* @throws ResourceModifiedException thrown if the request is rejected by server on status code 409.
771+
* @throws RuntimeException all other wrapped checked exceptions if the request fails to be sent.
772+
* @return chat completions stream for the provided chat messages. Completions support a wide variety of tasks and
773+
* generate text that continues from or "completes" provided prompt data.
774+
*/
775+
@ServiceMethod(returns = ReturnType.COLLECTION)
776+
public Response<IterableStream<ChatCompletions>> getChatCompletionsStreamWithResponse(String deploymentOrModelName,
777+
ChatCompletionsOptions chatCompletionsOptions, RequestOptions requestOptions) {
778+
chatCompletionsOptions.setStream(true);
779+
Response<BinaryData> response = getChatCompletionsWithResponse(deploymentOrModelName,
780+
BinaryData.fromObject(chatCompletionsOptions), requestOptions);
781+
Flux<ByteBuffer> responseStream = response.getValue().toFluxByteBuffer();
782+
OpenAIServerSentEvents<ChatCompletions> chatCompletionsStream
783+
= new OpenAIServerSentEvents<>(responseStream, ChatCompletions.class);
784+
return new SimpleResponse<>(response, new IterableStream<>(chatCompletionsStream.getEvents()));
785+
}
786+
735787
/**
736788
* Gets transcribed text and associated metadata from provided spoken audio file data. Audio will be transcribed in
737789
* the written language corresponding to the language it was spoken in.

sdk/openai/azure-ai-openai/src/samples/java/com/azure/ai/openai/StreamingChatSample.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import com.azure.ai.openai.models.ChatResponseMessage;
1313
import com.azure.core.credential.AzureKeyCredential;
1414
import com.azure.core.util.Configuration;
15+
import com.azure.core.util.CoreUtils;
1516
import com.azure.core.util.IterableStream;
1617
import com.knuddels.jtokkit.Encodings;
1718
import com.knuddels.jtokkit.api.Encoding;
@@ -73,11 +74,11 @@ public static void main(String[] args) {
7374
// }
7475
chatCompletionsStream
7576
.stream()
76-
// Remove .skip(1) when using Non-Azure OpenAI API
77-
// Note: the first chat completions can be ignored when using Azure OpenAI service which is a known service bug.
78-
// TODO: remove .skip(1) after service fixes the issue.
79-
.skip(1)
8077
.forEach(chatCompletions -> {
78+
if (CoreUtils.isNullOrEmpty(chatCompletions.getChoices())) {
79+
return;
80+
}
81+
8182
ChatResponseMessage delta = chatCompletions.getChoices().get(0).getDelta();
8283

8384
if (delta.getRole() != null) {

sdk/openai/azure-ai-openai/src/samples/java/com/azure/ai/openai/impl/OpenAIAsyncClientJavaDocCodeSnippets.java

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
import com.azure.ai.openai.models.ChatRequestMessage;
1212
import com.azure.ai.openai.models.ChatRequestSystemMessage;
1313
import com.azure.ai.openai.models.ChatRequestUserMessage;
14-
import com.azure.ai.openai.models.ChatResponseMessage;
1514
import com.azure.core.credential.AzureKeyCredential;
15+
import com.azure.core.http.rest.RequestOptions;
1616
import com.azure.core.util.Configuration;
1717
import org.junit.jupiter.api.Test;
1818

1919
import java.util.ArrayList;
2020
import java.util.List;
21+
import java.util.concurrent.TimeUnit;
2122

2223
/**
2324
* Code snippets for {@link OpenAIAsyncClient}
@@ -29,32 +30,35 @@ public class OpenAIAsyncClientJavaDocCodeSnippets {
2930
* Code snippets for {@link OpenAIClient#getChatCompletionsStream(String, ChatCompletionsOptions)}
3031
*/
3132
@Test
32-
public void getChatCompletionsStream() {
33-
String deploymentOrModelId = "gpt-4-1106-preview";
33+
public void getChatCompletionsStream() throws InterruptedException {
34+
String deploymentOrModelId = Configuration.getGlobalConfiguration().get("OPENAI_DEPLOYMENT_OR_MODEL_ID");
3435
List<ChatRequestMessage> chatMessages = new ArrayList<>();
3536
chatMessages.add(new ChatRequestSystemMessage("You are a helpful assistant. You will talk like a pirate."));
3637
chatMessages.add(new ChatRequestUserMessage("Can you help me?"));
3738
chatMessages.add(new ChatRequestAssistantMessage("Of course, me hearty! What can I do for ye?"));
3839
chatMessages.add(new ChatRequestUserMessage("What's the best way to train a parrot?"));
40+
3941
// BEGIN: com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptions
4042
openAIAsyncClient
4143
.getChatCompletionsStream(deploymentOrModelId, new ChatCompletionsOptions(chatMessages))
42-
.toStream()
43-
// Remove .skip(1) when using Non-Azure OpenAI API
44-
// Note: the first chat completions can be ignored when using Azure OpenAI service which is a known service bug.
45-
// TODO: remove .skip(1) after service fixes the issue.
46-
.skip(1)
47-
.forEach(chatCompletions -> {
48-
ChatResponseMessage delta = chatCompletions.getChoices().get(0).getDelta();
49-
if (delta.getRole() != null) {
50-
System.out.println("Role = " + delta.getRole());
51-
}
52-
if (delta.getContent() != null) {
53-
String content = delta.getContent();
54-
System.out.print(content);
55-
}
56-
});
44+
.subscribe(
45+
chatCompletions -> System.out.print(chatCompletions.getId()),
46+
error -> System.err.println("There was an error getting chat completions." + error),
47+
() -> System.out.println("Completed called getChatCompletionsStream."));
5748
// END: com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptions
49+
50+
// With Response Code Snippet
51+
52+
// BEGIN: com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload
53+
openAIAsyncClient.getChatCompletionsStreamWithResponse(deploymentOrModelId, new ChatCompletionsOptions(chatMessages),
54+
new RequestOptions().setHeader("my-header", "my-header-value"))
55+
.subscribe(
56+
response -> System.out.print(response.getValue().getId()),
57+
error -> System.err.println("There was an error getting chat completions." + error),
58+
() -> System.out.println("Completed called getChatCompletionsStreamWithResponse."));
59+
// END: com.azure.ai.openai.OpenAIAsyncClient.getChatCompletionsStream#String-ChatCompletionsOptionsMaxOverload
60+
61+
TimeUnit.SECONDS.sleep(10);
5862
}
5963

6064
private OpenAIAsyncClient getOpenAIAsyncClient() {

0 commit comments

Comments
 (0)