Skip to content

Commit ac3d47e

Browse files
authored
Merge pull request router-for-me#173 from tobwen/feature/dynamic-model-routing
Add support for dynamic model providers
2 parents 847c250 + e5ed2cb commit ac3d47e

File tree

4 files changed

+97
-18
lines changed

4 files changed

+97
-18
lines changed

examples/custom-provider/main.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
146146
return ch, nil
147147
}
148148

149+
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
150+
return clipexec.Response{}, errors.New("not implemented")
151+
}
152+
149153
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
150154
return a, nil
151155
}

internal/api/server.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
225225
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
226226

227227
// Create server instance
228+
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
229+
for _, p := range cfg.OpenAICompatibility {
230+
providerNames = append(providerNames, p.Name)
231+
}
228232
s := &Server{
229233
engine: engine,
230-
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
234+
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
231235
cfg: cfg,
232236
accessManager: accessManager,
233237
requestLogger: requestLogger,
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
823827
managementasset.SetCurrentConfig(cfg)
824828
// Save YAML snapshot for next comparison
825829
s.oldConfigYaml, _ = yaml.Marshal(cfg)
830+
831+
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
832+
for _, p := range cfg.OpenAICompatibility {
833+
providerNames = append(providerNames, p.Name)
834+
}
835+
s.handlers.OpenAICompatProviders = providerNames
836+
826837
s.handlers.UpdateClients(&cfg.SDKConfig)
827838

828839
if !cfg.RemoteManagement.DisableControlPanel {
@@ -904,4 +915,4 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
904915
}
905916
}
906917

907-
// legacy clientsToSlice removed; handlers no longer consume legacy client slices
918+
// legacy clientsToSlice removed; handlers no longer consume legacy client slices

internal/translator/openai/openai/chat-completions/openai_openai_request.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package chat_completions
44

55
import (
66
"bytes"
7+
"github.com/tidwall/sjson"
78
)
89

910
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
@@ -17,5 +18,14 @@ import (
1718
// Returns:
1819
// - []byte: The transformed request data in Gemini CLI API format
1920
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
20-
return bytes.Clone(inputRawJSON)
21+
// Update the "model" field in the JSON payload with the provided modelName
22+
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
23+
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
24+
if err != nil {
25+
// If there's an error, return the original JSON or handle the error appropriately.
26+
// For now, we'll return the original, but in a real scenario, logging or a more robust error
27+
// handling mechanism would be needed.
28+
return bytes.Clone(inputRawJSON)
29+
}
30+
return updatedJSON
2131
}

sdk/api/handlers/handlers.go

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package handlers
66
import (
77
"fmt"
88
"net/http"
9+
"strings"
910

1011
"github.com/gin-gonic/gin"
1112
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
4647

4748
// Cfg holds the current application configuration.
4849
Cfg *config.SDKConfig
50+
51+
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
52+
OpenAICompatProviders []string
4953
}
5054

5155
// NewBaseAPIHandlers creates a new API handlers instance.
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
5761
//
5862
// Returns:
5963
// - *BaseAPIHandler: A new API handlers instance
60-
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
64+
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
6165
return &BaseAPIHandler{
62-
Cfg: cfg,
63-
AuthManager: authManager,
66+
Cfg: cfg,
67+
AuthManager: authManager,
68+
OpenAICompatProviders: openAICompatProviders,
6469
}
6570
}
6671

@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
133138
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
134139
// This path is the only supported execution route.
135140
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
136-
normalizedModel, metadata := normalizeModelMetadata(modelName)
137-
providers := util.GetProviderName(normalizedModel)
138-
if len(providers) == 0 {
139-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
141+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
142+
if errMsg != nil {
143+
return nil, errMsg
140144
}
141145
req := coreexecutor.Request{
142146
Model: normalizedModel,
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
176180
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
177181
// This path is the only supported execution route.
178182
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
179-
normalizedModel, metadata := normalizeModelMetadata(modelName)
180-
providers := util.GetProviderName(normalizedModel)
181-
if len(providers) == 0 {
182-
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
183+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
184+
if errMsg != nil {
185+
return nil, errMsg
183186
}
184187
req := coreexecutor.Request{
185188
Model: normalizedModel,
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
219222
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
220223
// This path is the only supported execution route.
221224
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
222-
normalizedModel, metadata := normalizeModelMetadata(modelName)
223-
providers := util.GetProviderName(normalizedModel)
224-
if len(providers) == 0 {
225+
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
226+
if errMsg != nil {
225227
errChan := make(chan *interfaces.ErrorMessage, 1)
226-
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
228+
errChan <- errMsg
227229
close(errChan)
228230
return nil, errChan
229231
}
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
292294
return dataChan, errChan
293295
}
294296

297+
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
298+
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
299+
300+
// First, normalize the model name to handle suffixes like "-thinking-128"
301+
// This needs to happen before determining the provider for non-dynamic models.
302+
normalizedModel, metadata = normalizeModelMetadata(modelName)
303+
304+
if isDynamic {
305+
providers = []string{providerName}
306+
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
307+
// so we use it as the final normalizedModel.
308+
normalizedModel = extractedModelName
309+
} else {
310+
// For non-dynamic models, use the normalizedModel to get the provider name.
311+
providers = util.GetProviderName(normalizedModel)
312+
}
313+
314+
if len(providers) == 0 {
315+
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
316+
}
317+
318+
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
319+
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
320+
// So, normalizedModel is already correctly set at this point.
321+
322+
return providers, normalizedModel, metadata, nil
323+
}
324+
325+
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
326+
var providerPart, modelPart string
327+
for _, sep := range []string{"://"} {
328+
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
329+
providerPart = parts[0]
330+
modelPart = parts[1]
331+
break
332+
}
333+
}
334+
335+
if providerPart == "" {
336+
return "", modelName, false
337+
}
338+
339+
// Check if the provider is a configured openai-compatibility provider
340+
for _, pName := range h.OpenAICompatProviders {
341+
if pName == providerPart {
342+
return providerPart, modelPart, true
343+
}
344+
}
345+
346+
return "", modelName, false
347+
}
348+
295349
func cloneBytes(src []byte) []byte {
296350
if len(src) == 0 {
297351
return nil

0 commit comments

Comments
 (0)