Skip to content

Commit c8a2d18

Browse files
committed
#153: Extended the stream chat request
1 parent 6e71630 commit c8a2d18

33 files changed

+127
-141
lines changed

pkg/api/http/server.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ import (
44
"context"
55
"errors"
66
"fmt"
7-
"github.com/gofiber/contrib/otelfiber"
87
"time"
98

9+
"github.com/gofiber/contrib/otelfiber"
10+
1011
"github.com/gofiber/swagger"
1112

1213
"github.com/EinStack/glide/docs"

pkg/api/schemas/chat_stream.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ type StreamRequestID = string
2020

2121
// ChatStreamRequest defines a message that requests a new streaming chat
2222
type ChatStreamRequest struct {
23-
ID StreamRequestID `json:"id" validate:"required"`
24-
Message ChatMessage `json:"message" validate:"required"`
25-
MessageHistory []ChatMessage `json:"message_history" validate:"required"`
23+
ID StreamRequestID `json:"id" validate:"required"`
24+
*ChatRequest
2625
OverrideParams *map[string]ModelParamsOverride `json:"override_params,omitempty"`
2726
Metadata *Metadata `json:"metadata,omitempty"`
2827
}
2928

3029
func NewChatStreamFromStr(message string) *ChatStreamRequest {
3130
return &ChatStreamRequest{
32-
Message: ChatMessage{
33-
"user",
34-
message,
31+
ChatRequest: &ChatRequest{
32+
Message: ChatMessage{
33+
"user",
34+
message,
35+
},
3536
},
3637
}
3738
}

pkg/providers/anthropic/chat.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8-
"github.com/EinStack/glide/pkg/providers/clients"
98
"io"
109
"net/http"
1110
"time"
1211

12+
"github.com/EinStack/glide/pkg/providers/clients"
13+
1314
"github.com/EinStack/glide/pkg/api/schemas"
1415
"go.uber.org/zap"
1516
)
@@ -35,7 +36,7 @@ func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) {
3536
// NewChatRequestFromConfig fills the struct from the config. Not using reflection because of performance penalty it gives
3637
func NewChatRequestFromConfig(cfg *Config) *ChatRequest {
3738
return &ChatRequest{
38-
Model: cfg.Model,
39+
Model: cfg.ModelName,
3940
System: cfg.DefaultParams.System,
4041
Temperature: cfg.DefaultParams.Temperature,
4142
TopP: cfg.DefaultParams.TopP,
@@ -59,7 +60,6 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
5960
chatReq.Stream = false
6061

6162
chatResponse, err := c.doChatRequest(ctx, &chatReq)
62-
6363
if err != nil {
6464
return nil, err
6565
}

pkg/providers/anthropic/client.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
5656
func (c *Client) Provider() string {
5757
return providerName
5858
}
59+
60+
func (c *Client) ModelName() string {
61+
return c.config.ModelName
62+
}

pkg/providers/anthropic/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ type Config struct {
3939
BaseURL string `yaml:"base_url" json:"base_url" validate:"required"`
4040
APIVersion string `yaml:"api_version" json:"api_version" validate:"required"`
4141
ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"`
42-
Model string `yaml:"model" json:"model" validate:"required"`
42+
ModelName string `yaml:"model" json:"model" validate:"required"`
4343
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
4444
DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"`
4545
}
@@ -52,7 +52,7 @@ func DefaultConfig() *Config {
5252
BaseURL: "https://api.anthropic.com/v1",
5353
APIVersion: "2023-06-01",
5454
ChatEndpoint: "/messages",
55-
Model: "claude-instant-1.2",
55+
ModelName: "claude-instant-1.2",
5656
DefaultParams: &defaultParams,
5757
}
5858
}

pkg/providers/azureopenai/chat.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@ import (
55
"context"
66
"encoding/json"
77
"fmt"
8-
"github.com/EinStack/glide/pkg/providers/clients"
98
"io"
109
"net/http"
1110

11+
"github.com/EinStack/glide/pkg/providers/clients"
12+
1213
"github.com/EinStack/glide/pkg/providers/openai"
1314

1415
"github.com/EinStack/glide/pkg/api/schemas"
@@ -46,7 +47,6 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas
4647
chatReq.Stream = false
4748

4849
chatResponse, err := c.doChatRequest(ctx, &chatReq)
49-
5050
if err != nil {
5151
return nil, err
5252
}

pkg/providers/azureopenai/chat_stream.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,6 @@ func (c *Client) SupportChatStream() bool {
158158
func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) {
159159
// Create a new chat request
160160
httpRequest, err := c.makeStreamReq(ctx, params)
161-
162161
if err != nil {
163162
return nil, err
164163
}

pkg/providers/azureopenai/client.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
3333
chatURL := fmt.Sprintf(
3434
"%s/openai/deployments/%s/chat/completions?api-version=%s",
3535
providerConfig.BaseURL,
36-
providerConfig.Model,
36+
providerConfig.ModelName,
3737
providerConfig.APIVersion,
3838
)
3939

@@ -60,3 +60,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *
6060
func (c *Client) Provider() string {
6161
return providerName
6262
}
63+
64+
func (c *Client) ModelName() string {
65+
return c.config.ModelName
66+
}

pkg/providers/azureopenai/client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) {
8181

8282
// Verify the default configuration values
8383
require.Equal(t, "/chat/completions", providerCfg.ChatEndpoint)
84-
require.Equal(t, "", providerCfg.Model)
84+
require.Equal(t, "", providerCfg.ModelName)
8585
require.Equal(t, "2023-05-15", providerCfg.APIVersion)
8686
require.NotNil(t, providerCfg.DefaultParams)
8787

pkg/providers/azureopenai/config.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ func (p *Params) UnmarshalYAML(unmarshal func(interface{}) error) error {
4444
type Config struct {
4545
BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` // The name of your Azure OpenAI Resource (e.g https://glide-test.openai.azure.com/)
4646
ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint"`
47-
Model string `yaml:"model" json:"model" validate:"required"` // This is your deployment name. You're required to first deploy a model before you can make calls (e.g. glide-gpt-35)
47+
ModelName string `yaml:"model" json:"model" validate:"required"` // This is your deployment name. You're required to first deploy a model before you can make calls (e.g. glide-gpt-35)
4848
APIVersion string `yaml:"api_version" json:"apiVersion" validate:"required"` // The API version to use for this operation. This follows the YYYY-MM-DD format (e.g 2023-05-15)
4949
APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"`
5050
DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"`
@@ -57,7 +57,7 @@ func DefaultConfig() *Config {
5757
return &Config{
5858
BaseURL: "", // This needs to come from config
5959
ChatEndpoint: "/chat/completions",
60-
Model: "", // This needs to come from config
60+
ModelName: "", // This needs to come from config
6161
APIVersion: "2023-05-15",
6262
DefaultParams: &defaultParams,
6363
}

0 commit comments

Comments
 (0)