Skip to content

Commit 44a550a

Browse files
committed
#153: Unified the ChatMessage struct and removed copy-pasted structs like that
1 parent 03f1805 commit 44a550a

File tree

13 files changed

+150
-140
lines changed

13 files changed

+150
-140
lines changed

pkg/api/http/server.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@ import (
44
"context"
55
"errors"
66
"fmt"
7+
"github.com/gofiber/contrib/otelfiber"
78
"time"
89

910
"github.com/gofiber/swagger"
1011

1112
"github.com/EinStack/glide/docs"
1213

13-
"github.com/gofiber/contrib/fiberzap/v2"
14-
"github.com/gofiber/contrib/otelfiber"
15-
1614
_ "github.com/EinStack/glide/docs" // importing docs package to include them into the binary
15+
"github.com/gofiber/contrib/fiberzap/v2"
1716

1817
"github.com/gofiber/fiber/v2"
1918

@@ -31,7 +30,6 @@ type Server struct {
3130

3231
func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) {
3332
srv := config.ToServer()
34-
srv.Use(otelfiber.Middleware())
3533

3634
return &Server{
3735
config: config,
@@ -47,6 +45,8 @@ func (srv *Server) Run() error {
4745
return c.Status(fiber.StatusOK).Type("json").Send(docs.SwaggerJSON)
4846
})
4947

48+
srv.server.Use(otelfiber.Middleware())
49+
5050
srv.server.Use(fiberzap.New(fiberzap.Config{
5151
Logger: srv.telemetry.Logger,
5252
}))

pkg/api/schemas/chat.go

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,68 @@ package schemas
22

33
// ChatRequest defines Glide's Chat Request Schema unified across all language models
44
type ChatRequest struct {
5-
Message ChatMessage `json:"message" validate:"required"`
6-
MessageHistory []ChatMessage `json:"message_history"`
7-
OverrideParams *OverrideChatRequest `json:"override_params,omitempty"`
5+
Message ChatMessage `json:"message" validate:"required"`
6+
MessageHistory []ChatMessage `json:"message_history"`
7+
OverrideParams *map[string]ModelParamsOverride `json:"override_params,omitempty"`
88
}
99

10-
type OverrideChatRequest struct {
11-
ModelID string `json:"model_id" validate:"required"`
12-
Message ChatMessage `json:"message" validate:"required"`
10+
func (r *ChatRequest) ModelParams(modelNameOrID string) *ModelParamsOverride {
11+
if r.OverrideParams == nil {
12+
return nil
13+
}
14+
15+
if override, found := (*r.OverrideParams)[modelNameOrID]; found {
16+
return &override
17+
}
18+
19+
return nil
20+
}
21+
22+
// ModelParamsOverride allows to redefine chat message and model params based on the model ID
23+
//
24+
// Glide provides an abstraction around concreate models and this is a way to be able to provide model-specific params if needed.
25+
// The override is going to be applied if Glide picks the referenced there (it may pick another model to serve a given request)
26+
type ModelParamsOverride struct {
27+
// TODO: should be just string?
28+
Message ChatMessage `json:"message,omitempty"`
29+
// TODO(185): Add an ability to override model params
30+
}
31+
32+
// ChatParams represents a chat request params that overrides the default model params from configs
33+
type ChatParams struct {
34+
Messages []ChatMessage
35+
// TODO(185): set other params
36+
}
37+
38+
// ChatParams returns a specific chat request params account for model-specific overrides.
39+
func (r *ChatRequest) ChatParams(modelID string, modelName string) *ChatParams {
40+
params := &ChatParams{
41+
Messages: make([]ChatMessage, 0, len(r.MessageHistory)+1),
42+
}
43+
44+
reqMessage := r.Message
45+
46+
if override := r.ModelParams(modelName); override != nil {
47+
reqMessage = override.Message
48+
// TODO(185): set other params
49+
}
50+
51+
if override := r.ModelParams(modelID); override != nil {
52+
reqMessage = override.Message
53+
// TODO(185): set other params
54+
}
55+
56+
params.Messages = append(params.Messages, r.MessageHistory...)
57+
params.Messages = append(params.Messages, reqMessage)
58+
59+
return params
1360
}
1461

1562
func NewChatFromStr(message string) *ChatRequest {
1663
return &ChatRequest{
1764
Message: ChatMessage{
1865
"user",
1966
message,
20-
"glide",
2167
},
2268
}
2369
}
@@ -35,7 +81,6 @@ type ChatResponse struct {
3581
}
3682

3783
// ModelResponse is the unified response from the provider.
38-
3984
type ModelResponse struct {
4085
Metadata map[string]string `json:"metadata,omitempty"`
4186
Message ChatMessage `json:"message"`
@@ -54,7 +99,4 @@ type ChatMessage struct {
5499
Role string `json:"role" validate:"required"`
55100
// The content of the message.
56101
Content string `json:"content" validate:"required"`
57-
// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,
58-
// with a maximum length of 64 characters.
59-
Name string `json:"name,omitempty"`
60102
}

pkg/api/schemas/chat_stream.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,18 @@ type StreamRequestID = string
2020

2121
// ChatStreamRequest defines a message that requests a new streaming chat
2222
type ChatStreamRequest struct {
23-
ID StreamRequestID `json:"id" validate:"required"`
24-
Message ChatMessage `json:"message" validate:"required"`
25-
MessageHistory []ChatMessage `json:"message_history" validate:"required"`
26-
OverrideParams *OverrideChatRequest `json:"override_params,omitempty"`
27-
Metadata *Metadata `json:"metadata,omitempty"`
23+
ID StreamRequestID `json:"id" validate:"required"`
24+
Message ChatMessage `json:"message" validate:"required"`
25+
MessageHistory []ChatMessage `json:"message_history" validate:"required"`
26+
OverrideParams *map[string]ModelParamsOverride `json:"override_params,omitempty"`
27+
Metadata *Metadata `json:"metadata,omitempty"`
2828
}
2929

3030
func NewChatStreamFromStr(message string) *ChatStreamRequest {
3131
return &ChatStreamRequest{
3232
Message: ChatMessage{
3333
"user",
3434
message,
35-
"glide",
3635
},
3736
}
3837
}

pkg/providers/anthropic/chat.go

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,18 @@ import (
1313
"go.uber.org/zap"
1414
)
1515

16-
type ChatMessage struct {
17-
Role string `json:"role"`
18-
Content string `json:"content"`
19-
}
20-
2116
// ChatRequest is an Anthropic-specific request schema
2217
type ChatRequest struct {
23-
Model string `json:"model"`
24-
Messages []ChatMessage `json:"messages"`
25-
System string `json:"system,omitempty"`
26-
Temperature float64 `json:"temperature,omitempty"`
27-
TopP float64 `json:"top_p,omitempty"`
28-
TopK int `json:"top_k,omitempty"`
29-
MaxTokens int `json:"max_tokens,omitempty"`
30-
Stream bool `json:"stream,omitempty"`
31-
Metadata *string `json:"metadata,omitempty"`
32-
StopSequences []string `json:"stop_sequences,omitempty"`
18+
Model string `json:"model"`
19+
Messages []schemas.ChatMessage `json:"messages"`
20+
System string `json:"system,omitempty"`
21+
Temperature float64 `json:"temperature,omitempty"`
22+
TopP float64 `json:"top_p,omitempty"`
23+
TopK int `json:"top_k,omitempty"`
24+
MaxTokens int `json:"max_tokens,omitempty"`
25+
Stream bool `json:"stream,omitempty"`
26+
Metadata *string `json:"metadata,omitempty"`
27+
StopSequences []string `json:"stop_sequences,omitempty"`
3328
}
3429

3530
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
package azureopenai
22

3-
type ChatMessage struct {
4-
Role string `json:"role"`
5-
Content string `json:"content"`
6-
}
3+
import "github.com/EinStack/glide/pkg/api/schemas"
74

85
// ChatRequest is an Azure openai-specific request schema
96
type ChatRequest struct {
10-
Messages []ChatMessage `json:"messages"`
11-
Temperature float64 `json:"temperature,omitempty"`
12-
TopP float64 `json:"top_p,omitempty"`
13-
MaxTokens int `json:"max_tokens,omitempty"`
14-
N int `json:"n,omitempty"`
15-
StopWords []string `json:"stop,omitempty"`
16-
Stream bool `json:"stream,omitempty"`
17-
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
18-
PresencePenalty int `json:"presence_penalty,omitempty"`
19-
LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
20-
User *string `json:"user,omitempty"`
21-
Seed *int `json:"seed,omitempty"`
22-
Tools []string `json:"tools,omitempty"`
23-
ToolChoice interface{} `json:"tool_choice,omitempty"`
24-
ResponseFormat interface{} `json:"response_format,omitempty"`
7+
Messages []schemas.ChatMessage `json:"messages"`
8+
Temperature float64 `json:"temperature,omitempty"`
9+
TopP float64 `json:"top_p,omitempty"`
10+
MaxTokens int `json:"max_tokens,omitempty"`
11+
N int `json:"n,omitempty"`
12+
StopWords []string `json:"stop,omitempty"`
13+
Stream bool `json:"stream,omitempty"`
14+
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
15+
PresencePenalty int `json:"presence_penalty,omitempty"`
16+
LogitBias *map[int]float64 `json:"logit_bias,omitempty"`
17+
User *string `json:"user,omitempty"`
18+
Seed *int `json:"seed,omitempty"`
19+
Tools []string `json:"tools,omitempty"`
20+
ToolChoice interface{} `json:"tool_choice,omitempty"`
21+
ResponseFormat interface{} `json:"response_format,omitempty"`
2522
}
2623

2724
// ChatCompletion
@@ -37,10 +34,10 @@ type ChatCompletion struct {
3734
}
3835

3936
type Choice struct {
40-
Index int `json:"index"`
41-
Message ChatMessage `json:"message"`
42-
Logprobs interface{} `json:"logprobs"`
43-
FinishReason string `json:"finish_reason"`
37+
Index int `json:"index"`
38+
Message schemas.ChatMessage `json:"message"`
39+
Logprobs interface{} `json:"logprobs"`
40+
FinishReason string `json:"finish_reason"`
4441
}
4542

4643
type Usage struct {
@@ -61,7 +58,7 @@ type ChatCompletionChunk struct {
6158
}
6259

6360
type StreamChoice struct {
64-
Index int `json:"index"`
65-
Delta ChatMessage `json:"delta"`
66-
FinishReason string `json:"finish_reason"`
61+
Index int `json:"index"`
62+
Delta schemas.ChatMessage `json:"delta"`
63+
FinishReason string `json:"finish_reason"`
6764
}

pkg/providers/bedrock/chat.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,6 @@ import (
1616
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
1717
)
1818

19-
type ChatMessage struct {
20-
Role string `json:"role"`
21-
Content string `json:"content"`
22-
}
23-
2419
// ChatRequest is an Bedrock-specific request schema
2520
type ChatRequest struct {
2621
Messages string `json:"inputText"`
@@ -116,7 +111,6 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
116111
Message: schemas.ChatMessage{
117112
Role: "assistant",
118113
Content: bedrockCompletion.Results[0].OutputText,
119-
Name: "",
120114
},
121115
TokenUsage: schemas.TokenUsage{
122116
PromptTokens: bedrockCompletion.Results[0].TokenCount,

pkg/providers/cohere/chat.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,8 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche
148148
"responseId": cohereCompletion.ResponseID,
149149
},
150150
Message: schemas.ChatMessage{
151-
Role: "model",
151+
Role: "assistant",
152152
Content: cohereCompletion.Text,
153-
Name: "",
154153
},
155154
TokenUsage: schemas.TokenUsage{
156155
PromptTokens: cohereCompletion.TokenCount.PromptTokens,

pkg/providers/lang.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ func (m *LanguageModel) Chat(ctx context.Context, request *schemas.ChatRequest)
108108

109109
func (m *LanguageModel) ChatStream(ctx context.Context, req *schemas.ChatStreamRequest) (<-chan *clients.ChatStreamResult, error) {
110110
stream, err := m.client.ChatStream(ctx, req)
111+
111112
if err != nil {
112113
m.healthTracker.TrackErr(err)
113114

pkg/providers/octoml/chat.go

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,17 @@ import (
1515
"go.uber.org/zap"
1616
)
1717

18-
type ChatMessage struct {
19-
Role string `json:"role"`
20-
Content string `json:"content"`
21-
}
22-
2318
// ChatRequest is an octoml-specific request schema
2419
type ChatRequest struct {
25-
Model string `json:"model"`
26-
Messages []ChatMessage `json:"messages"`
27-
Temperature float64 `json:"temperature,omitempty"`
28-
TopP float64 `json:"top_p,omitempty"`
29-
MaxTokens int `json:"max_tokens,omitempty"`
30-
StopWords []string `json:"stop,omitempty"`
31-
Stream bool `json:"stream,omitempty"`
32-
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
33-
PresencePenalty int `json:"presence_penalty,omitempty"`
20+
Model string `json:"model"`
21+
Messages []schemas.ChatMessage `json:"messages"`
22+
Temperature float64 `json:"temperature,omitempty"`
23+
TopP float64 `json:"top_p,omitempty"`
24+
MaxTokens int `json:"max_tokens,omitempty"`
25+
StopWords []string `json:"stop,omitempty"`
26+
Stream bool `json:"stream,omitempty"`
27+
FrequencyPenalty int `json:"frequency_penalty,omitempty"`
28+
PresencePenalty int `json:"presence_penalty,omitempty"`
3429
}
3530

3631
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives

pkg/providers/ollama/chat.go

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,31 +18,26 @@ import (
1818
"go.uber.org/zap"
1919
)
2020

21-
type ChatMessage struct {
22-
Role string `json:"role"`
23-
Content string `json:"content"`
24-
}
25-
2621
// ChatRequest is an ollama-specific request schema
2722
type ChatRequest struct {
28-
Model string `json:"model"`
29-
Messages []ChatMessage `json:"messages"`
30-
Microstat int `json:"microstat,omitempty"`
31-
MicrostatEta float64 `json:"microstat_eta,omitempty"`
32-
MicrostatTau float64 `json:"microstat_tau,omitempty"`
33-
NumCtx int `json:"num_ctx,omitempty"`
34-
NumGqa int `json:"num_gqa,omitempty"`
35-
NumGpu int `json:"num_gpu,omitempty"`
36-
NumThread int `json:"num_thread,omitempty"`
37-
RepeatLastN int `json:"repeat_last_n,omitempty"`
38-
Temperature float64 `json:"temperature,omitempty"`
39-
Seed int `json:"seed,omitempty"`
40-
StopWords []string `json:"stop,omitempty"`
41-
Tfsz float64 `json:"tfs_z,omitempty"`
42-
NumPredict int `json:"num_predict,omitempty"`
43-
TopK int `json:"top_k,omitempty"`
44-
TopP float64 `json:"top_p,omitempty"`
45-
Stream bool `json:"stream"`
23+
Model string `json:"model"`
24+
Messages []schemas.ChatMessage `json:"messages"`
25+
Microstat int `json:"microstat,omitempty"`
26+
MicrostatEta float64 `json:"microstat_eta,omitempty"`
27+
MicrostatTau float64 `json:"microstat_tau,omitempty"`
28+
NumCtx int `json:"num_ctx,omitempty"`
29+
NumGqa int `json:"num_gqa,omitempty"`
30+
NumGpu int `json:"num_gpu,omitempty"`
31+
NumThread int `json:"num_thread,omitempty"`
32+
RepeatLastN int `json:"repeat_last_n,omitempty"`
33+
Temperature float64 `json:"temperature,omitempty"`
34+
Seed int `json:"seed,omitempty"`
35+
StopWords []string `json:"stop,omitempty"`
36+
Tfsz float64 `json:"tfs_z,omitempty"`
37+
NumPredict int `json:"num_predict,omitempty"`
38+
TopK int `json:"top_k,omitempty"`
39+
TopP float64 `json:"top_p,omitempty"`
40+
Stream bool `json:"stream"`
4641
}
4742

4843
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
@@ -64,7 +59,6 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
6459
NumPredict: cfg.DefaultParams.NumPredict,
6560
TopP: cfg.DefaultParams.TopP,
6661
TopK: cfg.DefaultParams.TopK,
67-
Stream: cfg.DefaultParams.Stream,
6862
}
6963
}
7064

0 commit comments

Comments
 (0)