diff --git a/runner/server/handler/completion.go b/runner/server/handler/completion.go index 6740eea5..091a4eab 100644 --- a/runner/server/handler/completion.go +++ b/runner/server/handler/completion.go @@ -13,6 +13,7 @@ import ( "github.com/openai/openai-go" "github.com/openai/openai-go/shared/constant" + "github.com/NexaAI/nexa-sdk/runner/internal/config" "github.com/NexaAI/nexa-sdk/runner/internal/store" "github.com/NexaAI/nexa-sdk/runner/internal/types" nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk" @@ -29,7 +30,8 @@ type ChatCompletionNewParams openai.ChatCompletionNewParams // ChatCompletionRequest defines the request body for the chat completions API. // example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] } type ChatCompletionRequest struct { - Stream bool `json:"stream"` + Stream bool `json:"stream"` + KeepAlive *int64 `json:"keep_alive"` EnableThink bool `json:"enable_think"` TopK int32 `json:"top_k"` @@ -124,11 +126,16 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) { samplerConfig := parseSamplerConfig(param) + keepAlive := config.Get().KeepAlive + if param.KeepAlive != nil { + keepAlive = *param.KeepAlive + } // Get LLM instance p, err := service.KeepAliveGet[nexa_sdk.LLM]( string(param.Model), types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt}, c.GetHeader("Nexa-KeepCache") != "true", + keepAlive, ) if errors.Is(err, os.ErrNotExist) { c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"}) @@ -353,11 +360,16 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) { samplerConfig := parseSamplerConfig(param) + keepAlive := config.Get().KeepAlive + if param.KeepAlive != nil { + keepAlive = *param.KeepAlive + } // Get VLM instance p, err := service.KeepAliveGet[nexa_sdk.VLM]( string(param.Model), types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt}, c.GetHeader("Nexa-KeepCache") != "true", + keepAlive, ) if errors.Is(err, os.ErrNotExist) { c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"}) diff --git a/runner/server/handler/embedder.go b/runner/server/handler/embedder.go index 89f43f46..6de73890 100644 --- a/runner/server/handler/embedder.go +++ b/runner/server/handler/embedder.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" "github.com/openai/openai-go" + "github.com/NexaAI/nexa-sdk/runner/internal/config" "github.com/NexaAI/nexa-sdk/runner/internal/types" nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk" "github.com/NexaAI/nexa-sdk/runner/server/service" @@ -17,16 +18,25 @@ import ( // @Accept json // @Param request body openai.EmbeddingNewParams true "Embedding request" func Embeddings(c *gin.Context) { - param := openai.EmbeddingNewParams{} + param := struct { + openai.EmbeddingNewParams + KeepAlive *int64 `json:"keep_alive"` + }{} + if err := c.ShouldBindJSON(¶m); err != nil { c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()}) return } + keepAlive := config.Get().KeepAlive + if param.KeepAlive != nil { + keepAlive = *param.KeepAlive + } p, err := service.KeepAliveGet[nexa_sdk.Embedder]( string(param.Model), types.ModelParam{}, false, + keepAlive, ) if err != nil { c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) diff --git a/runner/server/handler/image.go b/runner/server/handler/image.go index c7564c38..46aded8a 100644 --- a/runner/server/handler/image.go +++ b/runner/server/handler/image.go @@ -14,6 +14,7 @@ import ( "github.com/gin-gonic/gin" "github.com/openai/openai-go" + "github.com/NexaAI/nexa-sdk/runner/internal/config" "github.com/NexaAI/nexa-sdk/runner/internal/types" nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk" "github.com/NexaAI/nexa-sdk/runner/server/service" @@ -30,7 +31,10 @@ import ( // @Failure 404 {object} map[string]any "Model not found" // @Failure 500 {object} map[string]any "Internal server error" func ImageGenerations(c *gin.Context) { - param := openai.ImageGenerateParams{} + param := struct { + openai.ImageGenerateParams + KeepAlive *int64 `json:"keep_alive"` + }{} if err := c.ShouldBindJSON(¶m); err != nil { slog.Error("Failed to bind JSON request", "error", err) c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()}) @@ -53,10 +57,15 @@ func ImageGenerations(c *gin.Context) { param.ResponseFormat = openai.ImageGenerateParamsResponseFormatURL } + keepAlive := config.Get().KeepAlive + if param.KeepAlive != nil { + keepAlive = *param.KeepAlive + } imageGen, err := service.KeepAliveGet[nexa_sdk.ImageGen]( param.Model, types.ModelParam{}, c.GetHeader("Nexa-KeepCache") != "true", + keepAlive, ) if err != nil { c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()}) diff --git a/runner/server/service/keepalive.go b/runner/server/service/keepalive.go index 5917567c..6a563b5b 100644 --- a/runner/server/service/keepalive.go +++ b/runner/server/service/keepalive.go @@ -6,7 +6,6 @@ import ( "sync" "time" - "github.com/NexaAI/nexa-sdk/runner/internal/config" "github.com/NexaAI/nexa-sdk/runner/internal/store" "github.com/NexaAI/nexa-sdk/runner/internal/types" nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk" @@ -14,8 +13,8 @@ import ( // KeepAliveGet retrieves a model from the keepalive cache or creates it if not found // This avoids the overhead of repeatedly loading/unloading models from disk -func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) { - t, err := keepAliveGet[T](name, param, reset) +func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (*T, error) { + t, err := keepAliveGet[T](name, param, reset, timeout) if err != nil { return nil, err } @@ -37,6 +36,7 @@ type modelKeepInfo struct { model keepable param types.ModelParam lastTime time.Time + timeout int64 // timeout in seconds for this specific model } // keepable interface defines objects that can be managed by the keepalive service @@ -70,7 +70,7 @@ func (keepAlive *keepAliveService) start() { case <-t.C: keepAlive.Lock() for name, model := range keepAlive.models { - if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive { + if int64(time.Since(model.lastTime).Seconds()) > model.timeout { model.model.Destroy() delete(keepAlive.models, name) } @@ -83,7 +83,7 @@ func (keepAlive *keepAliveService) start() { // keepAliveGet retrieves a cached model or creates a new one if not found // Ensures only one model is kept in memory at a time by clearing others -func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) { +func keepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (any, error) { keepAlive.Lock() defer keepAlive.Unlock() @@ -195,6 +195,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, model: t, param: param, lastTime: time.Now(), + timeout: timeout, } keepAlive.models[name] = model