Skip to content

Commit 6e71630

Browse files
committed
#153: Aligning chat & chatStream methods with new ChatParams
1 parent 56eb570 commit 6e71630

File tree

10 files changed

+94
-124
lines changed

10 files changed

+94
-124
lines changed

pkg/providers/anthropic/chat.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
5656
chatReq := *c.chatRequestTemplate
5757
chatReq.ApplyParams(params)
5858

59+
chatReq.Stream = false
60+
5961
chatResponse, err := c.doChatRequest(ctx, &chatReq)
6062

6163
if err != nil {

pkg/providers/anthropic/chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
1212
return false
1313
}
1414

15-
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
15+
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
1616
return nil, clients.ErrChatStreamNotImplemented
1717
}

pkg/providers/azureopenai/chat.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
4343
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
4444
chatReq.ApplyParams(params)
4545

46+
chatReq.Stream = false
47+
4648
chatResponse, err := c.doChatRequest(ctx, &chatReq)
4749

4850
if err != nil {

pkg/providers/azureopenai/chat_stream.go

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,10 @@ func (c *Client) SupportChatStream() bool {
155155
return true
156156
}
157157

158-
func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (clients.ChatStream, error) {
158+
func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) {
159159
// Create a new chat request
160-
httpRequest, err := c.makeStreamReq(ctx, req)
160+
httpRequest, err := c.makeStreamReq(ctx, params)
161+
161162
if err != nil {
162163
return nil, err
163164
}
@@ -171,28 +172,13 @@ func (c *Client) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest)
171172
), nil
172173
}
173174

174-
func (c *Client) createRequestFromStream(request *schemas.ChatStreamRequest) *ChatRequest {
175-
// TODO: consider using objectpool to optimize memory allocation
176-
chatRequest := *c.chatRequestTemplate // hoping to get a copy of the template
177-
178-
chatRequest.Messages = make([]ChatMessage, 0, len(request.MessageHistory)+1)
179-
180-
// Add items from messageHistory first and the new chat message last
181-
for _, message := range request.MessageHistory {
182-
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: message.Role, Content: message.Content})
183-
}
184-
185-
chatRequest.Messages = append(chatRequest.Messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
186-
187-
return &chatRequest
188-
}
189-
190-
func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamRequest) (*http.Request, error) {
191-
chatRequest := c.createRequestFromStream(req)
175+
func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) {
176+
chatReq := *c.chatRequestTemplate
177+
chatReq.ApplyParams(params)
192178

193-
chatRequest.Stream = true
179+
chatReq.Stream = true
194180

195-
rawPayload, err := json.Marshal(chatRequest)
181+
rawPayload, err := json.Marshal(chatReq)
196182
if err != nil {
197183
return nil, fmt.Errorf("unable to marshal AzureOpenAI chat stream request payload: %w", err)
198184
}
@@ -212,7 +198,7 @@ func (c *Client) makeStreamReq(ctx context.Context, req *schemas.ChatStreamReque
212198
c.tel.L().Debug(
213199
"Stream chat request",
214200
zap.String("chatURL", c.chatURL),
215-
zap.Any("payload", chatRequest),
201+
zap.Any("payload", chatReq),
216202
)
217203

218204
return request, nil

pkg/providers/bedrock/chat.go

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,21 @@ import (
1616
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
1717
)
1818

19-
// ChatRequest is an Bedrock-specific request schema
19+
// ChatRequest is a Bedrock-specific request schema
2020
type ChatRequest struct {
2121
Messages string `json:"inputText"`
2222
TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"`
2323
}
2424

25+
func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
26+
// message history not yet supported for AWS models
27+
// TODO: do something about lack of message history. Maybe just concatenate all messages?
28+
// in any case, this is not a way to go to ignore message history
29+
message := params.Messages[len(params.Messages)-1]
30+
31+
r.Messages = fmt.Sprintf("Role: %s, Content: %s", message.Role, message.Content)
32+
}
33+
2534
type TextGenerationConfig struct {
2635
Temperature float64 `json:"temperature"`
2736
TopP float64 `json:"topP"`
@@ -41,38 +50,22 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
4150
}
4251
}
4352

