diff --git a/.tool-versions b/.tool-versions index 9a05cbdba..9c24a0d45 100644 --- a/.tool-versions +++ b/.tool-versions @@ -1,2 +1,5 @@ golang 1.25.5 -mockery 2.53.5 \ No newline at end of file +mockery 2.53.5 +minikube 1.34.0 +kubectl 1.28.3 +argo 3.6.10 diff --git a/aigateway/handler/openai.go b/aigateway/handler/openai.go index c54a3aa49..60df56c17 100644 --- a/aigateway/handler/openai.go +++ b/aigateway/handler/openai.go @@ -9,6 +9,7 @@ import ( "log/slog" "net/http" "net/url" + "strconv" "strings" "time" @@ -95,10 +96,11 @@ type OpenAIHandlerImpl struct { // ListModels godoc // @Security ApiKey // @Summary List available models -// @Description Returns a list of available models +// @Description Returns a list of available models, supports fuzzy search by model_id query parameter // @Tags AIGateway // @Accept json // @Produce json +// @Param model_id query string false "Model ID for fuzzy search" // @Success 200 {object} types.ModelList "OK" // @Failure 500 {object} error "Internal server error" // @Router /v1/models [get] @@ -116,9 +118,74 @@ func (h *OpenAIHandlerImpl) ListModels(c *gin.Context) { return } + // Apply fuzzy search filter if model_id query parameter is provided + searchQuery := c.Query("model_id") + if searchQuery != "" { + filteredModels := make([]types.Model, 0) + for _, model := range models { + if strings.Contains(strings.ToLower(model.ID), strings.ToLower(searchQuery)) { + filteredModels = append(filteredModels, model) + } + } + models = filteredModels + } + + // Parse pagination parameters + perStr := c.Query("per") + pageStr := c.Query("page") + + // Set default values + per := 20 // default per page + page := 1 // default page (1-based) + + if perStr != "" { + if parsedPerPage, err := strconv.Atoi(perStr); err == nil && parsedPerPage > 0 { + per = parsedPerPage + // Cap the per_page to prevent excessive requests + if per > 100 { + per = 100 + } + } + } + + if pageStr != "" { + if parsedPage, err := strconv.Atoi(pageStr); err == nil && parsedPage > 0 { + page = parsedPage + } + } + + totalCount := len(models) + + // Apply pagination + offset := (page - 1) * per + startIndex := offset + if startIndex > totalCount { + startIndex = totalCount + } + + endIndex := startIndex + per + if endIndex > totalCount { + endIndex = totalCount + } + + paginatedModels := models[startIndex:endIndex] + + // Set pagination metadata + var firstID, lastID *string + if len(paginatedModels) > 0 { + firstID = &paginatedModels[0].ID + lastID = &paginatedModels[len(paginatedModels)-1].ID + } + + hasMore := endIndex < totalCount + response := types.ModelList{ - Object: "list", - Data: models, + Object: "list", + Data: paginatedModels, + FirstID: firstID, + LastID: lastID, + HasMore: hasMore, + TotalCount: totalCount, } c.PureJSON(http.StatusOK, response) diff --git a/aigateway/handler/openai_test.go b/aigateway/handler/openai_test.go index ffdef7986..6c71b475e 100644 --- a/aigateway/handler/openai_test.go +++ b/aigateway/handler/openai_test.go @@ -5,14 +5,17 @@ import ( "context" "encoding/json" "errors" + "fmt" "io" "net/http" "net/http/httptest" + "net/url" "sync" "testing" "github.com/gin-gonic/gin" "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/option" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" mockcomp "opencsg.com/csghub-server/_mocks/opencsg.com/csghub-server/aigateway/component" @@ -23,9 +26,11 @@ import ( "opencsg.com/csghub-server/api/httpbase" "opencsg.com/csghub-server/builder/rpc" "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/builder/testutil" ) type testerOpenAIHandler struct { + *testutil.GinTester mocks struct { openAIComp *mockcomp.MockOpenAIComponent moderationComp *mockcomp.MockModeration @@ -45,17 +50,15 @@ func setupTest(t *testing.T) (*testerOpenAIHandler, *gin.Context, *httptest.Resp mockTokenCounterFactory := mocktoken.NewMockCounterFactory(t) handler := newOpenAIHandler(mockOpenAI, mockRepo, mockModeration, mockClsComp, mockTokenCounterFactory) - w := httptest.NewRecorder() - c, _ := gin.CreateTestContext(w) - c.Request = &http.Request{ - Header: make(http.Header), - } // Set test user - httpbase.SetCurrentUser(c, "testuser") - httpbase.SetCurrentUserUUID(c, "testuuid") tester := &testerOpenAIHandler{ - handler: handler, + GinTester: testutil.NewGinTester(), + handler: handler, } + w := tester.GinTester.Response() + c := tester.GinTester.Gctx() + httpbase.SetCurrentUser(c, "testuser") + httpbase.SetCurrentUserUUID(c, "testuuid") tester.mocks.moderationComp = mockModeration tester.mocks.openAIComp = mockOpenAI tester.mocks.repoComp = mockRepo @@ -88,6 +91,363 @@ func TestOpenAIHandler_ListModels(t *testing.T) { assert.Equal(t, "list", response.Object) assert.Equal(t, models, response.Data) }) + + t.Run("fuzzy search with matching model_id", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + { + BaseModel: types.BaseModel{ + ID: "gpt-4:svc1", + Object: "model", + OwnedBy: "testuser", + }, + }, + { + BaseModel: types.BaseModel{ + ID: "claude-3:svc2", + Object: "model", + OwnedBy: "testuser", + }, + }, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + tester.WithQuery("model_id", "gpt") + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 1) + assert.Equal(t, "gpt-4:svc1", response.Data[0].ID) + }) + + t.Run("fuzzy search case insensitive", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + { + BaseModel: types.BaseModel{ + ID: "GPT-4:svc1", + Object: "model", + OwnedBy: "testuser", + }, + }, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + c.Request.URL, _ = url.Parse("/v1/models?model_id=gpt") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Len(t, response.Data, 1) + assert.Equal(t, "GPT-4:svc1", response.Data[0].ID) + }) + + t.Run("fuzzy search with no matches", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + { + BaseModel: types.BaseModel{ + ID: "gpt-4:svc1", + Object: "model", + OwnedBy: "testuser", + }, + }, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + c.Request.URL, _ = url.Parse("/v1/models?model_id=nonexistent") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 0) + }) + + t.Run("no search query returns all models", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + { + BaseModel: types.BaseModel{ + ID: "model1:svc1", + Object: "model", + OwnedBy: "testuser", + }, + }, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + c.Request.URL, _ = url.Parse("/v1/models") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Equal(t, models, response.Data) + }) + + t.Run("pagination with default parameters", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model2:svc2", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model3:svc3", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 3) + assert.Equal(t, 3, response.TotalCount) + assert.False(t, response.HasMore) + assert.Equal(t, "model1:svc1", *response.FirstID) + assert.Equal(t, "model3:svc3", *response.LastID) + }) + + t.Run("pagination with per_page parameter", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model2:svc2", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model3:svc3", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model4:svc4", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model5:svc5", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("per", "2") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 2) + assert.Equal(t, 5, response.TotalCount) + assert.True(t, response.HasMore) + assert.Equal(t, "model1:svc1", *response.FirstID) + assert.Equal(t, "model2:svc2", *response.LastID) + }) + + t.Run("pagination with page and per_page parameters", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model2:svc2", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model3:svc3", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model4:svc4", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model5:svc5", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("page", "2").WithQuery("per", "2") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 2) + assert.Equal(t, 5, response.TotalCount) + assert.True(t, response.HasMore) + assert.Equal(t, "model3:svc3", *response.FirstID) + assert.Equal(t, "model4:svc4", *response.LastID) + }) + + t.Run("pagination last page", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model2:svc2", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model3:svc3", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model4:svc4", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model5:svc5", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("page", "3").WithQuery("per", "2") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 1) + assert.Equal(t, 5, response.TotalCount) + assert.False(t, response.HasMore) + assert.Equal(t, "model5:svc5", *response.FirstID) + assert.Equal(t, "model5:svc5", *response.LastID) + }) + + t.Run("pagination with page beyond available data", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "model1:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "model2:svc2", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("page", "3").WithQuery("per", "2") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 0) + assert.Equal(t, 2, response.TotalCount) + assert.False(t, response.HasMore) + assert.Nil(t, response.FirstID) + assert.Nil(t, response.LastID) + }) + + t.Run("pagination with per_page limit capped at 100", func(t *testing.T) { + tester, c, w := setupTest(t) + models := make([]types.Model, 150) + for i := 0; i < 150; i++ { + models[i] = types.Model{ + BaseModel: types.BaseModel{ + ID: fmt.Sprintf("model%d:svc%d", i+1, i+1), + Object: "model", + OwnedBy: "testuser", + }, + } + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("per", "200") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 100) // Should be capped at 100 + assert.Equal(t, 150, response.TotalCount) + assert.True(t, response.HasMore) + }) + + t.Run("pagination with search and per_page", func(t *testing.T) { + tester, c, w := setupTest(t) + models := []types.Model{ + {BaseModel: types.BaseModel{ID: "gpt-4:svc1", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "gpt-3.5:svc2", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "claude-3:svc3", Object: "model", OwnedBy: "testuser"}}, + {BaseModel: types.BaseModel{ID: "gpt-3.5-turbo:svc4", Object: "model", OwnedBy: "testuser"}}, + } + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + tester.WithQuery("model_id", "gpt").WithQuery("per", "2") + + tester.handler.ListModels(c) + + assert.Equal(t, http.StatusOK, w.Code) + var response types.ModelList + err := json.Unmarshal(w.Body.Bytes(), &response) + assert.NoError(t, err) + assert.Equal(t, "list", response.Object) + assert.Len(t, response.Data, 2) // Only gpt models, limited to 2 + assert.Equal(t, 3, response.TotalCount) // 3 gpt models total + assert.True(t, response.HasMore) + assert.Equal(t, "gpt-4:svc1", *response.FirstID) + assert.Equal(t, "gpt-3.5:svc2", *response.LastID) + }) +} + +func TestOpenAIHandler_ListModels_OpenaiSDK(t *testing.T) { + // Setup test with mock data + tester, _, _ := setupTest(t) + + // Prepare mock models + models := []types.Model{ + { + BaseModel: types.BaseModel{ + ID: "gpt-4:svc1", + Object: "model", + OwnedBy: "testuser", + }, + }, + { + BaseModel: types.BaseModel{ + ID: "gpt-3.5-turbo:svc2", + Object: "model", + OwnedBy: "testuser", + }, + }, + } + + // Set up mock expectation + tester.mocks.openAIComp.EXPECT().GetAvailableModels(mock.Anything, "testuser").Return(models, nil) + + // Create gin router + gin.SetMode(gin.TestMode) + router := gin.New() + + // Add middleware to set current user (similar to how it's done in the actual router) + router.Use(func(c *gin.Context) { + httpbase.SetCurrentUser(c, "testuser") + httpbase.SetCurrentUserUUID(c, "testuuid") + c.Next() + }) + + // Set up the route + router.GET("/v1/models", tester.handler.ListModels) + + // Start test server + server := httptest.NewServer(router) + defer server.Close() + + // Create OpenAI client with the test server URL + client := openai.NewClient(option.WithAPIKey("test-api-key"), option.WithBaseURL(server.URL+"/v1")) + + // Call the ListModels endpoint using OpenAI SDK + ctx := context.Background() + modelList, err := client.Models.List(ctx) + + // Assertions + assert.NoError(t, err) + assert.NotNil(t, modelList) + assert.Equal(t, "list", modelList.Object) + assert.Len(t, modelList.Data, 2) + // Verify model IDs + modelIDs := make([]string, len(modelList.Data)) + for i, model := range modelList.Data { + modelIDs[i] = model.ID + } + assert.Contains(t, modelIDs, "gpt-4:svc1") + assert.Contains(t, modelIDs, "gpt-3.5-turbo:svc2") + + // get next page + nextPage, err := modelList.GetNextPage() + assert.NoError(t, err) + assert.Nil(t, nextPage) } func TestOpenAIHandler_GetModel(t *testing.T) { diff --git a/aigateway/types/openai.go b/aigateway/types/openai.go index f18ad3323..47cf856ba 100644 --- a/aigateway/types/openai.go +++ b/aigateway/types/openai.go @@ -144,4 +144,9 @@ func (m *Model) ForExternalResponse() *Model { type ModelList struct { Object string `json:"object"` Data []Model `json:"data"` + // Pagination metadata + FirstID *string `json:"first_id,omitempty"` + LastID *string `json:"last_id,omitempty"` + HasMore bool `json:"has_more"` + TotalCount int `json:"total_count"` }