Skip to content

Commit 4ca04db

Browse files
authored
Conversation API: add cache support, add huggingface+mistral models (#3567)
Signed-off-by: yaron2 <[email protected]>
1 parent 1cbedb3 commit 4ca04db

File tree

15 files changed

+447
-97
lines changed

15 files changed

+447
-97
lines changed

conversation/anthropic/anthropic.go

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,11 @@ import (
2828
)
2929

3030
type Anthropic struct {
31-
llm *anthropic.LLM
31+
llm llms.Model
3232

3333
logger logger.Logger
3434
}
3535

36-
type AnthropicMetadata struct {
37-
Key string `json:"key"`
38-
Model string `json:"model"`
39-
}
40-
4136
func NewAnthropic(logger logger.Logger) conversation.Conversation {
4237
a := &Anthropic{
4338
logger: logger,
@@ -49,7 +44,7 @@ func NewAnthropic(logger logger.Logger) conversation.Conversation {
4944
const defaultModel = "claude-3-5-sonnet-20240620"
5045

5146
func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error {
52-
m := AnthropicMetadata{}
47+
m := conversation.LangchainMetadata{}
5348
err := kmeta.DecodeMetadata(meta.Properties, &m)
5449
if err != nil {
5550
return err
@@ -69,11 +64,21 @@ func (a *Anthropic) Init(ctx context.Context, meta conversation.Metadata) error
6964
}
7065

7166
a.llm = llm
67+
68+
if m.CacheTTL != "" {
69+
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, a.llm)
70+
if cacheErr != nil {
71+
return cacheErr
72+
}
73+
74+
a.llm = cachedModel
75+
}
76+
7277
return nil
7378
}
7479

7580
func (a *Anthropic) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
76-
metadataStruct := AnthropicMetadata{}
81+
metadataStruct := conversation.LangchainMetadata{}
7782
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
7883
return
7984
}

conversation/anthropic/metadata.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,9 @@ metadata:
2727
The Anthropic LLM to use. Defaults to claude-3-5-sonnet-20240620
2828
type: string
2929
example: 'claude-3-5-sonnet-20240620'
30+
- name: cacheTTL
31+
required: false
32+
description: |
33+
A time-to-live value for a prompt cache to expire. Uses Golang durations
34+
type: string
35+
example: '10m'

conversation/aws/bedrock/bedrock.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import (
3131

3232
type AWSBedrock struct {
3333
model string
34-
llm *bedrock.LLM
34+
llm llms.Model
3535

3636
logger logger.Logger
3737
}
@@ -43,6 +43,7 @@ type AWSBedrockMetadata struct {
4343
SecretKey string `json:"secretKey"`
4444
SessionToken string `json:"sessionToken"`
4545
Model string `json:"model"`
46+
CacheTTL string `json:"cacheTTL"`
4647
}
4748

4849
func NewAWSBedrock(logger logger.Logger) conversation.Conversation {
@@ -81,6 +82,15 @@ func (b *AWSBedrock) Init(ctx context.Context, meta conversation.Metadata) error
8182
}
8283

8384
b.llm = llm
85+
86+
if m.CacheTTL != "" {
87+
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, b.llm)
88+
if cacheErr != nil {
89+
return cacheErr
90+
}
91+
92+
b.llm = cachedModel
93+
}
8494
return nil
8595
}
8696

conversation/aws/bedrock/metadata.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,9 @@ metadata:
2424
The LLM to use. Defaults to Bedrock's default provider model from Amazon.
2525
type: string
2626
example: 'amazon.titan-text-express-v1'
27+
- name: cacheTTL
28+
required: false
29+
description: |
30+
A time-to-live value for a prompt cache to expire. Uses Golang durations
31+
type: string
32+
example: '10m'
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
/*
2+
Copyright 2024 The Dapr Authors
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
*/
15+
package huggingface
16+
17+
import (
18+
"context"
19+
"reflect"
20+
21+
"github.com/dapr/components-contrib/conversation"
22+
"github.com/dapr/components-contrib/metadata"
23+
"github.com/dapr/kit/logger"
24+
kmeta "github.com/dapr/kit/metadata"
25+
26+
"github.com/tmc/langchaingo/llms"
27+
"github.com/tmc/langchaingo/llms/huggingface"
28+
)
29+
30+
type Huggingface struct {
31+
llm llms.Model
32+
33+
logger logger.Logger
34+
}
35+
36+
func NewHuggingface(logger logger.Logger) conversation.Conversation {
37+
h := &Huggingface{
38+
logger: logger,
39+
}
40+
41+
return h
42+
}
43+
44+
const defaultModel = "meta-llama/Meta-Llama-3-8B"
45+
46+
func (h *Huggingface) Init(ctx context.Context, meta conversation.Metadata) error {
47+
m := conversation.LangchainMetadata{}
48+
err := kmeta.DecodeMetadata(meta.Properties, &m)
49+
if err != nil {
50+
return err
51+
}
52+
53+
model := defaultModel
54+
if m.Model != "" {
55+
model = m.Model
56+
}
57+
58+
llm, err := huggingface.New(
59+
huggingface.WithModel(model),
60+
huggingface.WithToken(m.Key),
61+
)
62+
if err != nil {
63+
return err
64+
}
65+
66+
h.llm = llm
67+
68+
if m.CacheTTL != "" {
69+
cachedModel, cacheErr := conversation.CacheModel(ctx, m.CacheTTL, h.llm)
70+
if cacheErr != nil {
71+
return cacheErr
72+
}
73+
74+
h.llm = cachedModel
75+
}
76+
77+
return nil
78+
}
79+
80+
func (h *Huggingface) GetComponentMetadata() (metadataInfo metadata.MetadataMap) {
81+
metadataStruct := conversation.LangchainMetadata{}
82+
metadata.GetMetadataInfoFromStructType(reflect.TypeOf(metadataStruct), &metadataInfo, metadata.ConversationType)
83+
return
84+
}
85+
86+
func (h *Huggingface) Converse(ctx context.Context, r *conversation.ConversationRequest) (res *conversation.ConversationResponse, err error) {
87+
messages := make([]llms.MessageContent, 0, len(r.Inputs))
88+
89+
for _, input := range r.Inputs {
90+
role := conversation.ConvertLangchainRole(input.Role)
91+
92+
messages = append(messages, llms.MessageContent{
93+
Role: role,
94+
Parts: []llms.ContentPart{
95+
llms.TextPart(input.Message),
96+
},
97+
})
98+
}
99+
100+
opts := []llms.CallOption{}
101+
102+
if r.Temperature > 0 {
103+
opts = append(opts, conversation.LangchainTemperature(r.Temperature))
104+
}
105+
106+
resp, err := h.llm.GenerateContent(ctx, messages, opts...)
107+
if err != nil {
108+
return nil, err
109+
}
110+
111+
outputs := make([]conversation.ConversationResult, 0, len(resp.Choices))
112+
113+
for i := range resp.Choices {
114+
outputs = append(outputs, conversation.ConversationResult{
115+
Result: resp.Choices[i].Content,
116+
Parameters: r.Parameters,
117+
})
118+
}
119+
120+
res = &conversation.ConversationResponse{
121+
Outputs: outputs,
122+
}
123+
124+
return res, nil
125+
}
126+
127+
func (h *Huggingface) Close() error {
128+
return nil
129+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# yaml-language-server: $schema=../../../component-metadata-schema.json
2+
schemaVersion: v1
3+
type: conversation
4+
name: huggingface
5+
version: v1
6+
status: alpha
7+
title: "Huggingface"
8+
urls:
9+
- title: Reference
10+
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-huggingface/
11+
authenticationProfiles:
12+
- title: "API Key"
13+
description: "Authenticate using an API key"
14+
metadata:
15+
- name: key
16+
type: string
17+
required: true
18+
sensitive: true
19+
description: |
20+
API key for Huggingface.
21+
example: "**********"
22+
default: ""
23+
metadata:
24+
- name: model
25+
required: false
26+
description: |
27+
The Huggingface LLM to use. Defaults to meta-llama/Meta-Llama-3-8B
28+
type: string
29+
example: 'meta-llama/Meta-Llama-3-8B'
30+
- name: cacheTTL
31+
required: false
32+
description: |
33+
A time-to-live value for a prompt cache to expire. Uses Golang durations
34+
type: string
35+
example: '10m'

conversation/metadata.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,10 @@ import "github.com/dapr/components-contrib/metadata"
2020
type Metadata struct {
2121
metadata.Base `json:",inline"`
2222
}
23+
24+
// LangchainMetadata is a common metadata structure for langchain supported implementations.
25+
type LangchainMetadata struct {
26+
Key string `json:"key"`
27+
Model string `json:"model"`
28+
CacheTTL string `json:"cacheTTL"`
29+
}

conversation/mistral/metadata.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# yaml-language-server: $schema=../../../component-metadata-schema.json
2+
schemaVersion: v1
3+
type: conversation
4+
name: mistral
5+
version: v1
6+
status: alpha
7+
title: "Mistral"
8+
urls:
9+
- title: Reference
10+
url: https://docs.dapr.io/reference/components-reference/supported-conversation/setup-mistral/
11+
authenticationProfiles:
12+
- title: "API Key"
13+
description: "Authenticate using an API key"
14+
metadata:
15+
- name: key
16+
type: string
17+
required: true
18+
sensitive: true
19+
description: |
20+
API key for Mistral.
21+
example: "**********"
22+
default: ""
23+
metadata:
24+
- name: model
25+
required: false
26+
description: |
27+
The Mistral LLM to use. Defaults to open-mistral-7b
28+
type: string
29+
example: 'open-mistral-7b'
30+
- name: cacheTTL
31+
required: false
32+
description: |
33+
A time-to-live value for a prompt cache to expire. Uses Golang durations
34+
type: string
35+
example: '10m'

0 commit comments

Comments
 (0)