Skip to content

Commit e5ed2cb

Browse files
committed
Add support for dynamic model providers
Implements functionality to parse model names with provider information in the format "provider://model" This allows dynamic provider selection rather than relying only on predefined mappings. The change affects all execution methods to properly handle these dynamic model specifications while maintaining compatibility with the existing approach for standard model names.
1 parent c7196ba commit e5ed2cb

File tree

4 files changed

+97
-18
lines changed

4 files changed

+97
-18
lines changed

examples/custom-provider/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
146146
return ch, nil
147147
}
148148

149+
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
150+
return clipexec.Response{}, errors.New("not implemented")
151+
}
152+
149153
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
150154
return a, nil
151155
}

internal/api/server.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
225225
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
226226

227227
// Create server instance
228+
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
229+
for _, p := range cfg.OpenAICompatibility {
230+
providerNames = append(providerNames, p.Name)
231+
}
228232
s := &Server{
229233
engine: engine,
230-
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
234+
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
231235
cfg: cfg,
232236
accessManager: accessManager,
233237
requestLogger: requestLogger,
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
823827
managementasset.SetCurrentConfig(cfg)
824828
// Save YAML snapshot for next comparison
825829
s.oldConfigYaml, _ = yaml.Marshal(cfg)
830+
831+
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
832+
for _, p := range cfg.OpenAICompatibility {
833+
providerNames = append(providerNames, p.Name)
834+
}
835+
s.handlers.OpenAICompatProviders = providerNames
836+
826837
s.handlers.UpdateClients(&cfg.SDKConfig)
827838

828839
if !cfg.RemoteManagement.DisableControlPanel {
@@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
904915
}
905916
}
906917

907-
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
918+
// legacy clientsToSlice removed; handlers no longer consume legacy client slices

internal/translator/openai/openai/chat-completions/openai_openai_request.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package chat_completions
44

55
import (
66
"bytes"
7+
"github.com/tidwall/sjson"
78
)
89

910
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
@@ -17,5 +18,14 @@ import (
1718
// Returns:
1819
// - []byte: The transformed request data in Gemini CLI API format
1920
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
20-
return bytes.Clone(inputRawJSON)
21+
// Update the "model" field in the JSON payload with the provided modelName
22+
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
23+
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
24+
if err != nil {
25+
// If there's an error, return the original JSON or handle the error appropriately.
26+
// For now, we'll return the original, but in a real scenario, logging or a more robust error
27+
// handling mechanism would be needed.
28+
return bytes.Clone(inputRawJSON)
29+
}
30+
return updatedJSON
2131
}

sdk/api/handlers/handlers.go

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package handlers
66
import (
77
"fmt"
88
"net/http"
9+
"strings"
910

1011
"github.com/gin-gonic/gin"
1112
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
4647

4748
// Cfg holds the current application configuration.
4849
Cfg *config.SDKConfig
50+
51+
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
52+
OpenAICompatProviders []string
4953
}
5054

5155
// NewBaseAPIHandlers creates a new API handlers instance.
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
5761
//
5862
// Returns:
5963
// - *BaseAPIHandler: A new API handlers instance
60-
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
64+
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
6165
return &BaseAPIHandler{
62-
Cfg: cfg,
63-
AuthManager: authManager,
66+
Cfg: cfg,
67+
AuthManager: authManager,
68+
OpenAICompatProviders: openAICompatProviders,
6469
}
6570
}
6671

@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
133138
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
134139
// This path is the only supported execution route.
135140
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
136-
normalizedModel, metadata := normalizeModelMetadata(modelName)
137-
providers := util.GetProviderName(normalizedModel)
138-
if len(providers) == 0 {
139-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
141+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
142+
if errMsg != nil {
143+
return nil, errMsg
140144
}
141145
req := coreexecutor.Request{
142146
Model: normalizedModel,
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
176180
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
177181
// This path is the only supported execution route.
178182
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
179-
normalizedModel, metadata := normalizeModelMetadata(modelName)
180-
providers := util.GetProviderName(normalizedModel)
181-
if len(providers) == 0 {
182-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
183+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
184+
if errMsg != nil {
185+
return nil, errMsg
183186
}
184187
req := coreexecutor.Request{
185188
Model: normalizedModel,
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
219222
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
220223
// This path is the only supported execution route.
221224
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
222-
normalizedModel, metadata := normalizeModelMetadata(modelName)
223-
providers := util.GetProviderName(normalizedModel)
224-
if len(providers) == 0 {
225+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
226+
if errMsg != nil {
225227
errChan := make(chan *interfaces.ErrorMessage, 1)
226-
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
228+
errChan <- errMsg
227229
close(errChan)
228230
return nil, errChan
229231
}
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
292294
return dataChan, errChan
293295
}
294296

297+
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
298+
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
299+
300+
// First, normalize the model name to handle suffixes like "-thinking-128"
301+
// This needs to happen before determining the provider for non-dynamic models.
302+
normalizedModel, metadata = normalizeModelMetadata(modelName)
303+
304+
if isDynamic {
305+
providers = []string{providerName}
306+
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
307+
// so we use it as the final normalizedModel.
308+
normalizedModel = extractedModelName
309+
} else {
310+
// For non-dynamic models, use the normalizedModel to get the provider name.
311+
providers = util.GetProviderName(normalizedModel)
312+
}
313+
314+
if len(providers) == 0 {
315+
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
316+
}
317+
318+
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
319+
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
320+
// So, normalizedModel is already correctly set at this point.
321+
322+
return providers, normalizedModel, metadata, nil
323+
}
324+
325+
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
326+
var providerPart, modelPart string
327+
for _, sep := range []string{"://"} {
328+
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
329+
providerPart = parts[0]
330+
modelPart = parts[1]
331+
break
332+
}
333+
}
334+
335+
if providerPart == "" {
336+
return "", modelName, false
337+
}
338+
339+
// Check if the provider is a configured openai-compatibility provider
340+
for _, pName := range h.OpenAICompatProviders {
341+
if pName == providerPart {
342+
return providerPart, modelPart, true
343+
}
344+
}
345+
346+
return "", modelName, false
347+
}
348+
295349
func cloneBytes(src []byte) []byte {
296350
if len(src) == 0 {
297351
return nil

0 commit comments

Comments
 (0)