Skip to content

Commit 78911cf

Browse files
author
David Grieve
committed
make ChatHistory thread safe
1 parent 61219b5 commit 78911cf

File tree

3 files changed

+26
-19
lines changed

3 files changed

+26
-19
lines changed

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatCompletion.java

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ private static class ChatMessages {
183183

184184
private final List<ChatRequestMessage> newMessages;
185185
private final List<ChatRequestMessage> allMessages;
186-
private final List<OpenAIChatMessageContent> newChatMessageContent;
186+
private final List<OpenAIChatMessageContent<?>> newChatMessageContent;
187187

188188
public ChatMessages(List<ChatRequestMessage> allMessages) {
189189
this.allMessages = Collections.unmodifiableList(allMessages);
@@ -194,7 +194,7 @@ public ChatMessages(List<ChatRequestMessage> allMessages) {
194194
private ChatMessages(
195195
List<ChatRequestMessage> allMessages,
196196
List<ChatRequestMessage> newMessages,
197-
List<OpenAIChatMessageContent> newChatMessageContent) {
197+
List<OpenAIChatMessageContent<?>> newChatMessageContent) {
198198
this.allMessages = Collections.unmodifiableList(allMessages);
199199
this.newMessages = Collections.unmodifiableList(newMessages);
200200
this.newChatMessageContent = Collections.unmodifiableList(newChatMessageContent);
@@ -218,8 +218,8 @@ public ChatMessages add(ChatRequestMessage requestMessage) {
218218
}
219219

220220
@CheckReturnValue
221-
public ChatMessages addChatMessage(List<OpenAIChatMessageContent> chatMessageContent) {
222-
ArrayList<OpenAIChatMessageContent> tmpChatMessageContent = new ArrayList<>(
221+
public ChatMessages addChatMessage(List<OpenAIChatMessageContent<?>> chatMessageContent) {
222+
ArrayList<OpenAIChatMessageContent<?>> tmpChatMessageContent = new ArrayList<>(
223223
newChatMessageContent);
224224
tmpChatMessageContent.addAll(chatMessageContent);
225225

@@ -580,7 +580,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(
580580
arguments);
581581
}
582582

583-
private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
583+
private Mono<List<OpenAIChatMessageContent<?>>> getChatMessageContentsAsync(
584584
ChatCompletions completions) {
585585
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
586586
completions.getId(),
@@ -594,22 +594,27 @@ private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
594594
.filter(Objects::nonNull)
595595
.collect(Collectors.toList());
596596

597-
return Flux.fromIterable(responseMessages)
598-
.flatMap(response -> {
597+
List<OpenAIChatMessageContent<?>> chatMessageContent =
598+
responseMessages
599+
.stream()
600+
.map(response -> {
599601
try {
600-
return Mono.just(new OpenAIChatMessageContent(
602+
return new OpenAIChatMessageContent<>(
601603
AuthorRole.ASSISTANT,
602604
response.getContent(),
603605
this.getModelId(),
604606
null,
605607
null,
606608
completionMetadata,
607-
formOpenAiToolCalls(response)));
609+
formOpenAiToolCalls(response));
608610
} catch (Exception e) {
609-
return Mono.error(e);
611+
return null;
610612
}
611613
})
612-
.collectList();
614+
.filter(Objects::nonNull)
615+
.collect(Collectors.toList());
616+
617+
return Mono.just(chatMessageContent);
613618
}
614619

615620
private List<ChatMessageContent<?>> toOpenAIChatMessageContent(

aiservices/openai/src/main/java/com/microsoft/semantickernel/aiservices/openai/chatcompletion/OpenAIChatMessageContent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ public OpenAIChatMessageContent(
3636
@Nullable String modelId,
3737
@Nullable T innerContent,
3838
@Nullable Charset encoding,
39-
@Nullable FunctionResultMetadata metadata,
39+
@Nullable FunctionResultMetadata<?> metadata,
4040
@Nullable List<OpenAIFunctionToolCall> toolCall) {
4141
super(authorRole, content, modelId, innerContent, encoding, metadata);
4242

semantickernel-api/src/main/java/com/microsoft/semantickernel/services/chatcompletion/ChatHistory.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import com.microsoft.semantickernel.services.chatcompletion.message.ChatMessageTextContent;
66
import java.nio.charset.Charset;
77
import java.util.ArrayList;
8+
import java.util.Collection;
89
import java.util.Collections;
910
import java.util.Iterator;
1011
import java.util.List;
1112
import java.util.Optional;
1213
import java.util.Spliterator;
14+
import java.util.concurrent.ConcurrentLinkedQueue;
1315
import java.util.function.Consumer;
1416
import javax.annotation.Nullable;
1517

@@ -18,7 +20,7 @@
1820
*/
1921
public class ChatHistory implements Iterable<ChatMessageContent<?>> {
2022

21-
private final List<ChatMessageContent<?>> chatMessageContents;
23+
private final Collection<ChatMessageContent<?>> chatMessageContents;
2224

2325
/**
2426
* The default constructor
@@ -33,7 +35,7 @@ public ChatHistory() {
3335
* @param instructions The instructions to add to the chat history
3436
*/
3537
public ChatHistory(@Nullable String instructions) {
36-
this.chatMessageContents = new ArrayList<>();
38+
this.chatMessageContents = new ConcurrentLinkedQueue<>();
3739
if (instructions != null) {
3840
this.chatMessageContents.add(
3941
ChatMessageTextContent.systemMessage(instructions));
@@ -45,8 +47,8 @@ public ChatHistory(@Nullable String instructions) {
4547
*
4648
* @param chatMessageContents The chat message contents to add to the chat history
4749
*/
48-
public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
49-
this.chatMessageContents = new ArrayList(chatMessageContents);
50+
public ChatHistory(List<? extends ChatMessageContent<?>> chatMessageContents) {
51+
this.chatMessageContents = new ConcurrentLinkedQueue<>(chatMessageContents);
5052
}
5153

5254
/**
@@ -55,7 +57,7 @@ public ChatHistory(List<? extends ChatMessageContent> chatMessageContents) {
5557
* @return List of messages in the chat
5658
*/
5759
public List<ChatMessageContent<?>> getMessages() {
58-
return Collections.unmodifiableList(chatMessageContents);
60+
return Collections.unmodifiableList(new ArrayList<>(chatMessageContents));
5961
}
6062

6163
/**
@@ -67,7 +69,7 @@ public Optional<ChatMessageContent<?>> getLastMessage() {
6769
if (chatMessageContents.isEmpty()) {
6870
return Optional.empty();
6971
}
70-
return Optional.of(chatMessageContents.get(chatMessageContents.size() - 1));
72+
return Optional.of(((ConcurrentLinkedQueue<ChatMessageContent<?>>)chatMessageContents).peek());
7173
}
7274

7375
/**
@@ -114,7 +116,7 @@ public Spliterator<ChatMessageContent<?>> spliterator() {
114116
* @param metadata The metadata of the message
115117
*/
116118
public void addMessage(AuthorRole authorRole, String content, Charset encoding,
117-
FunctionResultMetadata metadata) {
119+
FunctionResultMetadata<?> metadata) {
118120
chatMessageContents.add(
119121
ChatMessageTextContent.builder()
120122
.withAuthorRole(authorRole)

0 commit comments

Comments
 (0)