44-
func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) string {
45-
// message history not yet supported for AWS models
46-
message := fmt.Sprintf("Role: %s, Content: %s", request.Message.Role, request.Message.Content)
47-
48-
return message
49-
}
50-
5153
// Chat sends a chat request to the specified bedrock model.
52-
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
54+
func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
5355
// Create a new chat request
54-
chatRequest := c.createChatRequestSchema(request)
56+
// TODO: consider using objectpool to optimize memory allocation
57+
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
58+
chatReq.ApplyParams(params)
59+
60+
chatResponse, err := c.doChatRequest(ctx, &chatReq)
5561

56-
chatResponse, err := c.doChatRequest(ctx, chatRequest)
5762
if err != nil {
5863
return nil, err
5964
}
6065

61-
if len(chatResponse.ModelResponse.Message.Content) == 0 {
62-
return nil, ErrEmptyResponse
63-
}
64-
6566
return chatResponse, nil
6667
}
6768

68-
func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
69-
// TODO: consider using objectpool to optimize memory allocation
70-
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
71-
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
72-
73-
return chatRequest
74-
}
75-
7669
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
7770
rawPayload, err := json.Marshal(payload)
7871
if err != nil {
@@ -84,6 +77,7 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
8477
ContentType: aws.String("application/json"),
8578
Body: rawPayload,
8679
})
80+
8781
if err != nil {
8882
c.telemetry.Logger.Error("Error: Couldn't invoke model. Here's why: %v\n", zap.Error(err))
8983
return nil, err
@@ -92,30 +86,36 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
9286
var bedrockCompletion ChatCompletion
9387

9488
err = json.Unmarshal(result.Body, &bedrockCompletion)
89+
9590
if err != nil {
9691
c.telemetry.Logger.Error("failed to parse bedrock chat response", zap.Error(err))
9792

9893
return nil, err
9994
}
10095

96+
modelResult := bedrockCompletion.Results[0]
97+
98+
if len(modelResult.OutputText) == 0 {
99+
return nil, ErrEmptyResponse
100+
}
101+
101102
response := schemas.ChatResponse{
102103
ID: uuid.NewString(),
103104
Created: int(time.Now().Unix()),
104105
Provider: "aws-bedrock",
105106
ModelName: c.config.Model,
106107
Cached: false,
107108
ModelResponse: schemas.ModelResponse{
108-
Metadata: map[string]string{
109-
"system_fingerprint": "none",
110-
},
109+
Metadata: map[string]string{},
111110
Message: schemas.ChatMessage{
112111
Role: "assistant",
113-
Content: bedrockCompletion.Results[0].OutputText,
112+
Content: modelResult.OutputText,
114113
},
115114
TokenUsage: schemas.TokenUsage{
116-
PromptTokens: bedrockCompletion.Results[0].TokenCount,
115+
// TODO: what would happen if there is a few responses? We need to sum that up
116+
PromptTokens: modelResult.TokenCount,
117117
ResponseTokens: -1,
118-
TotalTokens: bedrockCompletion.Results[0].TokenCount,
118+
TotalTokens: modelResult.TokenCount,
119119
},
120120
},
121121
}

pkg/providers/bedrock/chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
1212
return false
1313
}
1414

15-
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
15+
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
1616
return nil, clients.ErrChatStreamNotImplemented
1717
}

pkg/providers/octoml/chat.go

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ type ChatRequest struct {
2828
PresencePenalty int `json:"presence_penalty,omitempty"`
2929
}
3030

31+
func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
32+
r.Messages = params.Messages
33+
// TODO(185): set other params
34+
}
35+
3136
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
3237
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
3338
return &ChatRequest{
@@ -36,50 +41,29 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
3641
TopP: cfg.DefaultParams.TopP,
3742
MaxTokens: cfg.DefaultParams.MaxTokens,
3843
StopWords: cfg.DefaultParams.StopWords,
39-
Stream: false, // unsupported right now
4044
FrequencyPenalty: cfg.DefaultParams.FrequencyPenalty,
4145
PresencePenalty: cfg.DefaultParams.PresencePenalty,
4246
}
4347
}
4448

