Skip to content

Commit 1cbedb3

Browse files
authored
Add temperature support to conversational api (#3566)
Signed-off-by: yaron2 <[email protected]>
1 parent 28d46f6 commit 1cbedb3

File tree

5 files changed

+40
-4
lines changed

5 files changed

+40
-4
lines changed

conversation/anthropic/anthropic.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ func (a *Anthropic) Converse(ctx context.Context, r *conversation.ConversationRe
9292
})
9393
}
9494

95-
resp, err := a.llm.GenerateContent(ctx, messages)
95+
opts := []llms.CallOption{}
96+
97+
if r.Temperature > 0 {
98+
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
99+
}
100+
101+
resp, err := a.llm.GenerateContent(ctx, messages, opts...)
96102
if err != nil {
97103
return nil, err
98104
}

conversation/aws/bedrock/bedrock.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,13 @@ func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationR
104104
})
105105
}
106106

107-
resp, err := b.llm.GenerateContent(ctx, messages)
107+
opts := []llms.CallOption{}
108+
109+
if r.Temperature > 0 {
110+
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
111+
}
112+
113+
resp, err := b.llm.GenerateContent(ctx, messages, opts...)
108114
if err != nil {
109115
return nil, err
110116
}

conversation/converse.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ type ConversationRequest struct {
4242
Inputs []ConversationInput `json:"inputs"`
4343
Parameters map[string]*anypb.Any `json:"parameters"`
4444
ConversationContext string `json:"conversationContext"`
45+
Temperature float64 `json:"temperature"`
4546

4647
// from metadata
4748
Key string `json:"key"`

conversation/openai/openai.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,9 @@ func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationReque
111111
}
112112

113113
req := openai.ChatCompletionRequest{
114-
Model: o.model,
115-
Messages: messages,
114+
Model: o.model,
115+
Messages: messages,
116+
Temperature: float32(r.Temperature),
116117
}
117118

118119
// TODO: support ConversationContext

conversation/temperature.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
/*
2+
Copyright 2024 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
package conversation
16+
17+
import "github.com/tmc/langchaingo/llms"
18+
19+
// LangchainTemperature returns a langchain compliant LLM temperature
20+
func LangchainTemperature(temperature float64) llms.CallOption {
21+
return llms.WithTemperature(temperature)
22+
}

0 commit comments

Comments
 (0)