Skip to content

Commit 5b5b542

Browse files
authored
OCI GenAi - Cohere tool history fix #391 (#431)
1 parent 2018a08 commit 5b5b542

File tree

6 files changed

+244
-15
lines changed

6 files changed

+244
-15
lines changed

models/langchain4j-community-oci-genai/src/main/java/dev/langchain4j/community/model/oracle/oci/genai/BaseCohereChatModel.java

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import com.oracle.bmc.generativeaiinference.model.CohereResponseFormat;
1212
import com.oracle.bmc.generativeaiinference.model.CohereResponseJsonFormat;
1313
import com.oracle.bmc.generativeaiinference.model.CohereResponseTextFormat;
14-
import com.oracle.bmc.generativeaiinference.model.CohereSystemMessage;
1514
import com.oracle.bmc.generativeaiinference.model.CohereTool;
1615
import com.oracle.bmc.generativeaiinference.model.CohereToolCall;
1716
import com.oracle.bmc.generativeaiinference.model.CohereToolResult;
@@ -172,9 +171,11 @@ private CohereChatRequest.Builder map(ChatRequest chatRequest) {
172171
}
173172
case SYSTEM -> {
174173
var systemMessage = (dev.langchain4j.data.message.SystemMessage) chatMessage;
175-
chatHistory.add(CohereSystemMessage.builder()
176-
.message(systemMessage.text())
177-
.build());
174+
// https://docs.cohere.com/v1/reference/chat
175+
// The chat_history parameter should not be used for SYSTEM messages in most cases.
176+
// Instead, to add a SYSTEM role message at the beginning of a conversation,
177+
// the preamble parameter should be used.
178+
builder.preambleOverride(systemMessage.text());
178179
}
179180
case AI -> {
180181
var aiMessage = (dev.langchain4j.data.message.AiMessage) chatMessage;
@@ -184,19 +185,15 @@ private CohereChatRequest.Builder map(ChatRequest chatRequest) {
184185
if (aiMessage.hasToolExecutionRequests()) {
185186
var toolCalls = new ArrayList<CohereToolCall>();
186187
for (ToolExecutionRequest toolExecReq : aiMessage.toolExecutionRequests()) {
187-
toolCalls.add(CohereToolCall.builder()
188-
.name(toolExecReq.name())
189-
.parameters(fromJson(toolExecReq.arguments(), Map.class))
190-
.build());
188+
toolCalls.add(map(toolExecReq));
191189
}
192190
assistantMessageBuilder.toolCalls(toolCalls);
193191
}
194192
// https://docs.cohere.com/v1/reference/chat
195-
// Chat calls with tool_results should not be included in the Chat history
196-
// to avoid duplication of the message text.
197-
// chatHistory.add(assistantMessageBuilder
198-
// .message(aiMessage.text())
199-
// .build());
193+
// "Chat calls with tool_results should not be included in the Chat history
194+
// to avoid duplication of the message text."
195+
// BUT - sequential tool calls wouldn't work!
196+
chatHistory.add(assistantMessageBuilder.build());
200197
}
201198
default -> throw new UnsupportedOperationException("Unsupported message type: " + chatMessage.type());
202199
}
@@ -228,6 +225,13 @@ private CohereChatRequest.Builder map(ChatRequest chatRequest) {
228225
return builder;
229226
}
230227