45-
func NewChatMessagesFromUnifiedRequest(request *schemas.ChatRequest) []ChatMessage {
46-
messages := make([]ChatMessage, 0, len(request.MessageHistory)+1)
47-
48-
// Add items from messageHistory first and the new chat message last
49-
for _, message := range request.MessageHistory {
50-
messages = append(messages, ChatMessage{Role: message.Role, Content: message.Content})
51-
}
52-
53-
messages = append(messages, ChatMessage{Role: request.Message.Role, Content: request.Message.Content})
54-
55-
return messages
56-
}
57-
5849
// Chat sends a chat request to the specified octoml model.
59-
func (c *Client) Chat(ctx context.Context, request *schemas.ChatRequest) (*schemas.ChatResponse, error) {
50+
func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) {
6051
// Create a new chat request
61-
chatRequest := c.createChatRequestSchema(request)
52+
// TODO: consider using objectpool to optimize memory allocation
53+
chatReq := *c.chatRequestTemplate // hoping to get a copy of the template
54+
chatReq.ApplyParams(params)
55+
56+
chatReq.Stream = false
57+
58+
chatResponse, err := c.doChatRequest(ctx, &chatReq)
6259

63-
chatResponse, err := c.doChatRequest(ctx, chatRequest)
6460
if err != nil {
6561
return nil, err
6662
}
6763

68-
if len(chatResponse.ModelResponse.Message.Content) == 0 {
69-
return nil, ErrEmptyResponse
70-
}
71-
7264
return chatResponse, nil
7365
}
7466

75-
func (c *Client) createChatRequestSchema(request *schemas.ChatRequest) *ChatRequest {
76-
// TODO: consider using objectpool to optimize memory allocation
77-
chatRequest := c.chatRequestTemplate // hoping to get a copy of the template
78-
chatRequest.Messages = NewChatMessagesFromUnifiedRequest(request)
79-
80-
return chatRequest
81-
}
82-
8367
func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) {
8468
// Build request payload
8569
rawPayload, err := json.Marshal(payload)
@@ -121,33 +105,39 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
121105
}
122106

123107
// Parse the response JSON
124-
var openAICompletion openai.ChatCompletion // Octo uses the same response schema as OpenAI
108+
var completion openai.ChatCompletion // Octo uses the same response schema as OpenAI
125109

126-
err = json.Unmarshal(bodyBytes, &openAICompletion)
110+
err = json.Unmarshal(bodyBytes, &completion)
127111
if err != nil {
128112
c.telemetry.Logger.Error("failed to parse openai chat response", zap.Error(err))
129113
return nil, err
130114
}
131115

116+
modelChoice := completion.Choices[0]
117+
118+
if len(modelChoice.Message.Content) == 0 {
119+
return nil, ErrEmptyResponse
120+
}
121+
132122
// Map response to UnifiedChatResponse schema
133123
response := schemas.ChatResponse{
134-
ID: openAICompletion.ID,
135-
Created: openAICompletion.Created,
124+
ID: completion.ID,
125+
Created: completion.Created,
136126
Provider: providerName,
137-
ModelName: openAICompletion.ModelName,
127+
ModelName: completion.ModelName,
138128
Cached: false,
139129
ModelResponse: schemas.ModelResponse{
140130
Metadata: map[string]string{
141-
"system_fingerprint": openAICompletion.SystemFingerprint,
131+
"system_fingerprint": completion.SystemFingerprint,
142132
},
143133
Message: schemas.ChatMessage{
144-
Role: openAICompletion.Choices[0].Message.Role,
145-
Content: openAICompletion.Choices[0].Message.Content,
134+
Role: modelChoice.Message.Role,
135+
Content: modelChoice.Message.Content,
146136
},
147137
TokenUsage: schemas.TokenUsage{
148-
PromptTokens: openAICompletion.Usage.PromptTokens,
149-
ResponseTokens: openAICompletion.Usage.CompletionTokens,
150-
TotalTokens: openAICompletion.Usage.TotalTokens,
138+
PromptTokens: completion.Usage.PromptTokens,
139+
ResponseTokens: completion.Usage.CompletionTokens,
140+
TotalTokens: completion.Usage.TotalTokens,
151141
},
152142
},
153143
}

pkg/providers/octoml/chat_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ func (c *Client) SupportChatStream() bool {
1212
return false
1313
}
1414

15-
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatStreamRequest) (clients.ChatStream, error) {
15+
func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) {
1616
return nil, clients.ErrChatStreamNotImplemented
1717
}

0 commit comments

Comments
 (0)