diff --git a/.gitignore b/.gitignore index 066b8f56..18d81aa2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,11 @@ .idea dist .env +.env.bak config.yaml bin glide +glide.exe tmp coverage.txt precommit.txt diff --git a/pkg/cmd/cli.go b/cmd/cli.go similarity index 100% rename from pkg/cmd/cli.go rename to cmd/cli.go diff --git a/config.dev.yaml b/config.dev.yaml index 80c77a5b..8bd2af4a 100644 --- a/config.dev.yaml +++ b/config.dev.yaml @@ -8,5 +8,6 @@ routers: - id: default models: - id: openai - openai: - api_key: "${env:OPENAI_API_KEY}" + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/config.sample.yaml b/config.sample.yaml index 3ce72055..7118a6f5 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -6,3 +6,12 @@ telemetry: #api: # http: # ... + +routers: + language: + - id: default + models: + - id: openai + provider: + openai: + api_key: "${env:OPENAI_API_KEY}" diff --git a/main.go b/main.go index a6d84381..122d45ba 100644 --- a/main.go +++ b/main.go @@ -3,7 +3,7 @@ package main import ( "log" - "github.com/EinStack/glide/pkg/cmd" + "github.com/EinStack/glide/cmd" ) // @title Glide diff --git a/pkg/api/http/handlers.go b/pkg/api/http/handlers.go index 98e9f3a3..3db789eb 100644 --- a/pkg/api/http/handlers.go +++ b/pkg/api/http/handlers.go @@ -4,8 +4,10 @@ import ( "context" "sync" - "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/router" + "github.com/EinStack/glide/pkg/telemetry" "github.com/gofiber/contrib/websocket" "github.com/gofiber/fiber/v2" @@ -31,38 +33,38 @@ type Handler = func(c *fiber.Ctx) error // @Failure 400 {object} schemas.Error // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chat [POST] -func LangChatHandler(routerManager *routers.RouterManager) Handler { +func LangChatHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if !c.Is("json") { - return c.Status(fiber.StatusBadRequest).JSON(schemas.ErrUnsupportedMediaType) + return c.Status(fiber.StatusBadRequest).JSON(schema.ErrUnsupportedMediaType) } // Unmarshal request body - req := schemas.GetChatRequest() - defer schemas.ReleaseChatRequest(req) + req := schema.GetChatRequest() + defer schema.ReleaseChatRequest(req) err := c.BodyParser(&req) if err != nil { - return c.Status(fiber.StatusBadRequest).JSON(schemas.NewPayloadParseErr(err)) + return c.Status(fiber.StatusBadRequest).JSON(schema.NewPayloadParseErr(err)) } // Get router ID from path routerID := c.Params("router") - router, err := routerManager.GetLangRouter(routerID) + r, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } // Chat with router - resp := schemas.GetChatResponse() - defer schemas.ReleaseChatResponse(resp) + resp := schema.GetChatResponse() + defer schema.ReleaseChatResponse(resp) - resp, err = router.Chat(c.Context(), req) + resp, err = r.Chat(c.Context(), req) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -72,14 +74,14 @@ func LangChatHandler(routerManager *routers.RouterManager) Handler { } } -func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { +func LangStreamRouterValidator(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { routerID := c.Params("router") _, err := routerManager.GetLangRouter(routerID) if err != nil { - httpErr := schemas.FromErr(err) + httpErr := schema.FromErr(err) return c.Status(httpErr.Status).JSON(httpErr) } @@ -107,7 +109,7 @@ func LangStreamRouterValidator(routerManager *routers.RouterManager) Handler { // @Failure 426 // @Failure 404 {object} schemas.Error // @Router /v1/language/{router}/chatStream [GET] -func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.RouterManager) Handler { +func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *router.Manager) Handler { // TODO: expose websocket connection configs https://github.com/gofiber/contrib/tree/main/websocket return websocket.New(func(c *websocket.Conn) { routerID := c.Params("router") @@ -118,9 +120,9 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout wg sync.WaitGroup ) - chatStreamC := make(chan *schemas.ChatStreamMessage) + chatStreamC := make(chan *schema.ChatStreamMessage) - router, _ := routerManager.GetLangRouter(routerID) + r, _ := routerManager.GetLangRouter(routerID) defer close(chatStreamC) defer c.Conn.Close() @@ -138,7 +140,7 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout }() for { - var chatRequest schemas.ChatStreamRequest + var chatRequest schema.ChatStreamRequest if err = c.ReadJSON(&chatRequest); err != nil { // TODO: handle bad request schemas gracefully and return back validation errors @@ -154,10 +156,10 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout // TODO: handle termination gracefully wg.Add(1) - go func(chatRequest schemas.ChatStreamRequest) { + go func(chatRequest schema.ChatStreamRequest) { defer wg.Done() - router.ChatStream(context.Background(), &chatRequest, chatStreamC) + r.ChatStream(context.Background(), &chatRequest, chatStreamC) }(chatRequest) } @@ -175,16 +177,16 @@ func LangStreamChatHandler(tel *telemetry.Telemetry, routerManager *routers.Rout // @Produce json // @Success 200 {object} schemas.RouterListSchema // @Router /v1/language/ [GET] -func LangRoutersHandler(routerManager *routers.RouterManager) Handler { +func LangRoutersHandler(routerManager *router.Manager) Handler { return func(c *fiber.Ctx) error { configuredRouters := routerManager.GetLangRouters() cfgs := make([]interface{}, 0, len(configuredRouters)) // opaque by design - for _, router := range configuredRouters { - cfgs = append(cfgs, router.Config) + for _, r := range configuredRouters { + cfgs = append(cfgs, r.Config) } - return c.Status(fiber.StatusOK).JSON(schemas.RouterListSchema{Routers: cfgs}) + return c.Status(fiber.StatusOK).JSON(schema.RouterListSchema{Routers: cfgs}) } } @@ -199,9 +201,9 @@ func LangRoutersHandler(routerManager *routers.RouterManager) Handler { // @Success 200 {object} schemas.HealthSchema // @Router /v1/health/ [get] func HealthHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusOK).JSON(schemas.HealthSchema{Healthy: true}) + return c.Status(fiber.StatusOK).JSON(schema.HealthSchema{Healthy: true}) } func NotFoundHandler(c *fiber.Ctx) error { - return c.Status(fiber.StatusNotFound).JSON(schemas.ErrRouteNotFound) + return c.Status(fiber.StatusNotFound).JSON(schema.ErrRouteNotFound) } diff --git a/pkg/api/http/server.go b/pkg/api/http/server.go index 9b70a05f..6422623a 100644 --- a/pkg/api/http/server.go +++ b/pkg/api/http/server.go @@ -6,6 +6,8 @@ import ( "fmt" "time" + "github.com/EinStack/glide/pkg/router" + "github.com/gofiber/contrib/otelfiber" "github.com/gofiber/swagger" @@ -17,19 +19,17 @@ import ( "github.com/gofiber/fiber/v2" - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/telemetry" ) type Server struct { config *ServerConfig telemetry *telemetry.Telemetry - routerManager *routers.RouterManager + routerManager *router.Manager server *fiber.App } -func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *routers.RouterManager) (*Server, error) { +func NewServer(config *ServerConfig, tel *telemetry.Telemetry, routerManager *router.Manager) (*Server, error) { srv := config.ToServer() return &Server{ diff --git a/pkg/api/schemas/chat.go b/pkg/api/schema/chat.go similarity index 99% rename from pkg/api/schemas/chat.go rename to pkg/api/schema/chat.go index bb846043..b833b367 100644 --- a/pkg/api/schemas/chat.go +++ b/pkg/api/schema/chat.go @@ -1,4 +1,4 @@ -package schemas +package schema // ChatRequest defines Glide's Chat Request Schema unified across all language models type ChatRequest struct { diff --git a/pkg/api/schemas/chat_stream.go b/pkg/api/schema/chat_stream.go similarity index 99% rename from pkg/api/schemas/chat_stream.go rename to pkg/api/schema/chat_stream.go index f7cf8b27..ee1cd228 100644 --- a/pkg/api/schemas/chat_stream.go +++ b/pkg/api/schema/chat_stream.go @@ -1,4 +1,4 @@ -package schemas +package schema import "time" diff --git a/pkg/api/schemas/chat_test.go b/pkg/api/schema/chat_test.go similarity index 99% rename from pkg/api/schemas/chat_test.go rename to pkg/api/schema/chat_test.go index 9b5ce407..9d77da62 100644 --- a/pkg/api/schemas/chat_test.go +++ b/pkg/api/schema/chat_test.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "testing" diff --git a/pkg/api/schema/embed.go b/pkg/api/schema/embed.go new file mode 100644 index 00000000..5698d330 --- /dev/null +++ b/pkg/api/schema/embed.go @@ -0,0 +1,9 @@ +package schema + +type EmbedRequest struct { + // TODO: implement +} + +type EmbedResponse struct { + // TODO: implement +} diff --git a/pkg/api/schemas/errors.go b/pkg/api/schema/errors.go similarity index 99% rename from pkg/api/schemas/errors.go rename to pkg/api/schema/errors.go index 2765f93e..0eecf0b5 100644 --- a/pkg/api/schemas/errors.go +++ b/pkg/api/schema/errors.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "fmt" diff --git a/pkg/api/schemas/health_checks.go b/pkg/api/schema/health.go similarity index 79% rename from pkg/api/schemas/health_checks.go rename to pkg/api/schema/health.go index 6078e769..896e00c5 100644 --- a/pkg/api/schemas/health_checks.go +++ b/pkg/api/schema/health.go @@ -1,4 +1,4 @@ -package schemas +package schema type HealthSchema struct { Healthy bool `json:"healthy"` diff --git a/pkg/api/schemas/pool.go b/pkg/api/schema/pool.go similarity index 97% rename from pkg/api/schemas/pool.go rename to pkg/api/schema/pool.go index dcd9ccf8..4b5c38ba 100755 --- a/pkg/api/schemas/pool.go +++ b/pkg/api/schema/pool.go @@ -1,4 +1,4 @@ -package schemas +package schema import ( "sync" diff --git a/pkg/api/schemas/routers.go b/pkg/api/schema/routers.go similarity index 95% rename from pkg/api/schemas/routers.go rename to pkg/api/schema/routers.go index 9111a319..18dcee02 100644 --- a/pkg/api/schemas/routers.go +++ b/pkg/api/schema/routers.go @@ -1,4 +1,4 @@ -package schemas +package schema // RouterListSchema returns list of active configured routers. // diff --git a/pkg/api/servers.go b/pkg/api/servers.go index 3588e257..fd0a281e 100644 --- a/pkg/api/servers.go +++ b/pkg/api/servers.go @@ -4,9 +4,9 @@ import ( "context" "sync" - "go.uber.org/zap" + "github.com/EinStack/glide/pkg/router" - "github.com/EinStack/glide/pkg/routers" + "go.uber.org/zap" "github.com/EinStack/glide/pkg/telemetry" @@ -19,7 +19,7 @@ type ServerManager struct { telemetry *telemetry.Telemetry } -func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *routers.RouterManager) (*ServerManager, error) { +func NewServerManager(cfg *Config, tel *telemetry.Telemetry, router *router.Manager) (*ServerManager, error) { httpServer, err := http.NewServer(cfg.HTTP, tel, router) if err != nil { return nil, err diff --git a/pkg/providers/clients/config.go b/pkg/clients/config.go similarity index 100% rename from pkg/providers/clients/config.go rename to pkg/clients/config.go diff --git a/pkg/providers/clients/config_test.go b/pkg/clients/config_test.go similarity index 100% rename from pkg/providers/clients/config_test.go rename to pkg/clients/config_test.go diff --git a/pkg/providers/clients/errors.go b/pkg/clients/errors.go similarity index 100% rename from pkg/providers/clients/errors.go rename to pkg/clients/errors.go diff --git a/pkg/providers/clients/errors_test.go b/pkg/clients/errors_test.go similarity index 100% rename from pkg/providers/clients/errors_test.go rename to pkg/clients/errors_test.go diff --git a/pkg/providers/clients/sse.go b/pkg/clients/sse.go similarity index 100% rename from pkg/providers/clients/sse.go rename to pkg/clients/sse.go diff --git a/pkg/providers/clients/sse_test.go b/pkg/clients/sse_test.go similarity index 100% rename from pkg/providers/clients/sse_test.go rename to pkg/clients/sse_test.go diff --git a/pkg/clients/stream.go b/pkg/clients/stream.go new file mode 100644 index 00000000..4ab55fb0 --- /dev/null +++ b/pkg/clients/stream.go @@ -0,0 +1,29 @@ +package clients + +import "github.com/EinStack/glide/pkg/api/schema" + +type ChatStream interface { + Open() error + Recv() (*schema.ChatStreamChunk, error) + Close() error +} + +type ChatStreamResult struct { + chunk *schema.ChatStreamChunk + err error +} + +func (r *ChatStreamResult) Chunk() *schema.ChatStreamChunk { + return r.chunk +} + +func (r *ChatStreamResult) Error() error { + return r.err +} + +func NewChatStreamResult(chunk *schema.ChatStreamChunk, err error) *ChatStreamResult { + return &ChatStreamResult{ + chunk: chunk, + err: err, + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index dd520a9a..9f390a45 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -1,18 +1,16 @@ package config import ( - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers" - "github.com/EinStack/glide/pkg/api" + "github.com/EinStack/glide/pkg/router" + "github.com/EinStack/glide/pkg/telemetry" ) // Config is a general top-level Glide configuration type Config struct { - Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` - API *api.Config `yaml:"api" validate:"required"` - Routers routers.Config `yaml:"routers" validate:"required"` + Telemetry *telemetry.Config `yaml:"telemetry" validate:"required"` + API *api.Config `yaml:"api" validate:"required"` + Routers router.RoutersConfig `yaml:"routers" validate:"required"` } func DefaultConfig() *Config { diff --git a/pkg/extmodel/config.go b/pkg/extmodel/config.go new file mode 100644 index 00000000..3edd45f0 --- /dev/null +++ b/pkg/extmodel/config.go @@ -0,0 +1,51 @@ +package extmodel + +import ( + "fmt" + + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/router/latency" + "github.com/EinStack/glide/pkg/telemetry" +) + +// Config defines an extra configuration for a model wrapper around a provider +type Config[P provider.Configurer] struct { + ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? + ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` + Latency *latency.Config `yaml:"latency" json:"latency"` + Weight int `yaml:"weight" json:"weight"` + Client *clients.ClientConfig `yaml:"client" json:"client"` + + Provider P `yaml:"provider" json:"provider"` +} + +func NewConfig[P provider.Configurer](ID string) *Config[P] { + config := DefaultConfig[P]() + + config.ID = ID + + return &config +} + +func DefaultConfig[P provider.Configurer]() Config[P] { + return Config[P]{ + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Weight: 1, + } +} + +func (c *Config[P]) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { + client, err := c.Provider.ToClient(tel, c.Client) + if err != nil { + return nil, fmt.Errorf("error initializing client: %w", err) + } + + return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil +} diff --git a/pkg/providers/lang.go b/pkg/extmodel/lang.go similarity index 75% rename from pkg/providers/lang.go rename to pkg/extmodel/lang.go index d2a6aa06..7c95282a 100644 --- a/pkg/providers/lang.go +++ b/pkg/extmodel/lang.go @@ -1,37 +1,28 @@ -package providers +package extmodel import ( "context" "io" "time" - "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/routers/health" + "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/clients" + health2 "github.com/EinStack/glide/pkg/resiliency/health" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/config/fields" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/router/latency" ) -// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests -type LangProvider interface { - ModelProvider - - SupportChatStream() bool - - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) -} - type LangModel interface { - Model + Interface Provider() string ModelName() string - Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) - ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) } // LanguageModel wraps provider client and expend it with health & latency tracking @@ -41,18 +32,18 @@ type LangModel interface { type LanguageModel struct { modelID string weight int - client LangProvider - healthTracker *health.Tracker + client provider.LangProvider + healthTracker *health2.Tracker chatLatency *latency.MovingAverage chatStreamLatency *latency.MovingAverage latencyUpdateInterval *fields.Duration } -func NewLangModel(modelID string, client LangProvider, budget *health.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { +func NewLangModel(modelID string, client provider.LangProvider, budget *health2.ErrorBudget, latencyConfig latency.Config, weight int) *LanguageModel { return &LanguageModel{ modelID: modelID, client: client, - healthTracker: health.NewTracker(budget), + healthTracker: health2.NewTracker(budget), chatLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), chatStreamLatency: latency.NewMovingAverage(latencyConfig.Decay, latencyConfig.WarmupSamples), latencyUpdateInterval: latencyConfig.UpdateInterval, @@ -88,7 +79,7 @@ func (m LanguageModel) ChatStreamLatency() *latency.MovingAverage { return m.chatStreamLatency } -func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (m *LanguageModel) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { startedAt := time.Now() resp, err := m.client.Chat(ctx, params) @@ -107,7 +98,7 @@ func (m *LanguageModel) Chat(ctx context.Context, params *schemas.ChatParams) (* return resp, err } -func (m *LanguageModel) ChatStream(ctx context.Context, params *schemas.ChatParams) (<-chan *clients.ChatStreamResult, error) { +func (m *LanguageModel) ChatStream(ctx context.Context, params *schema.ChatParams) (<-chan *clients.ChatStreamResult, error) { stream, err := m.client.ChatStream(ctx, params) if err != nil { m.healthTracker.TrackErr(err) @@ -179,10 +170,10 @@ func (m *LanguageModel) ModelName() string { return m.client.ModelName() } -func ChatLatency(model Model) *latency.MovingAverage { +func ChatLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatLatency() } -func ChatStreamLatency(model Model) *latency.MovingAverage { +func ChatStreamLatency(model Interface) *latency.MovingAverage { return model.(LanguageModel).ChatStreamLatency() } diff --git a/pkg/extmodel/model.go b/pkg/extmodel/model.go new file mode 100644 index 00000000..b250c470 --- /dev/null +++ b/pkg/extmodel/model.go @@ -0,0 +1,11 @@ +package extmodel + +import "github.com/EinStack/glide/pkg/config/fields" + +// Interface represent a configured external modality-agnostic model with its routing properties and status +type Interface interface { + ID() string + Healthy() bool + LatencyUpdateInterval() *fields.Duration + Weight() int +} diff --git a/pkg/providers/testing/models.go b/pkg/extmodel/testing.go similarity index 84% rename from pkg/providers/testing/models.go rename to pkg/extmodel/testing.go index d4ac3840..86829610 100644 --- a/pkg/providers/testing/models.go +++ b/pkg/extmodel/testing.go @@ -1,13 +1,10 @@ -package testing +package extmodel import ( "time" "github.com/EinStack/glide/pkg/config/fields" - - "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/router/latency" ) // LangModelMock @@ -55,6 +52,6 @@ func (m LangModelMock) Weight() int { return m.weight } -func ChatMockLatency(model providers.Model) *latency.MovingAverage { +func ChatMockLatency(model Interface) *latency.MovingAverage { return model.(LangModelMock).chatLatency } diff --git a/pkg/gateway.go b/pkg/gateway.go index 950ec26d..fd6c1878 100644 --- a/pkg/gateway.go +++ b/pkg/gateway.go @@ -7,7 +7,8 @@ import ( "os/signal" "syscall" - "github.com/EinStack/glide/pkg/routers" + "github.com/EinStack/glide/pkg/router" + "github.com/EinStack/glide/pkg/version" "go.opentelemetry.io/contrib/instrumentation/host" "go.opentelemetry.io/contrib/instrumentation/runtime" @@ -49,7 +50,7 @@ func NewGateway(configProvider *config.Provider) (*Gateway, error) { tel.L().Info("🐦Glide is starting up", zap.String("version", version.FullVersion)) tel.L().Debug("✅ Config loaded successfully:\n" + configProvider.GetStr()) - routerManager, err := routers.NewManager(&cfg.Routers, tel) + routerManager, err := router.NewManager(&cfg.Routers, tel) if err != nil { return nil, err } diff --git a/pkg/providers/anthropic/chat.go b/pkg/provider/anthropic/chat.go similarity index 76% rename from pkg/providers/anthropic/chat.go rename to pkg/provider/anthropic/chat.go index 80b45f2b..bb0559ad 100644 --- a/pkg/providers/anthropic/chat.go +++ b/pkg/provider/anthropic/chat.go @@ -9,27 +9,28 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) // ChatRequest is an Anthropic-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - System string `json:"system,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - TopK int `json:"top_k,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - Stream bool `json:"stream,omitempty"` - Metadata *string `json:"metadata,omitempty"` - StopSequences []string `json:"stop_sequences,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + System string `json:"system,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + TopK int `json:"top_k,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Metadata *string `json:"metadata,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { r.Messages = params.Messages } @@ -51,7 +52,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { // Chat sends a chat request to the specified anthropic model. // // Ref: https://docs.anthropic.com/claude/reference/messages_post -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -67,7 +68,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -130,19 +131,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche usage := anthropicResponse.Usage // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: anthropicResponse.ID, Created: int(time.Now().UTC().Unix()), // not provided by anthropic - Provider: providerName, + Provider: ProviderID, ModelName: anthropicResponse.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: completion.Type, Content: completion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: usage.InputTokens, ResponseTokens: usage.OutputTokens, TotalTokens: usage.InputTokens + usage.OutputTokens, diff --git a/pkg/provider/anthropic/chat_stream.go b/pkg/provider/anthropic/chat_stream.go new file mode 100644 index 00000000..1a9f88a4 --- /dev/null +++ b/pkg/provider/anthropic/chat_stream.go @@ -0,0 +1,17 @@ +package anthropic + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/anthropic/client.go b/pkg/provider/anthropic/client.go similarity index 87% rename from pkg/providers/anthropic/client.go rename to pkg/provider/anthropic/client.go index bb34fe07..2e08b2e2 100644 --- a/pkg/providers/anthropic/client.go +++ b/pkg/provider/anthropic/client.go @@ -5,13 +5,15 @@ import ( "net/url" "time" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" ) const ( - providerName = "anthropic" + ProviderID = "anthropic" ) // Client is a client for accessing OpenAI API @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) @@ -54,7 +61,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/anthropic/client_test.go b/pkg/provider/anthropic/client_test.go similarity index 91% rename from pkg/providers/anthropic/client_test.go rename to pkg/provider/anthropic/client_test.go index b0c11f36..2fe33334 100644 --- a/pkg/providers/anthropic/client_test.go +++ b/pkg/provider/anthropic/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestAnthropicClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -86,7 +86,7 @@ func TestAnthropicClient_BadChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/anthropic/config.go b/pkg/provider/anthropic/config.go similarity index 87% rename from pkg/providers/anthropic/config.go rename to pkg/provider/anthropic/config.go index abdb5b73..1d252811 100644 --- a/pkg/providers/anthropic/config.go +++ b/pkg/provider/anthropic/config.go @@ -1,7 +1,10 @@ package anthropic import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines OpenAI-specific model params with the specific validation of values @@ -57,6 +60,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/providers/anthropic/errors.go b/pkg/provider/anthropic/errors.go similarity index 96% rename from pkg/providers/anthropic/errors.go rename to pkg/provider/anthropic/errors.go index 126de68d..5c7a1370 100644 --- a/pkg/providers/anthropic/errors.go +++ b/pkg/provider/anthropic/errors.go @@ -6,9 +6,10 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/provider/anthropic/register.go b/pkg/provider/anthropic/register.go new file mode 100644 index 00000000..9b00ffc6 --- /dev/null +++ b/pkg/provider/anthropic/register.go @@ -0,0 +1,7 @@ +package anthropic + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} diff --git a/pkg/providers/anthropic/schamas.go b/pkg/provider/anthropic/schamas.go similarity index 100% rename from pkg/providers/anthropic/schamas.go rename to pkg/provider/anthropic/schamas.go diff --git a/pkg/providers/anthropic/testdata/chat.req.json b/pkg/provider/anthropic/testdata/chat.req.json similarity index 100% rename from pkg/providers/anthropic/testdata/chat.req.json rename to pkg/provider/anthropic/testdata/chat.req.json diff --git a/pkg/providers/anthropic/testdata/chat.success.json b/pkg/provider/anthropic/testdata/chat.success.json similarity index 100% rename from pkg/providers/anthropic/testdata/chat.success.json rename to pkg/provider/anthropic/testdata/chat.success.json diff --git a/pkg/providers/azureopenai/chat.go b/pkg/provider/azureopenai/chat.go similarity index 88% rename from pkg/providers/azureopenai/chat.go rename to pkg/provider/azureopenai/chat.go index 22005fa3..d2f1200e 100644 --- a/pkg/providers/azureopenai/chat.go +++ b/pkg/provider/azureopenai/chat.go @@ -8,11 +8,11 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/provider/openai" "go.uber.org/zap" ) @@ -38,7 +38,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified azure openai model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -54,7 +54,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -110,19 +110,19 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, Provider: providerName, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{}, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/providers/azureopenai/chat_stream.go b/pkg/provider/azureopenai/chat_stream.go similarity index 89% rename from pkg/providers/azureopenai/chat_stream.go rename to pkg/provider/azureopenai/chat_stream.go index 8e73a556..f75fae4c 100644 --- a/pkg/providers/azureopenai/chat_stream.go +++ b/pkg/provider/azureopenai/chat_stream.go @@ -8,16 +8,17 @@ import ( "io" "net/http" + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider/openai" - "github.com/EinStack/glide/pkg/providers/clients" "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) // TODO: Think about reducing the number of copy-pasted code btw OpenAI and Azure providers @@ -33,6 +34,11 @@ type ChatStream struct { errMapper *ErrorMapper } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( tel *telemetry.Telemetry, client *http.Client, @@ -67,7 +73,7 @@ func (s *ChatStream) Open() error { } // Recv receives a chat stream chunk from the ChatStream and returns a ChatStreamChunk object. -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -124,16 +130,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, Provider: providerName, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: responseChunk.Delta.Role, Content: responseChunk.Delta.Content, }, @@ -155,7 +161,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -171,7 +177,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/providers/azureopenai/chat_stream_test.go b/pkg/provider/azureopenai/chat_stream_test.go similarity index 89% rename from pkg/providers/azureopenai/chat_stream_test.go rename to pkg/provider/azureopenai/chat_stream_test.go index 5aade1f5..f056d599 100644 --- a/pkg/providers/azureopenai/chat_stream_test.go +++ b/pkg/provider/azureopenai/chat_stream_test.go @@ -10,18 +10,18 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/telemetry" + clients2 "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) func TestAzureOpenAIClient_ChatStreamSupported(t *testing.T) { providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) @@ -64,14 +64,14 @@ func TestAzureOpenAIClient_ChatStreamRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = AzureopenAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -132,14 +132,14 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} @@ -153,7 +153,7 @@ func TestAzureOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { for { chunk, err := stream.Recv() if err != nil { - require.ErrorIs(t, err, clients.ErrProviderUnavailable) + require.ErrorIs(t, err, clients2.ErrProviderUnavailable) return } diff --git a/pkg/providers/azureopenai/client.go b/pkg/provider/azureopenai/client.go similarity index 88% rename from pkg/providers/azureopenai/client.go rename to pkg/provider/azureopenai/client.go index 0f594805..6ec90469 100644 --- a/pkg/providers/azureopenai/client.go +++ b/pkg/provider/azureopenai/client.go @@ -5,11 +5,13 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/clients" + + "github.com/EinStack/glide/pkg/provider/openai" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" ) const ( @@ -28,6 +30,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Azure OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL := fmt.Sprintf( diff --git a/pkg/providers/azureopenai/client_test.go b/pkg/provider/azureopenai/client_test.go similarity index 91% rename from pkg/providers/azureopenai/client_test.go rename to pkg/provider/azureopenai/client_test.go index 1700bca0..accca38d 100644 --- a/pkg/providers/azureopenai/client_test.go +++ b/pkg/provider/azureopenai/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestAzureOpenAIClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -88,7 +88,7 @@ func TestAzureOpenAIClient_ChatError(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -115,7 +115,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealio?", }}} diff --git a/pkg/providers/azureopenai/config.go b/pkg/provider/azureopenai/config.go similarity index 100% rename from pkg/providers/azureopenai/config.go rename to pkg/provider/azureopenai/config.go diff --git a/pkg/providers/azureopenai/errors.go b/pkg/provider/azureopenai/errors.go similarity index 96% rename from pkg/providers/azureopenai/errors.go rename to pkg/provider/azureopenai/errors.go index 6a30e989..d659c027 100644 --- a/pkg/providers/azureopenai/errors.go +++ b/pkg/provider/azureopenai/errors.go @@ -6,9 +6,10 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/provider/azureopenai/schemas.go b/pkg/provider/azureopenai/schemas.go new file mode 100644 index 00000000..2ce12eb5 --- /dev/null +++ b/pkg/provider/azureopenai/schemas.go @@ -0,0 +1,68 @@ +package azureopenai + +import "github.com/EinStack/glide/pkg/api/schema" + +// ChatRequest is an Azure openai-specific request schema +type ChatRequest struct { + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { + r.Messages = params.Messages +} + +// ChatCompletion +// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions +type ChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens float64 `json:"prompt_tokens"` + CompletionTokens float64 `json:"completion_tokens"` + TotalTokens float64 `json:"total_tokens"` +} + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []StreamChoice `json:"choices"` +} + +type StreamChoice struct { + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` +} diff --git a/pkg/providers/azureopenai/testdata/chat.req.json b/pkg/provider/azureopenai/testdata/chat.req.json similarity index 100% rename from pkg/providers/azureopenai/testdata/chat.req.json rename to pkg/provider/azureopenai/testdata/chat.req.json diff --git a/pkg/providers/azureopenai/testdata/chat.success.json b/pkg/provider/azureopenai/testdata/chat.success.json similarity index 100% rename from pkg/providers/azureopenai/testdata/chat.success.json rename to pkg/provider/azureopenai/testdata/chat.success.json diff --git a/pkg/providers/azureopenai/testdata/chat_stream.empty.txt b/pkg/provider/azureopenai/testdata/chat_stream.empty.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.empty.txt rename to pkg/provider/azureopenai/testdata/chat_stream.empty.txt diff --git a/pkg/providers/azureopenai/testdata/chat_stream.nodone.txt b/pkg/provider/azureopenai/testdata/chat_stream.nodone.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.nodone.txt rename to pkg/provider/azureopenai/testdata/chat_stream.nodone.txt diff --git a/pkg/providers/azureopenai/testdata/chat_stream.success.txt b/pkg/provider/azureopenai/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/azureopenai/testdata/chat_stream.success.txt rename to pkg/provider/azureopenai/testdata/chat_stream.success.txt diff --git a/pkg/providers/bedrock/chat.go b/pkg/provider/bedrock/chat.go similarity index 88% rename from pkg/providers/bedrock/chat.go rename to pkg/provider/bedrock/chat.go index 658c1769..cd51027b 100644 --- a/pkg/providers/bedrock/chat.go +++ b/pkg/provider/bedrock/chat.go @@ -6,7 +6,7 @@ import ( "fmt" "time" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "go.uber.org/zap" @@ -22,7 +22,7 @@ type ChatRequest struct { TextGenerationConfig TextGenerationConfig `json:"textGenerationConfig"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // message history not yet supported for AWS models // TODO: do something about lack of message history. Maybe just concatenate all messages? // in any case, this is not a way to go to ignore message history @@ -51,7 +51,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified bedrock model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -65,7 +65,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { rawPayload, err := json.Marshal(payload) if err != nil { return nil, fmt.Errorf("unable to marshal chat request payload: %w", err) @@ -96,18 +96,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche return nil, ErrEmptyResponse } - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: "assistant", Content: modelResult.OutputText, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ // TODO: what would happen if there is a few responses? We need to sum that up PromptTokens: modelResult.TokenCount, ResponseTokens: -1, diff --git a/pkg/provider/bedrock/chat_stream.go b/pkg/provider/bedrock/chat_stream.go new file mode 100644 index 00000000..6bb87905 --- /dev/null +++ b/pkg/provider/bedrock/chat_stream.go @@ -0,0 +1,17 @@ +package bedrock + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/bedrock/client.go b/pkg/provider/bedrock/client.go similarity index 92% rename from pkg/providers/bedrock/client.go rename to pkg/provider/bedrock/client.go index 0567b9fc..aa3905fd 100644 --- a/pkg/providers/bedrock/client.go +++ b/pkg/provider/bedrock/client.go @@ -7,9 +7,11 @@ import ( "net/url" "time" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials" @@ -36,6 +38,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint, providerConfig.ModelName, "/invoke") diff --git a/pkg/providers/bedrock/client_test.go b/pkg/provider/bedrock/client_test.go similarity index 90% rename from pkg/providers/bedrock/client_test.go rename to pkg/provider/bedrock/client_test.go index cdae1f68..f6081966 100644 --- a/pkg/providers/bedrock/client_test.go +++ b/pkg/provider/bedrock/client_test.go @@ -11,9 +11,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -61,7 +61,7 @@ func TestBedrockClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/bedrock/config.go b/pkg/provider/bedrock/config.go similarity index 100% rename from pkg/providers/bedrock/config.go rename to pkg/provider/bedrock/config.go diff --git a/pkg/providers/bedrock/schemas.go b/pkg/provider/bedrock/schemas.go similarity index 100% rename from pkg/providers/bedrock/schemas.go rename to pkg/provider/bedrock/schemas.go diff --git a/pkg/providers/bedrock/testdata/chat.req.json b/pkg/provider/bedrock/testdata/chat.req.json similarity index 100% rename from pkg/providers/bedrock/testdata/chat.req.json rename to pkg/provider/bedrock/testdata/chat.req.json diff --git a/pkg/providers/bedrock/testdata/chat.success.json b/pkg/provider/bedrock/testdata/chat.success.json similarity index 100% rename from pkg/providers/bedrock/testdata/chat.success.json rename to pkg/provider/bedrock/testdata/chat.success.json diff --git a/pkg/providers/cohere/chat.go b/pkg/provider/cohere/chat.go similarity index 89% rename from pkg/providers/cohere/chat.go rename to pkg/provider/cohere/chat.go index ddf75680..7b2ebbb9 100644 --- a/pkg/providers/cohere/chat.go +++ b/pkg/provider/cohere/chat.go @@ -9,9 +9,9 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "go.uber.org/zap" ) @@ -30,7 +30,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified cohere model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate @@ -44,7 +44,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -115,22 +115,22 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: cohereCompletion.ResponseID, Created: int(time.Now().UTC().Unix()), // Cohere doesn't provide this - Provider: providerName, + Provider: ProviderID, ModelName: c.config.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "generationId": cohereCompletion.GenerationID, "responseId": cohereCompletion.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", Content: cohereCompletion.Text, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: cohereCompletion.TokenCount.PromptTokens, ResponseTokens: cohereCompletion.TokenCount.ResponseTokens, TotalTokens: cohereCompletion.TokenCount.TotalTokens, diff --git a/pkg/providers/cohere/chat_stream.go b/pkg/provider/cohere/chat_stream.go similarity index 85% rename from pkg/providers/cohere/chat_stream.go rename to pkg/provider/cohere/chat_stream.go index 1d8ed243..46b07598 100644 --- a/pkg/providers/cohere/chat_stream.go +++ b/pkg/provider/cohere/chat_stream.go @@ -8,13 +8,13 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/clients" - "go.uber.org/zap" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" + "go.uber.org/zap" ) // SupportedEventType Cohere has other types too: @@ -41,6 +41,11 @@ type ChatStream struct { tel *telemetry.Telemetry } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( tel *telemetry.Telemetry, client *http.Client, @@ -78,7 +83,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { if s.streamFinished { return nil, io.EOF } @@ -90,7 +95,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if err != nil { s.tel.L().Warn( "Chat stream is unexpectedly disconnected", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), ) @@ -101,7 +106,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.tel.L().Debug( "Raw chat stream chunk", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("rawChunk", rawChunk), ) @@ -119,7 +124,7 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { if responseChunk.EventType != TextGenEvent && responseChunk.EventType != StreamEndEvent { s.tel.L().Debug( "Unsupported stream chunk type, skipping it", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.ByteString("chunk", rawChunk), ) @@ -130,16 +135,16 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { s.streamFinished = true // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, "response_id": responseChunk.Response.ResponseID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -149,15 +154,15 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { } // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: s.modelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "generation_id": s.generationID, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "model", Content: responseChunk.Text, }, @@ -178,7 +183,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -195,7 +200,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate chatReq.ApplyParams(params) diff --git a/pkg/providers/cohere/chat_stream_test.go b/pkg/provider/cohere/chat_stream_test.go similarity index 94% rename from pkg/providers/cohere/chat_stream_test.go rename to pkg/provider/cohere/chat_stream_test.go index 7deb5b88..e40eed14 100644 --- a/pkg/providers/cohere/chat_stream_test.go +++ b/pkg/provider/cohere/chat_stream_test.go @@ -10,11 +10,11 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) @@ -71,7 +71,7 @@ func TestCohere_ChatStreamRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -138,7 +138,7 @@ func TestCohere_ChatStreamRequestInterrupted(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/providers/cohere/client.go b/pkg/provider/cohere/client.go similarity index 87% rename from pkg/providers/cohere/client.go rename to pkg/provider/cohere/client.go index c13ff64b..c3e43eb0 100644 --- a/pkg/providers/cohere/client.go +++ b/pkg/provider/cohere/client.go @@ -5,13 +5,15 @@ import ( "net/url" "time" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" ) const ( - providerName = "cohere" + ProviderID = "cohere" ) // Client is a client for accessing Cohere API @@ -26,6 +28,11 @@ type Client struct { tel *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new Cohere client for the Cohere API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) @@ -54,7 +61,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/cohere/client_test.go b/pkg/provider/cohere/client_test.go similarity index 90% rename from pkg/providers/cohere/client_test.go rename to pkg/provider/cohere/client_test.go index 2e5ab487..721ceda7 100644 --- a/pkg/providers/cohere/client_test.go +++ b/pkg/provider/cohere/client_test.go @@ -11,11 +11,11 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) @@ -55,7 +55,7 @@ func TestCohereClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} diff --git a/pkg/providers/cohere/config.go b/pkg/provider/cohere/config.go similarity index 90% rename from pkg/providers/cohere/config.go rename to pkg/provider/cohere/config.go index 8e7b8b1d..aedcbffa 100644 --- a/pkg/providers/cohere/config.go +++ b/pkg/provider/cohere/config.go @@ -1,7 +1,10 @@ package cohere import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines Cohere-specific model params with the specific validation of values @@ -58,6 +61,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/providers/cohere/errors.go b/pkg/provider/cohere/errors.go similarity index 91% rename from pkg/providers/cohere/errors.go rename to pkg/provider/cohere/errors.go index 118ef719..5f8ea045 100644 --- a/pkg/providers/cohere/errors.go +++ b/pkg/provider/cohere/errors.go @@ -6,9 +6,10 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) @@ -27,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -37,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/cohere/finish_reason.go b/pkg/provider/cohere/finish_reason.go similarity index 73% rename from pkg/providers/cohere/finish_reason.go rename to pkg/provider/cohere/finish_reason.go index 139498e6..4d156875 100644 --- a/pkg/providers/cohere/finish_reason.go +++ b/pkg/provider/cohere/finish_reason.go @@ -3,9 +3,10 @@ package cohere import ( "strings" + "github.com/EinStack/glide/pkg/api/schema" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -27,27 +28,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason *string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason *string) *schema.FinishReason { if finishReason == nil || len(*finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch strings.ToLower(*finishReason) { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", *finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/cohere/register.go b/pkg/provider/cohere/register.go new file mode 100644 index 00000000..3845e24a --- /dev/null +++ b/pkg/provider/cohere/register.go @@ -0,0 +1,7 @@ +package cohere + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} diff --git a/pkg/providers/cohere/schemas.go b/pkg/provider/cohere/schemas.go similarity index 74% rename from pkg/providers/cohere/schemas.go rename to pkg/provider/cohere/schemas.go index 9dc9bb09..c224ec0e 100644 --- a/pkg/providers/cohere/schemas.go +++ b/pkg/provider/cohere/schemas.go @@ -1,6 +1,6 @@ package cohere -import "github.com/EinStack/glide/pkg/api/schemas" +import "github.com/EinStack/glide/pkg/api/schema" // Cohere Chat Response type ChatCompletion struct { @@ -90,25 +90,25 @@ type FinalResponse struct { // ChatRequest is a request to complete a chat completion // Ref: https://docs.cohere.com/reference/chat type ChatRequest struct { - Model string `json:"model"` - Message string `json:"message"` - ChatHistory []schemas.ChatMessage `json:"chat_history"` - Temperature float64 `json:"temperature,omitempty"` - Preamble string `json:"preamble,omitempty"` - PromptTruncation *string `json:"prompt_truncation,omitempty"` - Connectors []string `json:"connectors,omitempty"` - SearchQueriesOnly bool `json:"search_queries_only,omitempty"` - Stream bool `json:"stream,omitempty"` - Seed *int `json:"seed,omitempty"` - MaxTokens *int `json:"max_tokens,omitempty"` - K int `json:"k"` - P float32 `json:"p"` - FrequencyPenalty float32 `json:"frequency_penalty"` - PresencePenalty float32 `json:"presence_penalty"` - StopSequences []string `json:"stop_sequences"` + Model string `json:"model"` + Message string `json:"message"` + ChatHistory []schema.ChatMessage `json:"chat_history"` + Temperature float64 `json:"temperature,omitempty"` + Preamble string `json:"preamble,omitempty"` + PromptTruncation *string `json:"prompt_truncation,omitempty"` + Connectors []string `json:"connectors,omitempty"` + SearchQueriesOnly bool `json:"search_queries_only,omitempty"` + Stream bool `json:"stream,omitempty"` + Seed *int `json:"seed,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + K int `json:"k"` + P float32 `json:"p"` + FrequencyPenalty float32 `json:"frequency_penalty"` + PresencePenalty float32 `json:"presence_penalty"` + StopSequences []string `json:"stop_sequences"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { message := params.Messages[len(params.Messages)-1] messageHistory := params.Messages[:len(params.Messages)-1] diff --git a/pkg/providers/cohere/stream_reader.go b/pkg/provider/cohere/stream_reader.go similarity index 100% rename from pkg/providers/cohere/stream_reader.go rename to pkg/provider/cohere/stream_reader.go diff --git a/pkg/providers/cohere/testdata/chat.req.json b/pkg/provider/cohere/testdata/chat.req.json similarity index 100% rename from pkg/providers/cohere/testdata/chat.req.json rename to pkg/provider/cohere/testdata/chat.req.json diff --git a/pkg/providers/cohere/testdata/chat.success.json b/pkg/provider/cohere/testdata/chat.success.json similarity index 100% rename from pkg/providers/cohere/testdata/chat.success.json rename to pkg/provider/cohere/testdata/chat.success.json diff --git a/pkg/providers/cohere/testdata/chat_stream.interrupted.txt b/pkg/provider/cohere/testdata/chat_stream.interrupted.txt similarity index 100% rename from pkg/providers/cohere/testdata/chat_stream.interrupted.txt rename to pkg/provider/cohere/testdata/chat_stream.interrupted.txt diff --git a/pkg/providers/cohere/testdata/chat_stream.success.txt b/pkg/provider/cohere/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/cohere/testdata/chat_stream.success.txt rename to pkg/provider/cohere/testdata/chat_stream.success.txt diff --git a/pkg/provider/config.go b/pkg/provider/config.go new file mode 100644 index 00000000..23633e0b --- /dev/null +++ b/pkg/provider/config.go @@ -0,0 +1,134 @@ +package provider + +import ( + "errors" + "fmt" + "strings" + + "github.com/go-playground/validator/v10" + + "gopkg.in/yaml.v3" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" +) + +var ErrNoProviderConfigured = errors.New("exactly one provider must be configured, none is configured") + +var validate *validator.Validate + +func init() { + validate = validator.New() +} + +// TODO: Configurer should be more generic, not tied to LangProviders +type Configurer interface { + UnmarshalYAML(unmarshal func(interface{}) error) error + ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) +} + +type Config map[ID]interface{} + +var _ Configurer = (*Config)(nil) + +func (p Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (LangProvider, error) { + for providerID, configValue := range p { + if configValue == nil { + continue + } + + providerConfig, found := LangRegistry.Get(providerID) + + if !found { + return nil, fmt.Errorf( + "provider %s is not supported (available providers: %v)", + providerID, + strings.Join(LangRegistry.Available(), ", "), + ) + } + + providerConfigUnmarshaller := func(providerConfig interface{}) error { + providerConfigBytes, err := yaml.Marshal(configValue) + if err != nil { + return err + } + + return yaml.Unmarshal(providerConfigBytes, providerConfig) + } + + err := providerConfig.UnmarshalYAML(providerConfigUnmarshaller) + if err != nil { + return nil, err + } + + return providerConfig.ToClient(tel, clientConfig) + } + + return nil, ErrProviderNotFound +} + +// validate ensure there is only one provider configured and it's supported by Glide +func (p Config) validate() error { + configuredProviders := make([]ID, 0, len(p)) + + for providerID, config := range p { + if config != nil { + configuredProviders = append(configuredProviders, providerID) + } + } + + if len(configuredProviders) == 0 { + return ErrNoProviderConfigured + } + + if len(configuredProviders) > 1 { + return fmt.Errorf( + "exactly one provider must be configured, but %v are configured (%v)", + len(configuredProviders), + strings.Join(configuredProviders, ", "), + ) + } + + providerID := configuredProviders[0] + providerConfig, found := LangRegistry.Get(providerID) + + if !found { + return fmt.Errorf( + "provider %s is not supported (available providers: %v)", + providerID, + strings.Join(LangRegistry.Available(), ", "), + ) + } + + providerConfigUnmarshaller := func(providerConfig interface{}) error { + configValue := p[providerID] + + providerConfigBytes, err := yaml.Marshal(configValue) + if err != nil { + return err + } + + err = yaml.Unmarshal(providerConfigBytes, providerConfig) + if err != nil { + return err + } + + return validate.Struct(providerConfig) + } + + return providerConfig.UnmarshalYAML(providerConfigUnmarshaller) +} + +func (p *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain Config // to avoid recursion + + temp := plain{} + + if err := unmarshal(&temp); err != nil { + return err + } + + *p = Config(temp) + + return p.validate() +} diff --git a/pkg/provider/config_test.go b/pkg/provider/config_test.go new file mode 100644 index 00000000..451f92ea --- /dev/null +++ b/pkg/provider/config_test.go @@ -0,0 +1,29 @@ +package provider + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v3" +) + +func TestDynLangProvider(t *testing.T) { + LangRegistry.Register(ProviderTest, &TestConfig{}) + + type ProviderConfig struct { + Provider *Config `yaml:"provider"` + } + + prConfig := make(Config) + providerConfig := ProviderConfig{ + Provider: &prConfig, + } + + config, err := os.ReadFile(filepath.Clean("./config_test.yaml")) + require.NoError(t, err) + + err = yaml.Unmarshal(config, &providerConfig) + require.NoError(t, err) +} diff --git a/pkg/provider/config_test.yaml b/pkg/provider/config_test.yaml new file mode 100644 index 00000000..5bc1fb1c --- /dev/null +++ b/pkg/provider/config_test.yaml @@ -0,0 +1,9 @@ +provider: + testprovider: + base_url: "https://api.example.com" + chat_endpoint: "/chat/completions" + model: "example-model" + api_key: "example_api_key" + default_params: + param1: "value1" + param2: "value2" diff --git a/pkg/provider/interface.go b/pkg/provider/interface.go new file mode 100644 index 00000000..1171f486 --- /dev/null +++ b/pkg/provider/interface.go @@ -0,0 +1,35 @@ +package provider + +import ( + "context" + "errors" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" +) + +var ErrProviderNotFound = errors.New("provider not found") + +type ID = string + +// ModelProvider exposes provider context +type ModelProvider interface { + Provider() ID + ModelName() string +} + +// LangProvider defines an interface a provider should fulfill to be able to serve language chat requests +type LangProvider interface { + ModelProvider + SupportChatStream() bool + Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) + ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) +} + +// EmbeddingProvider defines an interface a provider should fulfill to be able to generate embeddings +type EmbeddingProvider interface { + ModelProvider + SupportEmbedding() bool + Embed(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) +} diff --git a/pkg/providers/octoml/chat.go b/pkg/provider/octoml/chat.go similarity index 76% rename from pkg/providers/octoml/chat.go rename to pkg/provider/octoml/chat.go index 92f20fbf..3648cd95 100644 --- a/pkg/providers/octoml/chat.go +++ b/pkg/provider/octoml/chat.go @@ -8,27 +8,27 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/providers/openai" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/provider/openai" "go.uber.org/zap" ) // ChatRequest is an octoml-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -47,7 +47,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified octoml model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -63,7 +63,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -119,21 +119,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: completion.ID, Created: completion.Created, Provider: providerName, ModelName: completion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": completion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: completion.Usage.PromptTokens, ResponseTokens: completion.Usage.CompletionTokens, TotalTokens: completion.Usage.TotalTokens, diff --git a/pkg/provider/octoml/chat_stream.go b/pkg/provider/octoml/chat_stream.go new file mode 100644 index 00000000..22ead76a --- /dev/null +++ b/pkg/provider/octoml/chat_stream.go @@ -0,0 +1,17 @@ +package octoml + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/octoml/client.go b/pkg/provider/octoml/client.go similarity index 90% rename from pkg/providers/octoml/client.go rename to pkg/provider/octoml/client.go index 07f889bb..30ab7794 100644 --- a/pkg/providers/octoml/client.go +++ b/pkg/provider/octoml/client.go @@ -6,9 +6,11 @@ import ( "net/url" "time" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" ) const ( @@ -31,6 +33,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OctoML client for the OctoML API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/providers/octoml/client_test.go b/pkg/provider/octoml/client_test.go similarity index 91% rename from pkg/providers/octoml/client_test.go rename to pkg/provider/octoml/client_test.go index f35de1f7..fcc266c1 100644 --- a/pkg/providers/octoml/client_test.go +++ b/pkg/provider/octoml/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -55,7 +55,7 @@ func TestOctoMLClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -88,7 +88,7 @@ func TestOctoMLClient_Chat_Error(t *testing.T) { require.NoError(t, err) // Create a chat request - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -120,7 +120,7 @@ func TestDoChatRequest_ErrorResponse(t *testing.T) { require.NoError(t, err) // Create a chat request payload - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the dealeo?", }}} diff --git a/pkg/providers/octoml/config.go b/pkg/provider/octoml/config.go similarity index 100% rename from pkg/providers/octoml/config.go rename to pkg/provider/octoml/config.go diff --git a/pkg/providers/octoml/errors.go b/pkg/provider/octoml/errors.go similarity index 96% rename from pkg/providers/octoml/errors.go rename to pkg/provider/octoml/errors.go index 97f16840..fe1d1198 100644 --- a/pkg/providers/octoml/errors.go +++ b/pkg/provider/octoml/errors.go @@ -6,9 +6,10 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) diff --git a/pkg/providers/octoml/testdata/chat.req.json b/pkg/provider/octoml/testdata/chat.req.json similarity index 100% rename from pkg/providers/octoml/testdata/chat.req.json rename to pkg/provider/octoml/testdata/chat.req.json diff --git a/pkg/providers/octoml/testdata/chat.success.json b/pkg/provider/octoml/testdata/chat.success.json similarity index 100% rename from pkg/providers/octoml/testdata/chat.success.json rename to pkg/provider/octoml/testdata/chat.success.json diff --git a/pkg/providers/ollama/chat.go b/pkg/provider/ollama/chat.go similarity index 73% rename from pkg/providers/ollama/chat.go rename to pkg/provider/ollama/chat.go index b93f5b10..463899c8 100644 --- a/pkg/providers/ollama/chat.go +++ b/pkg/provider/ollama/chat.go @@ -9,38 +9,38 @@ import ( "net/http" "time" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/google/uuid" + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/google/uuid" "go.uber.org/zap" ) // ChatRequest is an ollama-specific request schema type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Microstat int `json:"microstat,omitempty"` - MicrostatEta float64 `json:"microstat_eta,omitempty"` - MicrostatTau float64 `json:"microstat_tau,omitempty"` - NumCtx int `json:"num_ctx,omitempty"` - NumGqa int `json:"num_gqa,omitempty"` - NumGpu int `json:"num_gpu,omitempty"` - NumThread int `json:"num_thread,omitempty"` - RepeatLastN int `json:"repeat_last_n,omitempty"` - Temperature float64 `json:"temperature,omitempty"` - Seed int `json:"seed,omitempty"` - StopWords []string `json:"stop,omitempty"` - Tfsz float64 `json:"tfs_z,omitempty"` - NumPredict int `json:"num_predict,omitempty"` - TopK int `json:"top_k,omitempty"` - TopP float64 `json:"top_p,omitempty"` - Stream bool `json:"stream"` + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Microstat int `json:"microstat,omitempty"` + MicrostatEta float64 `json:"microstat_eta,omitempty"` + MicrostatTau float64 `json:"microstat_tau,omitempty"` + NumCtx int `json:"num_ctx,omitempty"` + NumGqa int `json:"num_gqa,omitempty"` + NumGpu int `json:"num_gpu,omitempty"` + NumThread int `json:"num_thread,omitempty"` + RepeatLastN int `json:"repeat_last_n,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Seed int `json:"seed,omitempty"` + StopWords []string `json:"stop,omitempty"` + Tfsz float64 `json:"tfs_z,omitempty"` + NumPredict int `json:"num_predict,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream"` } -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { // TODO(185): set other params r.Messages = params.Messages } @@ -68,7 +68,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified ollama model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -84,7 +84,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { //nolint:cyclop +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { //nolint:cyclop // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -164,18 +164,18 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to UnifiedChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: uuid.NewString(), Created: int(time.Now().Unix()), Provider: providerName, ModelName: ollamaCompletion.Model, Cached: false, - ModelResponse: schemas.ModelResponse{ - Message: schemas.ChatMessage{ + ModelResponse: schema.ModelResponse{ + Message: schema.ChatMessage{ Role: ollamaCompletion.Message.Role, Content: ollamaCompletion.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: ollamaCompletion.EvalCount, ResponseTokens: ollamaCompletion.EvalCount, TotalTokens: ollamaCompletion.EvalCount, diff --git a/pkg/provider/ollama/chat_stream.go b/pkg/provider/ollama/chat_stream.go new file mode 100644 index 00000000..d43f88c0 --- /dev/null +++ b/pkg/provider/ollama/chat_stream.go @@ -0,0 +1,17 @@ +package ollama + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" +) + +func (c *Client) SupportChatStream() bool { + return false +} + +func (c *Client) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { + return nil, clients.ErrChatStreamNotImplemented +} diff --git a/pkg/providers/ollama/client.go b/pkg/provider/ollama/client.go similarity index 89% rename from pkg/providers/ollama/client.go rename to pkg/provider/ollama/client.go index 5a61898e..df624cd5 100644 --- a/pkg/providers/ollama/client.go +++ b/pkg/provider/ollama/client.go @@ -5,9 +5,11 @@ import ( "net/url" "time" - "github.com/EinStack/glide/pkg/telemetry" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" ) const ( @@ -24,6 +26,11 @@ type Client struct { telemetry *telemetry.Telemetry } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) diff --git a/pkg/providers/ollama/client_test.go b/pkg/provider/ollama/client_test.go similarity index 91% rename from pkg/providers/ollama/client_test.go rename to pkg/provider/ollama/client_test.go index e6c584cf..e371c39d 100644 --- a/pkg/providers/ollama/client_test.go +++ b/pkg/provider/ollama/client_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -56,7 +56,7 @@ func TestOllamaClient_ChatRequest(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the biggest animal?", }}} @@ -84,7 +84,7 @@ func TestOllamaClient_ChatRequest_Non200Response(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -121,7 +121,7 @@ func TestOllamaClient_ChatRequest_SuccessfulResponse(t *testing.T) { client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} diff --git a/pkg/providers/ollama/config.go b/pkg/provider/ollama/config.go similarity index 100% rename from pkg/providers/ollama/config.go rename to pkg/provider/ollama/config.go diff --git a/pkg/providers/ollama/schemas.go b/pkg/provider/ollama/schemas.go similarity index 100% rename from pkg/providers/ollama/schemas.go rename to pkg/provider/ollama/schemas.go diff --git a/pkg/providers/ollama/testdata/chat.req.json b/pkg/provider/ollama/testdata/chat.req.json similarity index 100% rename from pkg/providers/ollama/testdata/chat.req.json rename to pkg/provider/ollama/testdata/chat.req.json diff --git a/pkg/providers/ollama/testdata/chat.success.json b/pkg/provider/ollama/testdata/chat.success.json similarity index 100% rename from pkg/providers/ollama/testdata/chat.success.json rename to pkg/provider/ollama/testdata/chat.success.json diff --git a/pkg/providers/openai/chat.go b/pkg/provider/openai/chat.go similarity index 90% rename from pkg/providers/openai/chat.go rename to pkg/provider/openai/chat.go index 519d7d43..be2bcbf3 100644 --- a/pkg/providers/openai/chat.go +++ b/pkg/provider/openai/chat.go @@ -8,9 +8,10 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -36,7 +37,7 @@ func NewChatRequestFromConfig(cfg *Config) *ChatRequest { } // Chat sends a chat request to the specified OpenAI model. -func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Client) Chat(ctx context.Context, params *schema.ChatParams) (*schema.ChatResponse, error) { // Create a new chat request // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template @@ -52,7 +53,7 @@ func (c *Client) Chat(ctx context.Context, params *schemas.ChatParams) (*schemas return chatResponse, nil } -func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schemas.ChatResponse, error) { +func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*schema.ChatResponse, error) { // Build request payload rawPayload, err := json.Marshal(payload) if err != nil { @@ -123,21 +124,21 @@ func (c *Client) doChatRequest(ctx context.Context, payload *ChatRequest) (*sche } // Map response to ChatResponse schema - response := schemas.ChatResponse{ + response := schema.ChatResponse{ ID: chatCompletion.ID, Created: chatCompletion.Created, - Provider: providerName, + Provider: ProviderID, ModelName: chatCompletion.ModelName, Cached: false, - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "system_fingerprint": chatCompletion.SystemFingerprint, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: modelChoice.Message.Role, Content: modelChoice.Message.Content, }, - TokenUsage: schemas.TokenUsage{ + TokenUsage: schema.TokenUsage{ PromptTokens: chatCompletion.Usage.PromptTokens, ResponseTokens: chatCompletion.Usage.CompletionTokens, TotalTokens: chatCompletion.Usage.TotalTokens, diff --git a/pkg/providers/openai/chat_stream.go b/pkg/provider/openai/chat_stream.go similarity index 88% rename from pkg/providers/openai/chat_stream.go rename to pkg/provider/openai/chat_stream.go index 08ca2b21..9d50295b 100644 --- a/pkg/providers/openai/chat_stream.go +++ b/pkg/provider/openai/chat_stream.go @@ -8,11 +8,12 @@ import ( "io" "net/http" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/clients" + "github.com/r3labs/sse/v2" "go.uber.org/zap" - - "github.com/EinStack/glide/pkg/api/schemas" ) var StreamDoneMarker = []byte("[DONE]") @@ -28,6 +29,11 @@ type ChatStream struct { logger *zap.Logger } +// ensure interface +var ( + _ clients.ChatStream = (*ChatStream)(nil) +) + func NewChatStream( client *http.Client, req *http.Request, @@ -60,7 +66,7 @@ func (s *ChatStream) Open() error { return nil } -func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { +func (s *ChatStream) Recv() (*schema.ChatStreamChunk, error) { var completionChunk ChatCompletionChunk for { @@ -109,17 +115,17 @@ func (s *ChatStream) Recv() (*schemas.ChatStreamChunk, error) { responseChunk := completionChunk.Choices[0] // TODO: use objectpool here - return &schemas.ChatStreamChunk{ + return &schema.ChatStreamChunk{ Cached: false, - Provider: providerName, + Provider: ProviderID, ModelName: completionChunk.ModelName, - ModelResponse: schemas.ModelChunkResponse{ - Metadata: &schemas.Metadata{ + ModelResponse: schema.ModelChunkResponse{ + Metadata: &schema.Metadata{ "response_id": completionChunk.ID, "system_fingerprint": completionChunk.SystemFingerprint, "generated_at": completionChunk.Created, }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Role: "assistant", // doesn't present in all chunks Content: responseChunk.Delta.Content, }, @@ -141,7 +147,7 @@ func (c *Client) SupportChatStream() bool { return true } -func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Client) ChatStream(ctx context.Context, params *schema.ChatParams) (clients.ChatStream, error) { // Create a new chat request httpRequest, err := c.makeStreamReq(ctx, params) if err != nil { @@ -157,7 +163,7 @@ func (c *Client) ChatStream(ctx context.Context, params *schemas.ChatParams) (cl ), nil } -func (c *Client) makeStreamReq(ctx context.Context, params *schemas.ChatParams) (*http.Request, error) { +func (c *Client) makeStreamReq(ctx context.Context, params *schema.ChatParams) (*http.Request, error) { // TODO: consider using objectpool to optimize memory allocation chatReq := *c.chatRequestTemplate // hoping to get a copy of the template chatReq.ApplyParams(params) diff --git a/pkg/providers/openai/chat_stream_test.go b/pkg/provider/openai/chat_stream_test.go similarity index 89% rename from pkg/providers/openai/chat_stream_test.go rename to pkg/provider/openai/chat_stream_test.go index 1ab8483b..2934df3f 100644 --- a/pkg/providers/openai/chat_stream_test.go +++ b/pkg/provider/openai/chat_stream_test.go @@ -10,18 +10,18 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/telemetry" + clients2 "github.com/EinStack/glide/pkg/clients" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) func TestOpenAIClient_ChatStreamSupported(t *testing.T) { providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) @@ -64,14 +64,14 @@ func TestOpenAIClient_ChatStreamRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -132,14 +132,14 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -153,7 +153,7 @@ func TestOpenAIClient_ChatStreamRequestInterrupted(t *testing.T) { for { chunk, err := stream.Recv() if err != nil { - require.ErrorIs(t, err, clients.ErrProviderUnavailable) + require.ErrorIs(t, err, clients2.ErrProviderUnavailable) return } diff --git a/pkg/providers/openai/chat_test.go b/pkg/provider/openai/chat_test.go similarity index 85% rename from pkg/providers/openai/chat_test.go rename to pkg/provider/openai/chat_test.go index 3109f150..65dde4f6 100644 --- a/pkg/providers/openai/chat_test.go +++ b/pkg/provider/openai/chat_test.go @@ -10,9 +10,9 @@ import ( "path/filepath" "testing" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + clients2 "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/telemetry" @@ -49,14 +49,14 @@ func TestOpenAIClient_ChatRequest(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "user", Content: "What's the capital of the United Kingdom?", }}} @@ -78,14 +78,14 @@ func TestOpenAIClient_RateLimit(t *testing.T) { ctx := context.Background() providerCfg := DefaultConfig() - clientCfg := clients.DefaultClientConfig() + clientCfg := clients2.DefaultClientConfig() providerCfg.BaseURL = openAIServer.URL client, err := NewClient(providerCfg, clientCfg, telemetry.NewTelemetryMock()) require.NoError(t, err) - chatParams := schemas.ChatParams{Messages: []schemas.ChatMessage{{ + chatParams := schema.ChatParams{Messages: []schema.ChatMessage{{ Role: "human", Content: "What's the biggest animal?", }}} @@ -93,5 +93,5 @@ func TestOpenAIClient_RateLimit(t *testing.T) { _, err = client.Chat(ctx, &chatParams) require.Error(t, err) - require.IsType(t, &clients.RateLimitError{}, err) + require.IsType(t, &clients2.RateLimitError{}, err) } diff --git a/pkg/providers/openai/client.go b/pkg/provider/openai/client.go similarity index 86% rename from pkg/providers/openai/client.go rename to pkg/provider/openai/client.go index 832ade57..8567e26c 100644 --- a/pkg/providers/openai/client.go +++ b/pkg/provider/openai/client.go @@ -5,15 +5,17 @@ import ( "net/url" "time" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" + "go.uber.org/zap" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/clients" ) const ( - providerName = "openai" + ProviderID = "openai" ) // Client is a client for accessing OpenAI API @@ -29,6 +31,11 @@ type Client struct { logger *zap.Logger } +// ensure interfaces +var ( + _ provider.LangProvider = (*Client)(nil) +) + // NewClient creates a new OpenAI client for the OpenAI API. func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel *telemetry.Telemetry) (*Client, error) { chatURL, err := url.JoinPath(providerConfig.BaseURL, providerConfig.ChatEndpoint) @@ -37,7 +44,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } logger := tel.L().With( - zap.String("provider", providerName), + zap.String("provider", ProviderID), ) c := &Client{ @@ -62,7 +69,7 @@ func NewClient(providerConfig *Config, clientConfig *clients.ClientConfig, tel * } func (c *Client) Provider() string { - return providerName + return ProviderID } func (c *Client) ModelName() string { diff --git a/pkg/providers/openai/config.go b/pkg/provider/openai/config.go similarity index 87% rename from pkg/providers/openai/config.go rename to pkg/provider/openai/config.go index 8342db41..c6500c1f 100644 --- a/pkg/providers/openai/config.go +++ b/pkg/provider/openai/config.go @@ -1,7 +1,10 @@ package openai import ( + "github.com/EinStack/glide/pkg/clients" "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" ) // Params defines OpenAI-specific model params with the specific validation of values @@ -49,6 +52,11 @@ type Config struct { DefaultParams *Params `yaml:"default_params,omitempty" json:"default_params"` } +// ensure interfaces +var ( + _ provider.Configurer = (*Config)(nil) +) + // DefaultConfig for OpenAI models func DefaultConfig() *Config { defaultParams := DefaultParams() @@ -61,6 +69,10 @@ func DefaultConfig() *Config { } } +func (c *Config) ToClient(tel *telemetry.Telemetry, clientConfig *clients.ClientConfig) (provider.LangProvider, error) { + return NewClient(c, clientConfig, tel) +} + func (c *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { *c = *DefaultConfig() diff --git a/pkg/provider/openai/embed.go b/pkg/provider/openai/embed.go new file mode 100644 index 00000000..ba054adc --- /dev/null +++ b/pkg/provider/openai/embed.go @@ -0,0 +1,13 @@ +package openai + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" +) + +// Embed sends an embedding request to the specified OpenAI model. +func (c *Client) Embed(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { + // TODO: implement + return nil, nil +} diff --git a/pkg/providers/openai/errors.go b/pkg/provider/openai/errors.go similarity index 91% rename from pkg/providers/openai/errors.go rename to pkg/provider/openai/errors.go index 14978f8c..d0389cbe 100644 --- a/pkg/providers/openai/errors.go +++ b/pkg/provider/openai/errors.go @@ -6,9 +6,10 @@ import ( "net/http" "time" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/providers/clients" "go.uber.org/zap" ) @@ -27,7 +28,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { if err != nil { m.tel.Logger.Error( "Failed to unmarshal chat response error", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Error(err), zap.ByteString("rawResponse", bodyBytes), ) @@ -37,7 +38,7 @@ func (m *ErrorMapper) Map(resp *http.Response) error { m.tel.Logger.Error( "Chat request failed", - zap.String("provider", providerName), + zap.String("provider", ProviderID), zap.Int("statusCode", resp.StatusCode), zap.String("response", string(bodyBytes)), zap.Any("headers", resp.Header), diff --git a/pkg/providers/openai/finish_reasons.go b/pkg/provider/openai/finish_reasons.go similarity index 71% rename from pkg/providers/openai/finish_reasons.go rename to pkg/provider/openai/finish_reasons.go index 28b5f675..65196946 100644 --- a/pkg/providers/openai/finish_reasons.go +++ b/pkg/provider/openai/finish_reasons.go @@ -1,9 +1,9 @@ package openai import ( + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" - "github.com/EinStack/glide/pkg/api/schemas" "go.uber.org/zap" ) @@ -25,27 +25,27 @@ type FinishReasonMapper struct { tel *telemetry.Telemetry } -func (m *FinishReasonMapper) Map(finishReason string) *schemas.FinishReason { +func (m *FinishReasonMapper) Map(finishReason string) *schema.FinishReason { if len(finishReason) == 0 { return nil } - var reason *schemas.FinishReason + var reason *schema.FinishReason switch finishReason { case CompleteReason: - reason = &schemas.ReasonComplete + reason = &schema.ReasonComplete case MaxTokensReason: - reason = &schemas.ReasonMaxTokens + reason = &schema.ReasonMaxTokens case FilteredReason: - reason = &schemas.ReasonContentFiltered + reason = &schema.ReasonContentFiltered default: m.tel.Logger.Warn( "Unknown finish reason, other is going to used", zap.String("unknown_reason", finishReason), ) - reason = &schemas.ReasonOther + reason = &schema.ReasonOther } return reason diff --git a/pkg/provider/openai/register.go b/pkg/provider/openai/register.go new file mode 100644 index 00000000..b79b77b1 --- /dev/null +++ b/pkg/provider/openai/register.go @@ -0,0 +1,7 @@ +package openai + +import "github.com/EinStack/glide/pkg/provider" + +func init() { + provider.LangRegistry.Register(ProviderID, &Config{}) +} diff --git a/pkg/provider/openai/schemas.go b/pkg/provider/openai/schemas.go new file mode 100644 index 00000000..31af6f9c --- /dev/null +++ b/pkg/provider/openai/schemas.go @@ -0,0 +1,71 @@ +package openai + +import "github.com/EinStack/glide/pkg/api/schema" + +// ChatRequest is an OpenAI-specific request schema +type ChatRequest struct { + Model string `json:"model"` + Messages []schema.ChatMessage `json:"messages"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + N int `json:"n,omitempty"` + StopWords []string `json:"stop,omitempty"` + Stream bool `json:"stream,omitempty"` + FrequencyPenalty int `json:"frequency_penalty,omitempty"` + PresencePenalty int `json:"presence_penalty,omitempty"` + LogitBias *map[int]float64 `json:"logit_bias,omitempty"` + User *string `json:"user,omitempty"` + Seed *int `json:"seed,omitempty"` + Tools []string `json:"tools,omitempty"` + ToolChoice interface{} `json:"tool_choice,omitempty"` + ResponseFormat interface{} `json:"response_format,omitempty"` +} + +func (r *ChatRequest) ApplyParams(params *schema.ChatParams) { + // TODO(185): set other params + r.Messages = params.Messages +} + +// ChatCompletion +// Ref: https://platform.openai.com/docs/api-reference/chat/object +type ChatCompletion struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message schema.ChatMessage `json:"message"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming +// Ref: https://platform.openai.com/docs/api-reference/chat/streaming +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int `json:"created"` + ModelName string `json:"model"` + SystemFingerprint string `json:"system_fingerprint"` + Choices []StreamChoice `json:"choices"` +} + +type StreamChoice struct { + Index int `json:"index"` + Delta schema.ChatMessage `json:"delta"` + Logprobs interface{} `json:"logprobs"` + FinishReason string `json:"finish_reason"` +} diff --git a/pkg/providers/openai/testdata/chat.req.json b/pkg/provider/openai/testdata/chat.req.json similarity index 100% rename from pkg/providers/openai/testdata/chat.req.json rename to pkg/provider/openai/testdata/chat.req.json diff --git a/pkg/providers/openai/testdata/chat.success.json b/pkg/provider/openai/testdata/chat.success.json similarity index 100% rename from pkg/providers/openai/testdata/chat.success.json rename to pkg/provider/openai/testdata/chat.success.json diff --git a/pkg/providers/openai/testdata/chat_stream.empty.txt b/pkg/provider/openai/testdata/chat_stream.empty.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.empty.txt rename to pkg/provider/openai/testdata/chat_stream.empty.txt diff --git a/pkg/providers/openai/testdata/chat_stream.nodone.txt b/pkg/provider/openai/testdata/chat_stream.nodone.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.nodone.txt rename to pkg/provider/openai/testdata/chat_stream.nodone.txt diff --git a/pkg/providers/openai/testdata/chat_stream.success.txt b/pkg/provider/openai/testdata/chat_stream.success.txt similarity index 100% rename from pkg/providers/openai/testdata/chat_stream.success.txt rename to pkg/provider/openai/testdata/chat_stream.success.txt diff --git a/pkg/provider/registry.go b/pkg/provider/registry.go new file mode 100644 index 00000000..8862ba88 --- /dev/null +++ b/pkg/provider/registry.go @@ -0,0 +1,41 @@ +package provider + +import ( + "fmt" +) + +var LangRegistry = NewRegistry() + +type Registry struct { + providers map[ID]Configurer +} + +func NewRegistry() *Registry { + return &Registry{ + providers: make(map[ID]Configurer), + } +} + +func (r *Registry) Register(name ID, config Configurer) { + if _, ok := r.Get(name); ok { + panic(fmt.Sprintf("provider %s is already registered", name)) + } + + r.providers[name] = config +} + +func (r *Registry) Get(name ID) (Configurer, bool) { + config, ok := r.providers[name] + + return config, ok +} + +func (r *Registry) Available() []ID { + available := make([]ID, 0, len(r.providers)) + + for providerID := range r.providers { + available = append(available, providerID) + } + + return available +} diff --git a/pkg/providers/testing/lang.go b/pkg/provider/testing.go similarity index 52% rename from pkg/providers/testing/lang.go rename to pkg/provider/testing.go index 0f7f1f4e..f7bc64f0 100644 --- a/pkg/providers/testing/lang.go +++ b/pkg/provider/testing.go @@ -1,38 +1,61 @@ -package testing +package provider import ( "context" "io" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/config/fields" + "github.com/EinStack/glide/pkg/telemetry" ) +const ( + ProviderTest = "testprovider" +) + +type TestConfig struct { + BaseURL string `yaml:"base_url" json:"base_url" validate:"required"` + ChatEndpoint string `yaml:"chat_endpoint" json:"chat_endpoint" validate:"required"` + ModelName string `yaml:"model" json:"model" validate:"required"` + APIKey fields.Secret `yaml:"api_key" json:"-" validate:"required"` +} + +func (c *TestConfig) ToClient(_ *telemetry.Telemetry, _ *clients.ClientConfig) (LangProvider, error) { + return NewMock(nil, []RespMock{}), nil +} + +func (c *TestConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + type plain TestConfig // to avoid recursion + + return unmarshal((*plain)(c)) +} + // RespMock mocks a chat response or a streaming chat chunk type RespMock struct { Msg string Err error } -func (m *RespMock) Resp() *schemas.ChatResponse { - return &schemas.ChatResponse{ +func (m *RespMock) Resp() *schema.ChatResponse { + return &schema.ChatResponse{ ID: "rsp0001", - ModelResponse: schemas.ModelResponse{ + ModelResponse: schema.ModelResponse{ Metadata: map[string]string{ "ID": "0001", }, - Message: schemas.ChatMessage{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, } } -func (m *RespMock) RespChunk() *schemas.ChatStreamChunk { - return &schemas.ChatStreamChunk{ - ModelResponse: schemas.ModelChunkResponse{ - Message: schemas.ChatMessage{ +func (m *RespMock) RespChunk() *schema.ChatStreamChunk { + return &schema.ChatStreamChunk{ + ModelResponse: schema.ModelChunkResponse{ + Message: schema.ChatMessage{ Content: m.Msg, }, }, @@ -46,6 +69,11 @@ type RespStreamMock struct { Chunks *[]RespMock } +// ensure interface +var ( + _ clients.ChatStream = (*RespStreamMock)(nil) +) + func NewRespStreamMock(chunk *[]RespMock) RespStreamMock { return RespStreamMock{ idx: 0, @@ -70,7 +98,7 @@ func (m *RespStreamMock) Open() error { return nil } -func (m *RespStreamMock) Recv() (*schemas.ChatStreamChunk, error) { +func (m *RespStreamMock) Recv() (*schema.ChatStreamChunk, error) { if m.Chunks != nil && m.idx >= len(*m.Chunks) { return nil, io.EOF } @@ -91,8 +119,8 @@ func (m *RespStreamMock) Close() error { return nil } -// ProviderMock mocks a model provider -type ProviderMock struct { +// Mock mocks a model provider +type Mock struct { idx int chatResps *[]RespMock chatStreams *[]RespStreamMock @@ -100,8 +128,13 @@ type ProviderMock struct { modelName *string } -func NewProviderMock(modelName *string, responses []RespMock) *ProviderMock { - return &ProviderMock{ +// ensure interfaces +var ( + _ LangProvider = (*Mock)(nil) +) + +func NewMock(modelName *string, responses []RespMock) *Mock { + return &Mock{ idx: 0, chatResps: &responses, supportStreaming: false, @@ -109,8 +142,8 @@ func NewProviderMock(modelName *string, responses []RespMock) *ProviderMock { } } -func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *ProviderMock { - return &ProviderMock{ +func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *Mock { + return &Mock{ idx: 0, modelName: modelName, chatStreams: &chatStreams, @@ -118,11 +151,11 @@ func NewStreamProviderMock(modelName *string, chatStreams []RespStreamMock) *Pro } } -func (c *ProviderMock) SupportChatStream() bool { +func (c *Mock) SupportChatStream() bool { return c.supportStreaming } -func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas.ChatResponse, error) { +func (c *Mock) Chat(_ context.Context, _ *schema.ChatParams) (*schema.ChatResponse, error) { if c.chatResps == nil { return nil, clients.ErrProviderUnavailable } @@ -139,7 +172,7 @@ func (c *ProviderMock) Chat(_ context.Context, _ *schemas.ChatParams) (*schemas. return response.Resp(), nil } -func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { +func (c *Mock) ChatStream(_ context.Context, _ *schema.ChatParams) (clients.ChatStream, error) { if c.chatStreams == nil || c.idx >= len(*c.chatStreams) { return nil, clients.ErrProviderUnavailable } @@ -152,11 +185,11 @@ func (c *ProviderMock) ChatStream(_ context.Context, _ *schemas.ChatParams) (cli return &stream, nil } -func (c *ProviderMock) Provider() string { +func (c *Mock) Provider() string { return "provider_mock" } -func (c *ProviderMock) ModelName() string { +func (c *Mock) ModelName() string { if c.modelName == nil { return "model_mock" } diff --git a/pkg/providers/anthropic/chat_stream.go b/pkg/providers/anthropic/chat_stream.go deleted file mode 100644 index 5a6f2112..00000000 --- a/pkg/providers/anthropic/chat_stream.go +++ /dev/null @@ -1,17 +0,0 @@ -package anthropic - -import ( - "context" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/api/schemas" -) - -func (c *Client) SupportChatStream() bool { - return false -} - -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented -} diff --git a/pkg/providers/azureopenai/schemas.go b/pkg/providers/azureopenai/schemas.go deleted file mode 100644 index 5940648c..00000000 --- a/pkg/providers/azureopenai/schemas.go +++ /dev/null @@ -1,68 +0,0 @@ -package azureopenai - -import "github.com/EinStack/glide/pkg/api/schemas" - -// ChatRequest is an Azure openai-specific request schema -type ChatRequest struct { - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` -} - -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { - r.Messages = params.Messages -} - -// ChatCompletion -// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions -type ChatCompletion struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - ModelName string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` -} - -type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} - -type Usage struct { - PromptTokens float64 `json:"prompt_tokens"` - CompletionTokens float64 `json:"completion_tokens"` - TotalTokens float64 `json:"total_tokens"` -} - -// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming -// Ref: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions -type ChatCompletionChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - ModelName string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []StreamChoice `json:"choices"` -} - -type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - FinishReason string `json:"finish_reason"` -} diff --git a/pkg/providers/bedrock/chat_stream.go b/pkg/providers/bedrock/chat_stream.go deleted file mode 100644 index bb07da7d..00000000 --- a/pkg/providers/bedrock/chat_stream.go +++ /dev/null @@ -1,17 +0,0 @@ -package bedrock - -import ( - "context" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/api/schemas" -) - -func (c *Client) SupportChatStream() bool { - return false -} - -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented -} diff --git a/pkg/providers/clients/stream.go b/pkg/providers/clients/stream.go deleted file mode 100644 index 913bbddc..00000000 --- a/pkg/providers/clients/stream.go +++ /dev/null @@ -1,31 +0,0 @@ -package clients - -import ( - "github.com/EinStack/glide/pkg/api/schemas" -) - -type ChatStream interface { - Open() error - Recv() (*schemas.ChatStreamChunk, error) - Close() error -} - -type ChatStreamResult struct { - chunk *schemas.ChatStreamChunk - err error -} - -func (r *ChatStreamResult) Chunk() *schemas.ChatStreamChunk { - return r.chunk -} - -func (r *ChatStreamResult) Error() error { - return r.err -} - -func NewChatStreamResult(chunk *schemas.ChatStreamChunk, err error) *ChatStreamResult { - return &ChatStreamResult{ - chunk: chunk, - err: err, - } -} diff --git a/pkg/providers/config.go b/pkg/providers/config.go deleted file mode 100644 index 206be273..00000000 --- a/pkg/providers/config.go +++ /dev/null @@ -1,146 +0,0 @@ -package providers - -import ( - "errors" - "fmt" - - "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/providers/ollama" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/providers/bedrock" - - "github.com/EinStack/glide/pkg/routers/health" - - "github.com/EinStack/glide/pkg/providers/openai" - - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/providers/octoml" - - "github.com/EinStack/glide/pkg/providers/cohere" - - "github.com/EinStack/glide/pkg/providers/azureopenai" - - "github.com/EinStack/glide/pkg/providers/anthropic" -) - -var ErrProviderNotFound = errors.New("provider not found") - -type LangModelConfig struct { - ID string `yaml:"id" json:"id" validate:"required"` // Model instance ID (unique in scope of the router) - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is the model enabled? - ErrorBudget *health.ErrorBudget `yaml:"error_budget" json:"error_budget" swaggertype:"primitive,string"` - Latency *latency.Config `yaml:"latency" json:"latency"` - Weight int `yaml:"weight" json:"weight"` - Client *clients.ClientConfig `yaml:"client" json:"client"` - // Add other providers like - OpenAI *openai.Config `yaml:"openai,omitempty" json:"openai,omitempty"` - AzureOpenAI *azureopenai.Config `yaml:"azureopenai,omitempty" json:"azureopenai,omitempty"` - Cohere *cohere.Config `yaml:"cohere,omitempty" json:"cohere,omitempty"` - OctoML *octoml.Config `yaml:"octoml,omitempty" json:"octoml,omitempty"` - Anthropic *anthropic.Config `yaml:"anthropic,omitempty" json:"anthropic,omitempty"` - Bedrock *bedrock.Config `yaml:"bedrock,omitempty" json:"bedrock,omitempty"` - Ollama *ollama.Config `yaml:"ollama,omitempty" json:"ollama,omitempty"` -} - -func DefaultLangModelConfig() *LangModelConfig { - return &LangModelConfig{ - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - Weight: 1, - } -} - -func (c *LangModelConfig) ToModel(tel *telemetry.Telemetry) (*LanguageModel, error) { - client, err := c.initClient(tel) - if err != nil { - return nil, fmt.Errorf("error initializing client: %v", err) - } - - return NewLangModel(c.ID, client, c.ErrorBudget, *c.Latency, c.Weight), nil -} - -// initClient initializes the language model client based on the provided configuration. -// It takes a telemetry object as input and returns a LangModelProvider and an error. -func (c *LangModelConfig) initClient(tel *telemetry.Telemetry) (LangProvider, error) { - switch { - case c.OpenAI != nil: - return openai.NewClient(c.OpenAI, c.Client, tel) - case c.AzureOpenAI != nil: - return azureopenai.NewClient(c.AzureOpenAI, c.Client, tel) - case c.Cohere != nil: - return cohere.NewClient(c.Cohere, c.Client, tel) - case c.OctoML != nil: - return octoml.NewClient(c.OctoML, c.Client, tel) - case c.Anthropic != nil: - return anthropic.NewClient(c.Anthropic, c.Client, tel) - case c.Bedrock != nil: - return bedrock.NewClient(c.Bedrock, c.Client, tel) - default: - return nil, ErrProviderNotFound - } -} - -func (c *LangModelConfig) validateOneProvider() error { - providersConfigured := 0 - - if c.OpenAI != nil { - providersConfigured++ - } - - if c.AzureOpenAI != nil { - providersConfigured++ - } - - if c.Cohere != nil { - providersConfigured++ - } - - if c.OctoML != nil { - providersConfigured++ - } - - if c.Anthropic != nil { - providersConfigured++ - } - - if c.Bedrock != nil { - providersConfigured++ - } - - if c.Ollama != nil { - providersConfigured++ - } - - // check other providers here - if providersConfigured == 0 { - return fmt.Errorf("exactly one provider must be configured for model \"%v\", none is configured", c.ID) - } - - if providersConfigured > 1 { - return fmt.Errorf( - "exactly one provider must be configured for model \"%v\", %v are configured", - c.ID, - providersConfigured, - ) - } - - return nil -} - -func (c *LangModelConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = *DefaultLangModelConfig() - - type plain LangModelConfig // to avoid recursion - - if err := unmarshal((*plain)(c)); err != nil { - return err - } - - return c.validateOneProvider() -} diff --git a/pkg/providers/octoml/chat_stream.go b/pkg/providers/octoml/chat_stream.go deleted file mode 100644 index d0e33420..00000000 --- a/pkg/providers/octoml/chat_stream.go +++ /dev/null @@ -1,17 +0,0 @@ -package octoml - -import ( - "context" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/api/schemas" -) - -func (c *Client) SupportChatStream() bool { - return false -} - -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented -} diff --git a/pkg/providers/ollama/chat_stream.go b/pkg/providers/ollama/chat_stream.go deleted file mode 100644 index 31075ca1..00000000 --- a/pkg/providers/ollama/chat_stream.go +++ /dev/null @@ -1,17 +0,0 @@ -package ollama - -import ( - "context" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/api/schemas" -) - -func (c *Client) SupportChatStream() bool { - return false -} - -func (c *Client) ChatStream(_ context.Context, _ *schemas.ChatParams) (clients.ChatStream, error) { - return nil, clients.ErrChatStreamNotImplemented -} diff --git a/pkg/providers/openai/schemas.go b/pkg/providers/openai/schemas.go deleted file mode 100644 index bde0ba81..00000000 --- a/pkg/providers/openai/schemas.go +++ /dev/null @@ -1,71 +0,0 @@ -package openai - -import "github.com/EinStack/glide/pkg/api/schemas" - -// ChatRequest is an OpenAI-specific request schema -type ChatRequest struct { - Model string `json:"model"` - Messages []schemas.ChatMessage `json:"messages"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"top_p,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` - N int `json:"n,omitempty"` - StopWords []string `json:"stop,omitempty"` - Stream bool `json:"stream,omitempty"` - FrequencyPenalty int `json:"frequency_penalty,omitempty"` - PresencePenalty int `json:"presence_penalty,omitempty"` - LogitBias *map[int]float64 `json:"logit_bias,omitempty"` - User *string `json:"user,omitempty"` - Seed *int `json:"seed,omitempty"` - Tools []string `json:"tools,omitempty"` - ToolChoice interface{} `json:"tool_choice,omitempty"` - ResponseFormat interface{} `json:"response_format,omitempty"` -} - -func (r *ChatRequest) ApplyParams(params *schemas.ChatParams) { - // TODO(185): set other params - r.Messages = params.Messages -} - -// ChatCompletion -// Ref: https://platform.openai.com/docs/api-reference/chat/object -type ChatCompletion struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - ModelName string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []Choice `json:"choices"` - Usage Usage `json:"usage"` -} - -type Choice struct { - Index int `json:"index"` - Message schemas.ChatMessage `json:"message"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} - -type Usage struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - -// ChatCompletionChunk represents SSEvent a chat response is broken down on chat streaming -// Ref: https://platform.openai.com/docs/api-reference/chat/streaming -type ChatCompletionChunk struct { - ID string `json:"id"` - Object string `json:"object"` - Created int `json:"created"` - ModelName string `json:"model"` - SystemFingerprint string `json:"system_fingerprint"` - Choices []StreamChoice `json:"choices"` -} - -type StreamChoice struct { - Index int `json:"index"` - Delta schemas.ChatMessage `json:"delta"` - Logprobs interface{} `json:"logprobs"` - FinishReason string `json:"finish_reason"` -} diff --git a/pkg/providers/provider.go b/pkg/providers/provider.go deleted file mode 100644 index 91aded44..00000000 --- a/pkg/providers/provider.go +++ /dev/null @@ -1,19 +0,0 @@ -package providers - -import ( - "github.com/EinStack/glide/pkg/config/fields" -) - -// ModelProvider exposes provider context -type ModelProvider interface { - Provider() string - ModelName() string -} - -// Model represent a configured external modality-agnostic model with its routing properties and status -type Model interface { - ID() string - Healthy() bool - LatencyUpdateInterval() *fields.Duration - Weight() int -} diff --git a/pkg/routers/health/buckets.go b/pkg/resiliency/health/buckets.go similarity index 100% rename from pkg/routers/health/buckets.go rename to pkg/resiliency/health/buckets.go diff --git a/pkg/routers/health/buckets_test.go b/pkg/resiliency/health/buckets_test.go similarity index 100% rename from pkg/routers/health/buckets_test.go rename to pkg/resiliency/health/buckets_test.go diff --git a/pkg/routers/health/error_budget.go b/pkg/resiliency/health/error_budget.go similarity index 100% rename from pkg/routers/health/error_budget.go rename to pkg/resiliency/health/error_budget.go diff --git a/pkg/routers/health/error_budget_test.go b/pkg/resiliency/health/error_budget_test.go similarity index 100% rename from pkg/routers/health/error_budget_test.go rename to pkg/resiliency/health/error_budget_test.go diff --git a/pkg/routers/health/ratelimit.go b/pkg/resiliency/health/ratelimit.go similarity index 100% rename from pkg/routers/health/ratelimit.go rename to pkg/resiliency/health/ratelimit.go diff --git a/pkg/routers/health/ratelimit_test.go b/pkg/resiliency/health/ratelimit_test.go similarity index 100% rename from pkg/routers/health/ratelimit_test.go rename to pkg/resiliency/health/ratelimit_test.go diff --git a/pkg/routers/health/tracker.go b/pkg/resiliency/health/tracker.go similarity index 94% rename from pkg/routers/health/tracker.go rename to pkg/resiliency/health/tracker.go index 8cba6e65..3d4a313b 100644 --- a/pkg/routers/health/tracker.go +++ b/pkg/resiliency/health/tracker.go @@ -3,7 +3,7 @@ package health import ( "errors" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/clients" ) // Tracker tracks errors and general health of model provider diff --git a/pkg/routers/health/tracker_test.go b/pkg/resiliency/health/tracker_test.go similarity index 93% rename from pkg/routers/health/tracker_test.go rename to pkg/resiliency/health/tracker_test.go index 8927a041..279bd378 100644 --- a/pkg/routers/health/tracker_test.go +++ b/pkg/resiliency/health/tracker_test.go @@ -4,7 +4,8 @@ import ( "testing" "time" - "github.com/EinStack/glide/pkg/providers/clients" + "github.com/EinStack/glide/pkg/clients" + "github.com/stretchr/testify/require" ) diff --git a/pkg/routers/retry/config.go b/pkg/resiliency/retry/config.go similarity index 100% rename from pkg/routers/retry/config.go rename to pkg/resiliency/retry/config.go diff --git a/pkg/routers/retry/config_test.go b/pkg/resiliency/retry/config_test.go similarity index 100% rename from pkg/routers/retry/config_test.go rename to pkg/resiliency/retry/config_test.go diff --git a/pkg/routers/retry/exp.go b/pkg/resiliency/retry/exp.go similarity index 100% rename from pkg/routers/retry/exp.go rename to pkg/resiliency/retry/exp.go diff --git a/pkg/routers/retry/exp_test.go b/pkg/resiliency/retry/exp_test.go similarity index 100% rename from pkg/routers/retry/exp_test.go rename to pkg/resiliency/retry/exp_test.go diff --git a/pkg/router/config.go b/pkg/router/config.go new file mode 100644 index 00000000..0641a369 --- /dev/null +++ b/pkg/router/config.go @@ -0,0 +1,30 @@ +package router + +import ( + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/router/routing" +) + +// TODO: how to specify other backoff strategies? +// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 + +type Config struct { + ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID + Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? + Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router + RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request +} + +func DefaultConfig() Config { + return Config{ + Enabled: true, + RoutingStrategy: routing.Priority, + Retry: retry.DefaultExpRetryConfig(), + } +} + +// RoutersConfig defines a config for a set of supported router types +type RoutersConfig struct { + LanguageRouters LangRoutersConfig `yaml:"language" validate:"required,dive"` // the list of language routers + EmbeddingRouters EmbedRoutersConfig `yaml:"embedding" validate:"required,dive"` +} diff --git a/pkg/router/embed_config.go b/pkg/router/embed_config.go new file mode 100644 index 00000000..7eb442a4 --- /dev/null +++ b/pkg/router/embed_config.go @@ -0,0 +1,59 @@ +package router + +import ( + "fmt" + + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/provider" + "github.com/EinStack/glide/pkg/telemetry" + "go.uber.org/multierr" + "go.uber.org/zap" +) + +type ( + EmbedModelConfig = extmodel.Config[*provider.Config] + EmbedModelPoolConfig = []EmbedModelConfig +) + +type EmbedRouterConfig struct { + Config + Models EmbedModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} + +type EmbedRoutersConfig []EmbedRouterConfig + +func (c EmbedRoutersConfig) Build(tel *telemetry.Telemetry) ([]*EmbedRouter, error) { + seenIDs := make(map[string]bool, len(c)) + routers := make([]*EmbedRouter, 0, len(c)) + + var errs error + + for idx, routerConfig := range c { + if _, ok := seenIDs[routerConfig.ID]; ok { + return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) + } + + seenIDs[routerConfig.ID] = true + + if !routerConfig.Enabled { + tel.L().Info(fmt.Sprintf("Embed router \"%v\" is disabled, skipping", routerConfig.ID)) + continue + } + + tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) + + r, err := NewEmbedRouter(&c[idx], tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + routers = append(routers, r) + } + + if errs != nil { + return nil, errs + } + + return routers, nil +} diff --git a/pkg/router/embed_router.go b/pkg/router/embed_router.go new file mode 100644 index 00000000..b2ad8a59 --- /dev/null +++ b/pkg/router/embed_router.go @@ -0,0 +1,26 @@ +package router + +import ( + "context" + + "github.com/EinStack/glide/pkg/api/schema" + + "github.com/EinStack/glide/pkg/telemetry" +) + +type EmbedRouter struct { + // routerID lang.RouterID + // Config *LangRouterConfig + // retry *retry.ExpRetry + // tel *telemetry.Telemetry + // logger *zap.Logger +} + +func NewEmbedRouter(_ *EmbedRouterConfig, _ *telemetry.Telemetry) (*EmbedRouter, error) { + // TODO: implement + return &EmbedRouter{}, nil +} + +func (r *EmbedRouter) Embed(ctx context.Context, req *schema.EmbedRequest) (*schema.EmbedResponse, error) { + // TODO: implement +} diff --git a/pkg/routers/config.go b/pkg/router/lang_config.go similarity index 63% rename from pkg/routers/config.go rename to pkg/router/lang_config.go index c7651f6d..007a3769 100644 --- a/pkg/routers/config.go +++ b/pkg/router/lang_config.go @@ -1,78 +1,60 @@ -package routers +package router import ( "fmt" "time" - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/provider" - "github.com/EinStack/glide/pkg/routers/retry" + "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/router/routing" + "github.com/EinStack/glide/pkg/telemetry" "go.uber.org/multierr" "go.uber.org/zap" ) -type Config struct { - LanguageRouters []LangRouterConfig `yaml:"language" validate:"required,gte=1,dive"` // the list of language routers -} - -func (c *Config) BuildLangRouters(tel *telemetry.Telemetry) ([]*LangRouter, error) { - seenIDs := make(map[string]bool, len(c.LanguageRouters)) - routers := make([]*LangRouter, 0, len(c.LanguageRouters)) - - var errs error - - for idx, routerConfig := range c.LanguageRouters { - if _, ok := seenIDs[routerConfig.ID]; ok { - return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) - } - - seenIDs[routerConfig.ID] = true - - if !routerConfig.Enabled { - tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) - continue - } +type ( + LangModelConfig = extmodel.Config[*provider.Config] + LangModelPoolConfig = []LangModelConfig +) - tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) +// LangRouterConfig +type LangRouterConfig struct { + Config + Models LangModelPoolConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests +} - router, err := NewLangRouter(&c.LanguageRouters[idx], tel) - if err != nil { - errs = multierr.Append(errs, err) - continue - } +type ConfigOption = func(*LangRouterConfig) - routers = append(routers, router) +func WithModels(models LangModelPoolConfig) ConfigOption { + return func(c *LangRouterConfig) { + c.Models = models } +} - if errs != nil { - return nil, errs +func NewRouterConfig(RouterID string, opt ...ConfigOption) *LangRouterConfig { + cfg := &LangRouterConfig{ + Config: DefaultConfig(), } - return routers, nil -} + cfg.ID = RouterID -// TODO: how to specify other backoff strategies? -// TODO: Had to keep RoutingStrategy because of https://github.com/swaggo/swag/issues/1738 -// LangRouterConfig -type LangRouterConfig struct { - ID string `yaml:"id" json:"routers" validate:"required"` // Unique router ID - Enabled bool `yaml:"enabled" json:"enabled" validate:"required"` // Is router enabled? - Retry *retry.ExpRetryConfig `yaml:"retry" json:"retry" validate:"required"` // retry when no healthy model is available to router - RoutingStrategy routing.Strategy `yaml:"strategy" json:"strategy" swaggertype:"primitive,string" validate:"required"` // strategy on picking the next model to serve the request - Models []providers.LangModelConfig `yaml:"models" json:"models" validate:"required,min=1,dive"` // the list of models that could handle requests + for _, o := range opt { + o(cfg) + } + + return cfg } // BuildModels creates LanguageModel slice out of the given config -func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*providers.LanguageModel, []*providers.LanguageModel, error) { //nolint: cyclop +func (c *LangRouterConfig) BuildModels(tel *telemetry.Telemetry) ([]*extmodel.LanguageModel, []*extmodel.LanguageModel, error) { //nolint: cyclop var errs error seenIDs := make(map[string]bool, len(c.Models)) - chatModels := make([]*providers.LanguageModel, 0, len(c.Models)) - chatStreamModels := make([]*providers.LanguageModel, 0, len(c.Models)) + chatModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) + chatStreamModels := make([]*extmodel.LanguageModel, 0, len(c.Models)) for _, modelConfig := range c.Models { if _, ok := seenIDs[modelConfig.ID]; ok { @@ -176,11 +158,11 @@ func (c *LangRouterConfig) BuildRetry() *retry.ExpRetry { } func (c *LangRouterConfig) BuildRouting( - chatModels []*providers.LanguageModel, - chatStreamModels []*providers.LanguageModel, + chatModels []*extmodel.LanguageModel, + chatStreamModels []*extmodel.LanguageModel, ) (routing.LangModelRouting, routing.LangModelRouting, error) { - chatModelPool := make([]providers.Model, 0, len(chatModels)) - chatStreamModelPool := make([]providers.Model, 0, len(chatStreamModels)) + chatModelPool := make([]extmodel.Interface, 0, len(chatModels)) + chatStreamModelPool := make([]extmodel.Interface, 0, len(chatStreamModels)) for _, model := range chatModels { chatModelPool = append(chatModelPool, model) @@ -198,26 +180,62 @@ func (c *LangRouterConfig) BuildRouting( case routing.WeightedRoundRobin: return routing.NewWeightedRoundRobin(chatModelPool), routing.NewWeightedRoundRobin(chatStreamModelPool), nil case routing.LeastLatency: - return routing.NewLeastLatencyRouting(providers.ChatLatency, chatModelPool), - routing.NewLeastLatencyRouting(providers.ChatStreamLatency, chatStreamModelPool), + return routing.NewLeastLatencyRouting(extmodel.ChatLatency, chatModelPool), + routing.NewLeastLatencyRouting(extmodel.ChatStreamLatency, chatStreamModelPool), nil } return nil, nil, fmt.Errorf("routing strategy \"%v\" is not supported, please make sure there is no typo", c.RoutingStrategy) } -func DefaultLangRouterConfig() LangRouterConfig { - return LangRouterConfig{ - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), +func DefaultRouterConfig() *LangRouterConfig { + return &LangRouterConfig{ + Config: DefaultConfig(), } } -func (c *LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - *c = DefaultLangRouterConfig() +func (c LangRouterConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + c = *DefaultRouterConfig() type plain LangRouterConfig // to avoid recursion - return unmarshal((*plain)(c)) + return unmarshal((plain)(c)) +} + +type LangRoutersConfig []LangRouterConfig + +func (c LangRoutersConfig) Build(tel *telemetry.Telemetry) ([]*LangRouter, error) { + seenIDs := make(map[string]bool, len(c)) + langRouters := make([]*LangRouter, 0, len(c)) + + var errs error + + for idx, routerConfig := range c { + if _, ok := seenIDs[routerConfig.ID]; ok { + return nil, fmt.Errorf("ID \"%v\" is specified for more than one router while each ID should be unique", routerConfig.ID) + } + + seenIDs[routerConfig.ID] = true + + if !routerConfig.Enabled { + tel.L().Info(fmt.Sprintf("Router \"%v\" is disabled, skipping", routerConfig.ID)) + continue + } + + tel.L().Debug("Init router", zap.String("routerID", routerConfig.ID)) + + router, err := NewLangRouter(&c[idx], tel) + if err != nil { + errs = multierr.Append(errs, err) + continue + } + + langRouters = append(langRouters, router) + } + + if errs != nil { + return nil, errs + } + + return langRouters, nil } diff --git a/pkg/router/lang_config_test.go b/pkg/router/lang_config_test.go new file mode 100644 index 00000000..71d8d829 --- /dev/null +++ b/pkg/router/lang_config_test.go @@ -0,0 +1,214 @@ +package router + +import ( + "testing" + + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/provider/cohere" + "github.com/EinStack/glide/pkg/provider/openai" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/router/latency" + "github.com/EinStack/glide/pkg/router/routing" + "github.com/EinStack/glide/pkg/telemetry" + "github.com/stretchr/testify/require" +) + +func TestRouterConfig_BuildModels(t *testing.T) { + defaultParams := openai.DefaultParams() + + cfg := LangRoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }), + ), + *NewRouterConfig( + "second_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }), + ), + } + + routers, err := cfg.Build(telemetry.NewTelemetryMock()) + + require.NoError(t, err) + require.Len(t, routers, 2) + require.Len(t, routers[0].chatModels, 1) + require.IsType(t, &routing.PriorityRouting{}, routers[0].chatRouting) + require.Len(t, routers[1].chatModels, 1) + require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].chatRouting) +} + +func TestRouterConfig_BuildModelsPerType(t *testing.T) { + tel := telemetry.NewTelemetryMock() + openAIParams := openai.DefaultParams() + cohereParams := cohere.DefaultParams() + + cfg := NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &openAIParams, + }, + }, + }, + { + ID: "second_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + cohere.ProviderID: &cohere.Config{ + APIKey: "ABC", + DefaultParams: &cohereParams, + }, + }, + }, + }), + ) + + chatModels, streamChatModels, err := cfg.BuildModels(tel) + + require.Len(t, chatModels, 2) + require.Len(t, streamChatModels, 2) + require.NoError(t, err) +} + +func TestRouterConfig_InvalidSetups(t *testing.T) { + defaultParams := openai.DefaultParams() + + tests := []struct { + name string + config LangRoutersConfig + }{ + { + "duplicated router IDs", + LangRoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }), + ), + *NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }), + ), + }, + }, + { + "duplicated model IDs", + LangRoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{ + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + { + ID: "first_model", + Enabled: true, + Client: clients.DefaultClientConfig(), + ErrorBudget: health.DefaultErrorBudget(), + Latency: latency.DefaultConfig(), + Provider: &provider.Config{ + openai.ProviderID: &openai.Config{ + APIKey: "ABC", + DefaultParams: &defaultParams, + }, + }, + }, + }), + ), + }, + }, + { + "no models", + LangRoutersConfig{ + *NewRouterConfig( + "first_router", + WithModels(LangModelPoolConfig{}), + ), + }, + }, + } + + for _, test := range tests { + _, err := test.config.Build(telemetry.NewTelemetryMock()) + + require.Error(t, err) + } +} diff --git a/pkg/routers/router.go b/pkg/router/lang_router.go similarity index 81% rename from pkg/routers/router.go rename to pkg/router/lang_router.go index 4a7d0d0f..ec2d9113 100644 --- a/pkg/routers/router.go +++ b/pkg/router/lang_router.go @@ -1,30 +1,28 @@ -package routers +package router import ( "context" "errors" - "github.com/EinStack/glide/pkg/routers/retry" - "go.uber.org/zap" + "github.com/EinStack/glide/pkg/api/schema" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" + "github.com/EinStack/glide/pkg/resiliency/retry" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" - - "github.com/EinStack/glide/pkg/api/schemas" + "go.uber.org/zap" ) var ErrNoModels = errors.New("no models configured for router") -type RouterID = string +type ID = string type LangRouter struct { - routerID RouterID + routerID ID Config *LangRouterConfig - chatModels []*providers.LanguageModel - chatStreamModels []*providers.LanguageModel + chatModels []*extmodel.LanguageModel + chatStreamModels []*extmodel.LanguageModel chatRouting routing.LangModelRouting chatStreamRouting routing.LangModelRouting retry *retry.ExpRetry @@ -58,11 +56,11 @@ func NewLangRouter(cfg *LangRouterConfig, tel *telemetry.Telemetry) (*LangRouter return router, err } -func (r *LangRouter) ID() RouterID { +func (r *LangRouter) ID() ID { return r.routerID } -func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schemas.ChatResponse, error) { +func (r *LangRouter) Chat(ctx context.Context, req *schema.ChatRequest) (*schema.ChatResponse, error) { if len(r.chatModels) == 0 { return nil, ErrNoModels } @@ -80,7 +78,7 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem break } - langModel := model.(providers.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) @@ -115,22 +113,22 @@ func (r *LangRouter) Chat(ctx context.Context, req *schemas.ChatRequest) (*schem // if we reach this part, then we are in trouble r.logger.Error("No model was available to handle chat request") - return nil, &schemas.ErrNoModelAvailable + return nil, &schema.ErrNoModelAvailable } func (r *LangRouter) ChatStream( ctx context.Context, - req *schemas.ChatStreamRequest, - respC chan<- *schemas.ChatStreamMessage, + req *schema.ChatStreamRequest, + respC chan<- *schema.ChatStreamMessage, ) { if len(r.chatStreamModels) == 0 { - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.NoModelConfigured, + schema.NoModelConfigured, ErrNoModels.Error(), req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) return @@ -150,7 +148,7 @@ func (r *LangRouter) ChatStream( break } - langModel := model.(providers.LangModel) + langModel := model.(extmodel.LangModel) chatParams := req.Params(langModel.ID(), langModel.ModelName()) modelRespC, err := langModel.ChatStream(ctx, chatParams) @@ -178,10 +176,10 @@ func (r *LangRouter) ChatStream( // It's challenging to hide an error in case of streaming chat as consumer apps // may have already used all chunks we streamed this far (e.g. showed them to their users like OpenAI UI does), // so we cannot easily restart that process from scratch - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ModelUnavailable, + schema.ModelUnavailable, err.Error(), req.Metadata, nil, @@ -192,7 +190,7 @@ func (r *LangRouter) ChatStream( chunk := chunkResult.Chunk() - respC <- schemas.NewChatStreamChunk( + respC <- schema.NewChatStreamChunk( req.ID, r.routerID, req.Metadata, @@ -210,10 +208,10 @@ func (r *LangRouter) ChatStream( err := retryIterator.WaitNext(ctx) if err != nil { // something has cancelled the context - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.UnknownError, + schema.UnknownError, err.Error(), req.Metadata, nil, @@ -229,12 +227,12 @@ func (r *LangRouter) ChatStream( "Try to configure more fallback models to avoid this", ) - respC <- schemas.NewChatStreamError( + respC <- schema.NewChatStreamError( req.ID, r.routerID, - schemas.ErrNoModelAvailable.Name, - schemas.ErrNoModelAvailable.Message, + schema.ErrNoModelAvailable.Name, + schema.ErrNoModelAvailable.Message, req.Metadata, - &schemas.ReasonError, + &schema.ReasonError, ) } diff --git a/pkg/routers/router_test.go b/pkg/router/lang_router_test.go similarity index 56% rename from pkg/routers/router_test.go rename to pkg/router/lang_router_test.go index f56216e3..68c31cab 100644 --- a/pkg/routers/router_test.go +++ b/pkg/router/lang_router_test.go @@ -1,18 +1,20 @@ -package routers +package router import ( "context" "testing" "time" - "github.com/EinStack/glide/pkg/api/schemas" - "github.com/EinStack/glide/pkg/providers" - "github.com/EinStack/glide/pkg/providers/clients" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - "github.com/EinStack/glide/pkg/routers/health" - "github.com/EinStack/glide/pkg/routers/latency" - "github.com/EinStack/glide/pkg/routers/retry" - "github.com/EinStack/glide/pkg/routers/routing" + "github.com/EinStack/glide/pkg/provider" + + "github.com/EinStack/glide/pkg/extmodel" + + "github.com/EinStack/glide/pkg/clients" + "github.com/EinStack/glide/pkg/resiliency/health" + "github.com/EinStack/glide/pkg/resiliency/retry" + + "github.com/EinStack/glide/pkg/router/latency" + "github.com/EinStack/glide/pkg/router/routing" "github.com/EinStack/glide/pkg/telemetry" "github.com/stretchr/testify/require" ) @@ -21,40 +23,39 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}}), budget, *latConfig, 1, ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for i := 0; i < 2; i++ { resp, err := router.Chat(ctx, req) @@ -68,43 +69,42 @@ func TestLangRouter_Chat_PickFistHealthy(t *testing.T) { func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "4"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "4"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "third", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } expectedModels := []string{"third", "third"} router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), @@ -112,7 +112,7 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { } ctx := context.Background() - req := schemas.NewChatFromStr("tell me a dad joke") + req := schema.NewChatFromStr("tell me a dad joke") for _, modelID := range expectedModels { resp, err := router.Chat(ctx, req) @@ -126,41 +126,40 @@ func TestLangRouter_Chat_PickThirdHealthy(t *testing.T) { func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { budget := health.NewErrorBudget(1, health.MILLI) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "2"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Msg: "1"}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Msg: "1"}}), budget, *latConfig, 1, ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), - chatStreamRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), + chatStreamRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "first", resp.ModelID) @@ -170,42 +169,41 @@ func TestLangRouter_Chat_SuccessOnRetry(t *testing.T) { func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { budget := health.NewErrorBudget(1, health.MIN) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), + provider.NewMock(nil, []provider.RespMock{{Err: clients.ErrProviderUnavailable}, {Msg: "3"}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Msg: "1"}, {Msg: "2"}}), + provider.NewMock(nil, []provider.RespMock{{Msg: "1"}, {Msg: "2"}}), budget, *latConfig, 1, ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } for i := 0; i < 2; i++ { - resp, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + resp, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.NoError(t, err) require.Equal(t, "second", resp.ModelID) @@ -216,41 +214,40 @@ func TestLangRouter_Chat_UnhealthyModelInThePool(t *testing.T) { func TestLangRouter_Chat_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewProviderMock(nil, []ptesting.RespMock{{Err: &schemas.ErrNoModelAvailable}, {Err: &schemas.ErrNoModelAvailable}}), + provider.NewMock(nil, []provider.RespMock{{Err: &schema.ErrNoModelAvailable}, {Err: &schema.ErrNoModelAvailable}}), budget, *latConfig, 1, ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } - _, err := router.Chat(context.Background(), schemas.NewChatFromStr("tell me a dad joke")) + _, err := router.Chat(context.Background(), schema.NewChatFromStr("tell me a dad joke")) require.Error(t, err) } @@ -259,11 +256,11 @@ func TestLangRouter_ChatStream(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Msg: "Bill"}, {Msg: "Gates"}, {Msg: "entered"}, @@ -275,10 +272,10 @@ func TestLangRouter_ChatStream(t *testing.T) { *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Msg: "Knock"}, {Msg: "Knock"}, {Msg: "joke"}, @@ -290,26 +287,25 @@ func TestLangRouter_ChatStream(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -335,19 +331,19 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { budget := health.NewErrorBudget(3, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, nil), + provider.NewStreamProviderMock(nil, nil), budget, *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock( - &[]ptesting.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock( + &[]provider.RespMock{ {Msg: "Knock"}, {Msg: "knock"}, {Msg: "joke"}, @@ -360,26 +356,25 @@ func TestLangRouter_ChatStream_FailOnFirst(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_stream_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(3, 2, 1*time.Second, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), chatStreamModels: langModels, tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } ctx := context.Background() - req := schemas.NewChatStreamFromStr("tell me a dad joke") - respC := make(chan *schemas.ChatStreamMessage) + req := schema.NewChatStreamFromStr("tell me a dad joke") + respC := make(chan *schema.ChatStreamMessage) defer close(respC) @@ -405,11 +400,11 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { budget := health.NewErrorBudget(1, health.SEC) latConfig := latency.DefaultConfig() - langModels := []*providers.LanguageModel{ - providers.NewLangModel( + langModels := []*extmodel.LanguageModel{ + extmodel.NewLangModel( "first", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -417,10 +412,10 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { *latConfig, 1, ), - providers.NewLangModel( + extmodel.NewLangModel( "second", - ptesting.NewStreamProviderMock(nil, []ptesting.RespStreamMock{ - ptesting.NewRespStreamMock(&[]ptesting.RespMock{ + provider.NewStreamProviderMock(nil, []provider.RespStreamMock{ + provider.NewRespStreamMock(&[]provider.RespMock{ {Err: clients.ErrProviderUnavailable}, }), }), @@ -430,27 +425,26 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { ), } - models := make([]providers.Model, 0, len(langModels)) + modelPool := make([]extmodel.Interface, 0, len(langModels)) for _, model := range langModels { - models = append(models, model) + modelPool = append(modelPool, model) } router := LangRouter{ routerID: "test_router", - Config: &LangRouterConfig{}, retry: retry.NewExpRetry(1, 2, 1*time.Millisecond, nil), - chatRouting: routing.NewPriority(models), + chatRouting: routing.NewPriority(modelPool), chatModels: langModels, chatStreamModels: langModels, - chatStreamRouting: routing.NewPriority(models), + chatStreamRouting: routing.NewPriority(modelPool), tel: telemetry.NewTelemetryMock(), logger: telemetry.NewLoggerMock(), } - respC := make(chan *schemas.ChatStreamMessage) + respC := make(chan *schema.ChatStreamMessage) defer close(respC) - go router.ChatStream(context.Background(), schemas.NewChatStreamFromStr("tell me a dad joke"), respC) + go router.ChatStream(context.Background(), schema.NewChatStreamFromStr("tell me a dad joke"), respC) errs := make([]string, 0, 3) @@ -462,5 +456,5 @@ func TestLangRouter_ChatStream_AllModelsUnavailable(t *testing.T) { errs = append(errs, result.Error.Name) } - require.Equal(t, []string{schemas.ModelUnavailable, schemas.ModelUnavailable, schemas.AllModelsUnavailable}, errs) + require.Equal(t, []string{schema.ModelUnavailable, schema.ModelUnavailable, schema.AllModelsUnavailable}, errs) } diff --git a/pkg/routers/latency/config.go b/pkg/router/latency/config.go similarity index 100% rename from pkg/routers/latency/config.go rename to pkg/router/latency/config.go diff --git a/pkg/routers/latency/config_test.go b/pkg/router/latency/config_test.go similarity index 100% rename from pkg/routers/latency/config_test.go rename to pkg/router/latency/config_test.go diff --git a/pkg/routers/latency/moving_average.go b/pkg/router/latency/moving_average.go similarity index 100% rename from pkg/routers/latency/moving_average.go rename to pkg/router/latency/moving_average.go diff --git a/pkg/routers/latency/moving_average_test.go b/pkg/router/latency/moving_average_test.go similarity index 100% rename from pkg/routers/latency/moving_average_test.go rename to pkg/router/latency/moving_average_test.go diff --git a/pkg/routers/manager.go b/pkg/router/manager.go similarity index 62% rename from pkg/routers/manager.go rename to pkg/router/manager.go index 123ea09e..b0d2fd69 100644 --- a/pkg/routers/manager.go +++ b/pkg/router/manager.go @@ -1,20 +1,20 @@ -package routers +package router import ( - "github.com/EinStack/glide/pkg/api/schemas" + "github.com/EinStack/glide/pkg/api/schema" "github.com/EinStack/glide/pkg/telemetry" ) -type RouterManager struct { - Config *Config +type Manager struct { + Config *RoutersConfig tel *telemetry.Telemetry langRouterMap *map[string]*LangRouter langRouters []*LangRouter } // NewManager creates a new instance of Router Manager that creates, holds and returns all routers -func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { - langRouters, err := cfg.BuildLangRouters(tel) +func NewManager(cfg *RoutersConfig, tel *telemetry.Telemetry) (*Manager, error) { + langRouters, err := cfg.LanguageRouters.Build(tel) if err != nil { return nil, err } @@ -25,7 +25,7 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { langRouterMap[router.ID()] = router } - manager := RouterManager{ + manager := Manager{ Config: cfg, tel: tel, langRouters: langRouters, @@ -35,15 +35,15 @@ func NewManager(cfg *Config, tel *telemetry.Telemetry) (*RouterManager, error) { return &manager, err } -func (r *RouterManager) GetLangRouters() []*LangRouter { +func (r *Manager) GetLangRouters() []*LangRouter { return r.langRouters } // GetLangRouter returns a router by type and ID -func (r *RouterManager) GetLangRouter(routerID string) (*LangRouter, error) { +func (r *Manager) GetLangRouter(routerID string) (*LangRouter, error) { if router, found := (*r.langRouterMap)[routerID]; found { return router, nil } - return nil, &schemas.ErrRouterNotFound + return nil, &schema.ErrRouterNotFound } diff --git a/pkg/routers/routing/least_latency.go b/pkg/router/routing/least_latency.go similarity index 91% rename from pkg/routers/routing/least_latency.go rename to pkg/router/routing/least_latency.go index 015c044e..e233c20e 100644 --- a/pkg/routers/routing/least_latency.go +++ b/pkg/router/routing/least_latency.go @@ -5,9 +5,9 @@ import ( "sync/atomic" "time" - "github.com/EinStack/glide/pkg/routers/latency" + "github.com/EinStack/glide/pkg/extmodel" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/router/latency" ) const ( @@ -15,16 +15,16 @@ const ( ) // LatencyGetter defines where to find latency for the specific model action -type LatencyGetter = func(model providers.Model) *latency.MovingAverage +type LatencyGetter = func(model extmodel.Interface) *latency.MovingAverage // ModelSchedule defines latency update schedule for models type ModelSchedule struct { mu sync.RWMutex - model providers.Model + model extmodel.Interface expireAt time.Time } -func NewSchedule(model providers.Model) *ModelSchedule { +func NewSchedule(model extmodel.Interface) *ModelSchedule { schedule := &ModelSchedule{ model: model, } @@ -67,7 +67,7 @@ type LeastLatencyRouting struct { schedules []*ModelSchedule } -func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []providers.Model) *LeastLatencyRouting { +func NewLeastLatencyRouting(latencyGetter LatencyGetter, models []extmodel.Interface) *LeastLatencyRouting { schedules := make([]*ModelSchedule, 0, len(models)) for _, model := range models { @@ -95,7 +95,7 @@ func (r *LeastLatencyRouting) Iterator() LangModelIterator { // other model latencies that might have improved over time). // For that, we introduced expiration time after which the model receives a request // even if it was not the fastest to respond -func (r *LeastLatencyRouting) Next() (providers.Model, error) { //nolint:cyclop +func (r *LeastLatencyRouting) Next() (extmodel.Interface, error) { //nolint:cyclop coldSchedules := r.getColdModelSchedules() if len(coldSchedules) > 0 { diff --git a/pkg/routers/routing/least_latency_test.go b/pkg/router/routing/least_latency_test.go similarity index 86% rename from pkg/routers/routing/least_latency_test.go rename to pkg/router/routing/least_latency_test.go index 0ed9c51b..0e6618f2 100644 --- a/pkg/routers/routing/least_latency_test.go +++ b/pkg/router/routing/least_latency_test.go @@ -5,9 +5,7 @@ import ( "testing" "time" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -33,13 +31,13 @@ func TestLeastLatencyRouting_Warmup(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, model.latency, 1)) } - routing := NewLeastLatencyRouting(ptesting.ChatMockLatency, models) + routing := NewLeastLatencyRouting(extmodel.ChatMockLatency, modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -107,7 +105,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { for _, model := range tc.models { schedules = append(schedules, &ModelSchedule{ - model: ptesting.NewLangModelMock( + model: extmodel.NewLangModelMock( model.modelID, model.healthy, model.latency, @@ -118,7 +116,7 @@ func TestLeastLatencyRouting_Routing(t *testing.T) { } routing := LeastLatencyRouting{ - latencyGetter: ptesting.ChatMockLatency, + latencyGetter: extmodel.ChatMockLatency, schedules: schedules, } @@ -144,13 +142,13 @@ func TestLeastLatencyRouting_NoHealthyModels(t *testing.T) { for name, latencies := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(latencies)) + modelPool := make([]extmodel.Interface, 0, len(latencies)) for idx, latency := range latencies { - models = append(models, ptesting.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(strconv.Itoa(idx), false, latency, 1)) } - routing := NewLeastLatencyRouting(providers.ChatLatency, models) + routing := NewLeastLatencyRouting(extmodel.ChatLatency, modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/priority.go b/pkg/router/routing/priority.go similarity index 68% rename from pkg/routers/routing/priority.go rename to pkg/router/routing/priority.go index f895458c..7cf5ceeb 100644 --- a/pkg/routers/routing/priority.go +++ b/pkg/router/routing/priority.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -15,10 +15,10 @@ const ( // Priority of models are defined as position of the model on the list // (e.g. the first model definition has the highest priority, then the second model definition and so on) type PriorityRouting struct { - models []providers.Model + models []extmodel.Interface } -func NewPriority(models []providers.Model) *PriorityRouting { +func NewPriority(models []extmodel.Interface) *PriorityRouting { return &PriorityRouting{ models: models, } @@ -35,14 +35,14 @@ func (r *PriorityRouting) Iterator() LangModelIterator { type PriorityIterator struct { idx *atomic.Uint64 - models []providers.Model + models []extmodel.Interface } -func (r PriorityIterator) Next() (providers.Model, error) { - models := r.models +func (r PriorityIterator) Next() (extmodel.Interface, error) { + modelPool := r.models - for idx := int(r.idx.Load()); idx < len(models); idx = int(r.idx.Add(1)) { - model := models[idx] + for idx := int(r.idx.Load()); idx < len(modelPool); idx = int(r.idx.Add(1)) { + model := modelPool[idx] if !model.Healthy() { continue diff --git a/pkg/routers/routing/priority_test.go b/pkg/router/routing/priority_test.go similarity index 71% rename from pkg/routers/routing/priority_test.go rename to pkg/router/routing/priority_test.go index cee98c60..98e27e7d 100644 --- a/pkg/routers/routing/priority_test.go +++ b/pkg/router/routing/priority_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -29,13 +27,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() // loop three times over the whole pool to check if we return back to the begging of the list @@ -49,13 +47,13 @@ func TestPriorityRouting_PickModelsInOrder(t *testing.T) { } func TestPriorityRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } - routing := NewPriority(models) + routing := NewPriority(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/round_robin.go b/pkg/router/routing/round_robin.go similarity index 78% rename from pkg/routers/routing/round_robin.go rename to pkg/router/routing/round_robin.go index e5a0f927..7582cbcb 100644 --- a/pkg/routers/routing/round_robin.go +++ b/pkg/router/routing/round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync/atomic" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -13,10 +13,10 @@ const ( // RoundRobinRouting routes request to the next model in the list in cycle type RoundRobinRouting struct { idx atomic.Uint64 - models []providers.Model + models []extmodel.Interface } -func NewRoundRobinRouting(models []providers.Model) *RoundRobinRouting { +func NewRoundRobinRouting(models []extmodel.Interface) *RoundRobinRouting { return &RoundRobinRouting{ models: models, } @@ -26,7 +26,7 @@ func (r *RoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *RoundRobinRouting) Next() (providers.Model, error) { +func (r *RoundRobinRouting) Next() (extmodel.Interface, error) { modelLen := len(r.models) // in order to avoid infinite loop in case of no healthy model is available, diff --git a/pkg/routers/routing/round_robin_test.go b/pkg/router/routing/round_robin_test.go similarity index 72% rename from pkg/routers/routing/round_robin_test.go rename to pkg/router/routing/round_robin_test.go index fc34ec48..7287f468 100644 --- a/pkg/routers/routing/round_robin_test.go +++ b/pkg/router/routing/round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -30,13 +28,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 100, 1)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 100, 1)) } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() for i := 0; i < 3; i++ { @@ -52,13 +50,13 @@ func TestRoundRobinRouting_PickModelsSequentially(t *testing.T) { } func TestRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 1), - ptesting.NewLangModelMock("third", false, 0, 1), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 1), + extmodel.NewLangModelMock("third", false, 0, 1), } - routing := NewRoundRobinRouting(models) + routing := NewRoundRobinRouting(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/routing/strategies.go b/pkg/router/routing/strategies.go similarity index 79% rename from pkg/routers/routing/strategies.go rename to pkg/router/routing/strategies.go index 56f03676..48d18ab6 100644 --- a/pkg/routers/routing/strategies.go +++ b/pkg/router/routing/strategies.go @@ -3,7 +3,7 @@ package routing import ( "errors" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" ) var ErrNoHealthyModels = errors.New("no healthy models found") @@ -16,5 +16,5 @@ type LangModelRouting interface { } type LangModelIterator interface { - Next() (providers.Model, error) + Next() (extmodel.Interface, error) } diff --git a/pkg/routers/routing/weighted_round_robin.go b/pkg/router/routing/weighted_round_robin.go similarity index 85% rename from pkg/routers/routing/weighted_round_robin.go rename to pkg/router/routing/weighted_round_robin.go index 2e028408..418add91 100644 --- a/pkg/routers/routing/weighted_round_robin.go +++ b/pkg/router/routing/weighted_round_robin.go @@ -3,7 +3,7 @@ package routing import ( "sync" - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" ) const ( @@ -11,7 +11,7 @@ const ( ) type Weighter struct { - model providers.Model + model extmodel.Interface currentWeight int } @@ -36,7 +36,7 @@ type WRoundRobinRouting struct { weights []*Weighter } -func NewWeightedRoundRobin(models []providers.Model) *WRoundRobinRouting { +func NewWeightedRoundRobin(models []extmodel.Interface) *WRoundRobinRouting { weights := make([]*Weighter, 0, len(models)) for _, model := range models { @@ -55,7 +55,7 @@ func (r *WRoundRobinRouting) Iterator() LangModelIterator { return r } -func (r *WRoundRobinRouting) Next() (providers.Model, error) { +func (r *WRoundRobinRouting) Next() (extmodel.Interface, error) { r.mu.Lock() defer r.mu.Unlock() diff --git a/pkg/routers/routing/weighted_round_robin_test.go b/pkg/router/routing/weighted_round_robin_test.go similarity index 81% rename from pkg/routers/routing/weighted_round_robin_test.go rename to pkg/router/routing/weighted_round_robin_test.go index f4b59bb3..7ec9b24c 100644 --- a/pkg/routers/routing/weighted_round_robin_test.go +++ b/pkg/router/routing/weighted_round_robin_test.go @@ -3,9 +3,7 @@ package routing import ( "testing" - ptesting "github.com/EinStack/glide/pkg/providers/testing" - - "github.com/EinStack/glide/pkg/providers" + "github.com/EinStack/glide/pkg/extmodel" "github.com/stretchr/testify/require" ) @@ -116,13 +114,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { - models := make([]providers.Model, 0, len(tc.models)) + modelPool := make([]extmodel.Interface, 0, len(tc.models)) for _, model := range tc.models { - models = append(models, ptesting.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) + modelPool = append(modelPool, extmodel.NewLangModelMock(model.modelID, model.healthy, 0, model.weight)) } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() actualDistribution := make(map[string]int, len(tc.models)) @@ -142,13 +140,13 @@ func TestWRoundRobinRouting_RoutingDistribution(t *testing.T) { } func TestWRoundRobinRouting_NoHealthyModels(t *testing.T) { - models := []providers.Model{ - ptesting.NewLangModelMock("first", false, 0, 1), - ptesting.NewLangModelMock("second", false, 0, 2), - ptesting.NewLangModelMock("third", false, 0, 3), + modelPool := []extmodel.Interface{ + extmodel.NewLangModelMock("first", false, 0, 1), + extmodel.NewLangModelMock("second", false, 0, 2), + extmodel.NewLangModelMock("third", false, 0, 3), } - routing := NewWeightedRoundRobin(models) + routing := NewWeightedRoundRobin(modelPool) iterator := routing.Iterator() _, err := iterator.Next() diff --git a/pkg/routers/config_test.go b/pkg/routers/config_test.go deleted file mode 100644 index d740df2c..00000000 --- a/pkg/routers/config_test.go +++ /dev/null @@ -1,236 +0,0 @@ -package routers - -import ( - "testing" - - "github.com/EinStack/glide/pkg/providers/cohere" - - "github.com/EinStack/glide/pkg/telemetry" - - "github.com/EinStack/glide/pkg/routers/routing" - - "github.com/EinStack/glide/pkg/routers/retry" - - "github.com/EinStack/glide/pkg/routers/latency" - - "github.com/EinStack/glide/pkg/routers/health" - - "github.com/EinStack/glide/pkg/providers/openai" - - "github.com/EinStack/glide/pkg/providers/clients" - - "github.com/EinStack/glide/pkg/providers" - - "github.com/stretchr/testify/require" -) - -func TestRouterConfig_BuildModels(t *testing.T) { - defaultParams := openai.DefaultParams() - - cfg := Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - }, - }, - { - ID: "second_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - }, - }, - }, - } - - routers, err := cfg.BuildLangRouters(telemetry.NewTelemetryMock()) - - require.NoError(t, err) - require.Len(t, routers, 2) - require.Len(t, routers[0].chatModels, 1) - require.IsType(t, &routing.PriorityRouting{}, routers[0].chatRouting) - require.Len(t, routers[1].chatModels, 1) - require.IsType(t, &routing.LeastLatencyRouting{}, routers[1].chatRouting) -} - -func TestRouterConfig_BuildModelsPerType(t *testing.T) { - tel := telemetry.NewTelemetryMock() - openAIParams := openai.DefaultParams() - cohereParams := cohere.DefaultParams() - - cfg := LangRouterConfig{ - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &openAIParams, - }, - }, - { - ID: "second_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - Cohere: &cohere.Config{ - APIKey: "ABC", - DefaultParams: &cohereParams, - }, - }, - }, - } - - chatModels, streamChatModels, err := cfg.BuildModels(tel) - - require.Len(t, chatModels, 2) - require.Len(t, streamChatModels, 2) - require.NoError(t, err) -} - -func TestRouterConfig_InvalidSetups(t *testing.T) { - defaultParams := openai.DefaultParams() - - tests := []struct { - name string - config Config - }{ - { - "duplicated router IDs", - Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - }, - }, - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.LeastLatency, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - }, - }, - }, - }, - }, - { - "duplicated model IDs", - Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{ - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - { - ID: "first_model", - Enabled: true, - Client: clients.DefaultClientConfig(), - ErrorBudget: health.DefaultErrorBudget(), - Latency: latency.DefaultConfig(), - OpenAI: &openai.Config{ - APIKey: "ABC", - DefaultParams: &defaultParams, - }, - }, - }, - }, - }, - }, - }, - { - "no models", - Config{ - LanguageRouters: []LangRouterConfig{ - { - ID: "first_router", - Enabled: true, - RoutingStrategy: routing.Priority, - Retry: retry.DefaultExpRetryConfig(), - Models: []providers.LangModelConfig{}, - }, - }, - }, - }, - } - - for _, test := range tests { - _, err := test.config.BuildLangRouters(telemetry.NewTelemetryMock()) - - require.Error(t, err) - } -}