Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion runner/server/handler/completion.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/openai/openai-go"
"github.com/openai/openai-go/shared/constant"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/store"
"github.com/NexaAI/nexa-sdk/runner/internal/types"
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
Expand All @@ -29,7 +30,8 @@ type ChatCompletionNewParams openai.ChatCompletionNewParams
// ChatCompletionRequest defines the request body for the chat completions API.
// example: { "model": "nexaml/nexaml-models", "messages": [ { "role": "user", "content": "why is the sky blue?" } ] }
type ChatCompletionRequest struct {
Stream bool `json:"stream"`
Stream bool `json:"stream"`
KeepAlive *int64 `json:"keep_alive"`

EnableThink bool `json:"enable_think"`
TopK int32 `json:"top_k"`
Expand Down Expand Up @@ -124,11 +126,16 @@ func chatCompletionsLLM(c *gin.Context, param ChatCompletionRequest) {

samplerConfig := parseSamplerConfig(param)

keepAlive := config.Get().KeepAlive
if param.KeepAlive != nil {
keepAlive = *param.KeepAlive
}
// Get LLM instance
p, err := service.KeepAliveGet[nexa_sdk.LLM](
string(param.Model),
types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt},
c.GetHeader("Nexa-KeepCache") != "true",
keepAlive,
)
if errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
Expand Down Expand Up @@ -353,11 +360,16 @@ func chatCompletionsVLM(c *gin.Context, param ChatCompletionRequest) {

samplerConfig := parseSamplerConfig(param)

keepAlive := config.Get().KeepAlive
if param.KeepAlive != nil {
keepAlive = *param.KeepAlive
}
// Get VLM instance
p, err := service.KeepAliveGet[nexa_sdk.VLM](
string(param.Model),
types.ModelParam{NCtx: 4096, NGpuLayers: 999, SystemPrompt: systemPrompt},
c.GetHeader("Nexa-KeepCache") != "true",
keepAlive,
)
if errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, map[string]any{"error": "model not found"})
Expand Down
12 changes: 11 additions & 1 deletion runner/server/handler/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/openai/openai-go"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/types"
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
"github.com/NexaAI/nexa-sdk/runner/server/service"
Expand All @@ -17,16 +18,25 @@ import (
// @Accept json
// @Param request body openai.EmbeddingNewParams true "Embedding request"
func Embeddings(c *gin.Context) {
param := openai.EmbeddingNewParams{}
param := struct {
openai.EmbeddingNewParams
KeepAlive *int64 `json:"keep_alive"`
}{}

if err := c.ShouldBindJSON(&param); err != nil {
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
return
}

keepAlive := config.Get().KeepAlive
if param.KeepAlive != nil {
keepAlive = *param.KeepAlive
}
p, err := service.KeepAliveGet[nexa_sdk.Embedder](
string(param.Model),
types.ModelParam{},
false,
keepAlive,
)
if err != nil {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
Expand Down
11 changes: 10 additions & 1 deletion runner/server/handler/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/openai/openai-go"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/types"
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
"github.com/NexaAI/nexa-sdk/runner/server/service"
Expand All @@ -30,7 +31,10 @@ import (
// @Failure 404 {object} map[string]any "Model not found"
// @Failure 500 {object} map[string]any "Internal server error"
func ImageGenerations(c *gin.Context) {
param := openai.ImageGenerateParams{}
param := struct {
openai.ImageGenerateParams
KeepAlive *int64 `json:"keep_alive"`
}{}
if err := c.ShouldBindJSON(&param); err != nil {
slog.Error("Failed to bind JSON request", "error", err)
c.JSON(http.StatusBadRequest, map[string]any{"error": err.Error()})
Expand All @@ -53,10 +57,15 @@ func ImageGenerations(c *gin.Context) {
param.ResponseFormat = openai.ImageGenerateParamsResponseFormatURL
}

keepAlive := config.Get().KeepAlive
if param.KeepAlive != nil {
keepAlive = *param.KeepAlive
}
imageGen, err := service.KeepAliveGet[nexa_sdk.ImageGen](
param.Model,
types.ModelParam{},
c.GetHeader("Nexa-KeepCache") != "true",
keepAlive,
)
if err != nil {
c.JSON(http.StatusInternalServerError, map[string]any{"error": err.Error()})
Expand Down
11 changes: 6 additions & 5 deletions runner/server/service/keepalive.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,15 @@ import (
"sync"
"time"

"github.com/NexaAI/nexa-sdk/runner/internal/config"
"github.com/NexaAI/nexa-sdk/runner/internal/store"
"github.com/NexaAI/nexa-sdk/runner/internal/types"
nexa_sdk "github.com/NexaAI/nexa-sdk/runner/nexa-sdk"
)

// KeepAliveGet retrieves a model from the keepalive cache or creates it if not found
// This avoids the overhead of repeatedly loading/unloading models from disk
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool) (*T, error) {
t, err := keepAliveGet[T](name, param, reset)
func KeepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (*T, error) {
t, err := keepAliveGet[T](name, param, reset, timeout)
if err != nil {
return nil, err
}
Expand All @@ -37,6 +36,7 @@ type modelKeepInfo struct {
model keepable
param types.ModelParam
lastTime time.Time
timeout int64 // timeout in seconds for this specific model
}

// keepable interface defines objects that can be managed by the keepalive service
Expand Down Expand Up @@ -70,7 +70,7 @@ func (keepAlive *keepAliveService) start() {
case <-t.C:
keepAlive.Lock()
for name, model := range keepAlive.models {
if time.Since(model.lastTime).Milliseconds()/1000 > config.Get().KeepAlive {
if int64(time.Since(model.lastTime).Seconds()) > model.timeout {
model.model.Destroy()
delete(keepAlive.models, name)
}
Expand All @@ -83,7 +83,7 @@ func (keepAlive *keepAliveService) start() {

// keepAliveGet retrieves a cached model or creates a new one if not found
// Ensures only one model is kept in memory at a time by clearing others
func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any, error) {
func keepAliveGet[T any](name string, param types.ModelParam, reset bool, timeout int64) (any, error) {
keepAlive.Lock()
defer keepAlive.Unlock()

Expand Down Expand Up @@ -195,6 +195,7 @@ func keepAliveGet[T any](name string, param types.ModelParam, reset bool) (any,
model: t,
param: param,
lastTime: time.Now(),
timeout: timeout,
}
keepAlive.models[name] = model

Expand Down