228+
CohereToolCall map(ToolExecutionRequest toolExecReq) {
229+
return CohereToolCall.builder()
230+
.name(toolExecReq.name())
231+
.parameters(fromJson(toolExecReq.arguments(), Map.class))
232+
.build();
233+
}
234+
231235
CohereResponseFormat map(ResponseFormat responseFormat) {
232236
if (responseFormat == null) {
233237
return null;

models/langchain4j-community-oci-genai/src/test/java/dev/langchain4j/community/model/oracle/oci/genai/CohereChatModelIT.java

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ protected List<ChatModel> models() {
4343
.authProvider(authProvider)
4444
.region(Region.fromRegionCodeOrId(OCI_GENAI_MODEL_REGION))
4545
.seed(TestEnvProps.SEED)
46-
.maxTokens(600)
46+
.maxTokens(1000)
4747
.temperature(0.7)
4848
.topP(1.0)
4949
.build());
@@ -57,6 +57,7 @@ protected ChatModel createModelWith(final ChatRequestParameters parameters) {
5757
.authProvider(authProvider)
5858
.region(Region.fromRegionCodeOrId(OCI_GENAI_MODEL_REGION))
5959
.seed(TestEnvProps.SEED)
60+
.maxTokens(1000)
6061
.defaultRequestParameters(parameters)
6162
.build();
6263
}
@@ -105,6 +106,11 @@ protected boolean supportsSingleImageInputAsPublicURL() {
105106
return false;
106107
}
107108

109+
@Override
110+
protected boolean supportsToolChoiceRequiredWithMultipleTools() {
111+
return false;
112+
}
113+
108114
@Override
109115
protected boolean supportsToolChoiceRequiredWithSingleTool() {
110116
return false;
@@ -114,6 +120,14 @@ protected boolean assertResponseId() {
114120
return false;
115121
}
116122

123+
protected boolean assertTokenUsage() {
124+
return false;
125+
}
126+
127+
protected boolean supportsJsonResponseFormatWithRawSchema() {
128+
return false;
129+
}
130+
117131
@Override
118132
@Disabled("Not supported by testing model")
119133
protected void should_execute_multiple_tools_in_parallel_then_answer(ChatModel model) {

models/langchain4j-community-oci-genai/src/test/java/dev/langchain4j/community/model/oracle/oci/genai/CohereStreamingChatModelIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_COMPARTMENT_ID_PROPERTY;
1010
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_MODEL_REGION;
1111
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_MODEL_REGION_PROPERTY;
12+
import static org.mockito.ArgumentMatchers.any;
1213
import static org.mockito.ArgumentMatchers.anyString;
14+
import static org.mockito.Mockito.atLeast;
1315
import static org.mockito.Mockito.atLeastOnce;
1416

1517
import com.oracle.bmc.Region;
@@ -82,10 +84,17 @@ protected ChatRequestParameters createIntegrationSpecificParameters(final int ma
8284

8385
@Override
8486
protected void verifyToolCallbacks(StreamingChatResponseHandler handler, InOrder io, String id) {
87+
io.verify(handler, atLeast(0)).onPartialResponse(any(), any());
8588
io.verify(handler, atLeastOnce()).onPartialResponse(anyString());
8689
io.verify(handler).onCompleteToolCall(complete(0, id, "getWeather", "{\"city\":\"Munich\"}"));
8790
}
8891

92+
@Override
93+
protected void verifyToolCallbacks(StreamingChatResponseHandler handler, InOrder io, StreamingChatModel model) {
94+
io.verify(handler, atLeastOnce()).onPartialResponse(anyString());
95+
super.verifyToolCallbacks(handler, io, model);
96+
}
97+
8998
@Disabled("Know issue: response_format is not supported with RAG")
9099
@Override
91100
protected void should_execute_a_tool_then_answer_respecting_JSON_response_format_with_schema(StreamingChatModel m) {
@@ -150,6 +159,19 @@ protected boolean assertThreads() {
150159
return false;
151160
}
152161

162+
protected boolean assertTokenUsage() {
163+
return false;
164+
}
165+
166+
protected boolean supportsJsonResponseFormatWithRawSchema() {
167+
return false;
168+
}
169+
170+
@Override
171+
protected boolean supportsStreamingCancellation() {
172+
return false;
173+
}
174+
153175
@Override
154176
@Disabled("Not supported by testing model")
155177
protected void should_execute_multiple_tools_in_parallel_then_answer(StreamingChatModel model) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package dev.langchain4j.community.model.oracle.oci.genai;
2+
3+
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.NON_EMPTY;
4+
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_COHERE_CHAT_MODEL_NAME_PROPERTY;
5+
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_COMPARTMENT_ID_PROPERTY;
6+
import static dev.langchain4j.community.model.oracle.oci.genai.TestEnvProps.OCI_GENAI_MODEL_REGION_PROPERTY;
7+
import static org.hamcrest.MatcherAssert.assertThat;
8+
import static org.hamcrest.Matchers.contains;
9+
import static org.hamcrest.Matchers.containsString;
10+
import static org.junit.jupiter.api.Assertions.fail;
11+
12+
import com.oracle.bmc.Region;
13+
import com.oracle.bmc.auth.AuthenticationDetailsProvider;
14+
import com.oracle.bmc.generativeaiinference.GenerativeAiInferenceClient;
15+
import com.oracle.bmc.generativeaiinference.model.EmbedTextDetails;
16+
import com.oracle.bmc.generativeaiinference.model.OnDemandServingMode;
17+
import com.oracle.bmc.generativeaiinference.requests.EmbedTextRequest;
18+
import com.oracle.bmc.generativeaiinference.responses.EmbedTextResponse;
19+
import dev.langchain4j.agent.tool.P;
20+
import dev.langchain4j.agent.tool.Tool;
21+
import dev.langchain4j.agent.tool.ToolExecutionRequest;
22+
import dev.langchain4j.service.AiServices;
23+
import dev.langchain4j.service.Result;
24+
import dev.langchain4j.service.SystemMessage;
25+
import dev.langchain4j.service.UserMessage;
26+
import dev.langchain4j.service.V;
27+
import dev.langchain4j.service.tool.ToolExecution;
28+
import java.util.ArrayList;
29+
import java.util.List;
30+
import java.util.concurrent.CompletableFuture;
31+
import java.util.concurrent.ExecutionException;
32+
import java.util.concurrent.TimeUnit;
33+
import java.util.concurrent.TimeoutException;
34+
import java.util.concurrent.atomic.AtomicInteger;
35+
import java.util.stream.IntStream;
36+
import org.junit.jupiter.api.Test;
37+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
38+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariables;
39+
import org.slf4j.Logger;
40+
import org.slf4j.LoggerFactory;
41+
42+
@EnabledIfEnvironmentVariables({
43+
@EnabledIfEnvironmentVariable(named = OCI_GENAI_MODEL_REGION_PROPERTY, matches = NON_EMPTY),
44+
@EnabledIfEnvironmentVariable(named = OCI_GENAI_COMPARTMENT_ID_PROPERTY, matches = NON_EMPTY),
45+
@EnabledIfEnvironmentVariable(named = OCI_GENAI_COHERE_CHAT_MODEL_NAME_PROPERTY, matches = NON_EMPTY)
46+
})
47+
public class CohereToolHistoryTest {
48+
49+
static final Logger LOGGER = LoggerFactory.getLogger(CohereToolHistoryTest.class);
50+
static final AuthenticationDetailsProvider authProvider = TestEnvProps.createAuthProvider();
51+
static final List<List<Float>> EMBEDDINGS = new ArrayList<>();
52+
static final AtomicInteger EMBEDDINGS_SEQ = new AtomicInteger(0);
53+
54+
@Test
55+
public void sequentialToolCalls() throws ExecutionException, InterruptedException, TimeoutException {
56+
try (var model = OciGenAiCohereChatModel.builder()
57+
.modelName(TestEnvProps.OCI_GENAI_COHERE_CHAT_MODEL_NAME)
58+
.compartmentId(TestEnvProps.OCI_GENAI_COMPARTMENT_ID)
59+
.region(Region.fromRegionCodeOrId(TestEnvProps.OCI_GENAI_MODEL_REGION))
60+
.authProvider(authProvider)
61+
.temperature(0.3)
62+
.seed(TestEnvProps.SEED)
63+
.maxTokens(4000)
64+
.build()) {
65+
66+
var tools = new TestTools();
67+
68+
var embeddingAiService = AiServices.builder(TestEmbeddingAiService.class)
69+
.tools(tools)
70+
.toolExecutionErrorHandler((throwable, context) -> fail(throwable))
71+
.chatModel(model)
72+
.build();
73+
74+
var result = embeddingAiService.embed("It's bucketing down", "It's raining cats and dogs");
75+
LOGGER.info("Result> {}", result.content());
76+
77+
assertThat(
78+
result.toolExecutions().stream()
79+
.map(ToolExecution::request)
80+
.map(ToolExecutionRequest::name)
81+
.toList(),
82+
contains("storeEmbedding", "storeEmbedding", "calculateCosineSimilarity"));
83+
84+
assertThat(result.content(), containsString(String.valueOf(tools.similarity.get(10, TimeUnit.SECONDS))));
85+
}
86+
}
87+
88+
interface TestEmbeddingAiService {
89+
90+
@SystemMessage(
91+
"""
92+
You must use provided tool to calculate cosine similarity between the two embedding ids.
93+
You must never calculate cosine similarity yourself, always use tool.
94+
""")
95+
@UserMessage(
96+
"""
97+
Store embeddings of following two strings "{{firstEmbedString}}", "{{secondEmbedString}}"..
98+
When you have two resulting embedding ids use them to calculate cosine similarity, use tool for that.
99+
""")
100+
Result<String> embed(
101+
@V("firstEmbedString") String firstEmbedString, @V("secondEmbedString") String secondEmbedString);
102+
}
103+
104+
static class TestTools {
105+
106+
CompletableFuture<Double> similarity = new CompletableFuture<>();
107+
EmbeddingClient embeddingClient = new EmbeddingClient();
108+
109+
@Tool("Store embedding of an input in the embedding database. Return the result id of the stored embedding.")
110+
int storeEmbedding(@P("String input for embeddings") String input) {
111+
LOGGER.info("Storing embedding \"{}\"", input);
112+
var nextId = EMBEDDINGS_SEQ.getAndIncrement();
113+
EMBEDDINGS.add(nextId, embeddingClient.getEmbeddings(List.of(input)).get(0));
114+
return nextId;
115+
}
116+
117+
@Tool("Calculate cosine similarity between the two embeddings identified by provided ids.")
118+
double calculateCosineSimilarity(@P("First embedding id") int id1, @P("Second embedding id") int id2) {
119+
LOGGER.info("Computing similarity id1={} id2={}", id1, id2);
120+
var similarity = getCosineSimilarity(EMBEDDINGS.get(id1), EMBEDDINGS.get(id2));
121+
LOGGER.info("Computed similarity is {}", similarity);
122+
this.similarity.complete(similarity);
123+
return similarity;
124+
}
125+
126+
public static double[] getL2Normed(List<Float> vector) {
127+
var norm = (float) Math.sqrt(vector.stream().mapToDouble(e -> e * e).sum());
128+
return vector.stream().mapToDouble(e -> e / norm).toArray();
129+
}
130+
131+
public static double getCosineSimilarity(List<Float> vector1, List<Float> vector2) {
132+
if (vector1.size() != vector2.size()) throw new RuntimeException("Vectors are having different size");
133+
134+
var vector1Normed = getL2Normed(vector1);
135+
var vector2Normed = getL2Normed(vector2);
136+
137+
return IntStream.range(0, vector1.size())
138+
.mapToDouble(i -> vector1Normed[i] * vector2Normed[i])
139+
.sum();
140+
}
141+
}
142+
143+
public static class EmbeddingClient {
144+
public EmbeddingClient() {}
145+
146+
public List<List<Float>> getEmbeddings(List<String> input) {
147+
var clientBuilder = GenerativeAiInferenceClient.builder()
148+
.region(Region.fromRegionCodeOrId(TestEnvProps.OCI_GENAI_MODEL_REGION));
149+
150+
try (var embedClient = clientBuilder.build(authProvider)) {
151+
EmbedTextDetails embedTextDetails = EmbedTextDetails.builder()
152+
.inputs(input)
153+
.compartmentId(TestEnvProps.OCI_GENAI_COMPARTMENT_ID)
154+
.servingMode(OnDemandServingMode.builder()
155+
.modelId("cohere.embed-v4.0")
156+
.build())
157+
.build();
158+
159+
EmbedTextRequest request = EmbedTextRequest.builder()
160+
.embedTextDetails(embedTextDetails)
161+
.build();
162+
EmbedTextResponse response = embedClient.embedText(request);
163+
return response.getEmbedTextResult().getEmbeddings();
164+
} catch (Exception ex) {
165+
throw new RuntimeException(ex);
166+
}
167+
}
168+
}
169+
}

models/langchain4j-community-oci-genai/src/test/java/dev/langchain4j/community/model/oracle/oci/genai/GenericStreamingChatModelIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,27 @@ protected boolean assertTokenUsage() {
167167
return false;
168168
}
169169

170+
@Override
171+
protected boolean supportsStreamingCancellation() {
172+
return false;
173+
}
174+
175+
@Override
176+
protected void verifyToolCallbacks(StreamingChatResponseHandler handler, InOrder io, StreamingChatModel model) {
177+
// Some providers can talk before calling a tool. "atLeast(0)" is meant to ignore it.
178+
io.verify(handler, atLeast(0)).onPartialResponse(any(), any());
179+
180+
if (supportsPartialToolStreaming(model)) {
181+
io.verify(handler, atLeast(0)).onPartialToolCall(any());
182+
}
183+
io.verify(handler).onCompleteToolCall(any());
184+
}
185+
186+
@Override
187+
protected boolean supportsPartialToolStreaming(final StreamingChatModel model) {
188+
return true;
189+
}
190+
170191
@Override
171192
@Disabled("Enable when token usage is supported by SDK")
172193
protected void should_respect_maxOutputTokens_in_default_model_parameters() {

models/langchain4j-community-oci-genai/src/test/java/dev/langchain4j/community/model/oracle/oci/genai/GenericStreamingTest.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ void streamedText2() {
9696
var toolExecutionRequests = chatResponse.aiMessage().toolExecutionRequests();
9797
assertThat(toolExecutionRequests.size(), is(0));
9898
assertThat(chatResponse.aiMessage().text(), is("Hello "));
99-
System.out.println(handler.partialResponses);
10099
assertThat(handler.partialResponses, contains("Hello "));
101100
assertThat(handler.completeResponses, contains("Hello "));
102101
}

0 commit comments

Comments
 (0)