diff --git a/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go index e801bcf69..194d218cf 100644 --- a/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go +++ b/_mocks/opencsg.com/csghub-server/aigateway/token/mock_CounterFactory.go @@ -20,9 +20,9 @@ func (_m *MockCounterFactory) EXPECT() *MockCounterFactory_Expecter { return &MockCounterFactory_Expecter{mock: &_m.Mock} } -// NewChat provides a mock function with given fields: config -func (_m *MockCounterFactory) NewChat(config token.CreateParam) token.ChatTokenCounter { - ret := _m.Called(config) +// NewChat provides a mock function with given fields: param +func (_m *MockCounterFactory) NewChat(param token.CreateParam) token.ChatTokenCounter { + ret := _m.Called(param) if len(ret) == 0 { panic("no return value specified for NewChat") @@ -30,7 +30,7 @@ func (_m *MockCounterFactory) NewChat(config token.CreateParam) token.ChatTokenC var r0 token.ChatTokenCounter if rf, ok := ret.Get(0).(func(token.CreateParam) token.ChatTokenCounter); ok { - r0 = rf(config) + r0 = rf(param) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(token.ChatTokenCounter) @@ -46,12 +46,12 @@ type MockCounterFactory_NewChat_Call struct { } // NewChat is a helper method to define mock.On call -// - config token.Config -func (_e *MockCounterFactory_Expecter) NewChat(config interface{}) *MockCounterFactory_NewChat_Call { - return &MockCounterFactory_NewChat_Call{Call: _e.mock.On("NewChat", config)} +// - param token.CreateParam +func (_e *MockCounterFactory_Expecter) NewChat(param interface{}) *MockCounterFactory_NewChat_Call { + return &MockCounterFactory_NewChat_Call{Call: _e.mock.On("NewChat", param)} } -func (_c *MockCounterFactory_NewChat_Call) Run(run func(config token.CreateParam)) *MockCounterFactory_NewChat_Call { +func (_c *MockCounterFactory_NewChat_Call) Run(run func(param token.CreateParam)) *MockCounterFactory_NewChat_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(token.CreateParam)) }) @@ -68,9 +68,9 @@ func (_c *MockCounterFactory_NewChat_Call) RunAndReturn(run func(token.CreatePar return _c } -// NewEmbedding provides a mock function with given fields: config -func (_m *MockCounterFactory) NewEmbedding(config token.CreateParam) *token.EmbeddingTokenCounter { - ret := _m.Called(config) +// NewEmbedding provides a mock function with given fields: param +func (_m *MockCounterFactory) NewEmbedding(param token.CreateParam) *token.EmbeddingTokenCounter { + ret := _m.Called(param) if len(ret) == 0 { panic("no return value specified for NewEmbedding") @@ -78,7 +78,7 @@ func (_m *MockCounterFactory) NewEmbedding(config token.CreateParam) *token.Embe var r0 *token.EmbeddingTokenCounter if rf, ok := ret.Get(0).(func(token.CreateParam) *token.EmbeddingTokenCounter); ok { - r0 = rf(config) + r0 = rf(param) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(*token.EmbeddingTokenCounter) @@ -94,12 +94,12 @@ type MockCounterFactory_NewEmbedding_Call struct { } // NewEmbedding is a helper method to define mock.On call -// - config token.Config -func (_e *MockCounterFactory_Expecter) NewEmbedding(config interface{}) *MockCounterFactory_NewEmbedding_Call { - return &MockCounterFactory_NewEmbedding_Call{Call: _e.mock.On("NewEmbedding", config)} +// - param token.CreateParam +func (_e *MockCounterFactory_Expecter) NewEmbedding(param interface{}) *MockCounterFactory_NewEmbedding_Call { + return &MockCounterFactory_NewEmbedding_Call{Call: _e.mock.On("NewEmbedding", param)} } -func (_c *MockCounterFactory_NewEmbedding_Call) Run(run func(config token.CreateParam)) *MockCounterFactory_NewEmbedding_Call { +func (_c *MockCounterFactory_NewEmbedding_Call) Run(run func(param token.CreateParam)) *MockCounterFactory_NewEmbedding_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(token.CreateParam)) }) diff --git a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepoStore.go b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepoStore.go index 632594249..1735164b7 100644 --- a/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepoStore.go +++ b/_mocks/opencsg.com/csghub-server/builder/store/database/mock_RepoStore.go @@ -1776,6 +1776,75 @@ func (_c *MockRepoStore_GetRepoWithoutRuntimeByID_Call) RunAndReturn(run func(co return _c } +// GetReposBySearch provides a mock function with given fields: ctx, search, repoType, page, pageSize +func (_m *MockRepoStore) GetReposBySearch(ctx context.Context, search string, repoType types.RepositoryType, page int, pageSize int) ([]*database.Repository, int, error) { + ret := _m.Called(ctx, search, repoType, page, pageSize) + + if len(ret) == 0 { + panic("no return value specified for GetReposBySearch") + } + + var r0 []*database.Repository + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, types.RepositoryType, int, int) ([]*database.Repository, int, error)); ok { + return rf(ctx, search, repoType, page, pageSize) + } + if rf, ok := ret.Get(0).(func(context.Context, string, types.RepositoryType, int, int) []*database.Repository); ok { + r0 = rf(ctx, search, repoType, page, pageSize) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*database.Repository) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, types.RepositoryType, int, int) int); ok { + r1 = rf(ctx, search, repoType, page, pageSize) + } else { + r1 = ret.Get(1).(int) + } + + if rf, ok := ret.Get(2).(func(context.Context, string, types.RepositoryType, int, int) error); ok { + r2 = rf(ctx, search, repoType, page, pageSize) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MockRepoStore_GetReposBySearch_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetReposBySearch' +type MockRepoStore_GetReposBySearch_Call struct { + *mock.Call +} + +// GetReposBySearch is a helper method to define mock.On call +// - ctx context.Context +// - search string +// - repoType types.RepositoryType +// - page int +// - pageSize int +func (_e *MockRepoStore_Expecter) GetReposBySearch(ctx interface{}, search interface{}, repoType interface{}, page interface{}, pageSize interface{}) *MockRepoStore_GetReposBySearch_Call { + return &MockRepoStore_GetReposBySearch_Call{Call: _e.mock.On("GetReposBySearch", ctx, search, repoType, page, pageSize)} +} + +func (_c *MockRepoStore_GetReposBySearch_Call) Run(run func(ctx context.Context, search string, repoType types.RepositoryType, page int, pageSize int)) *MockRepoStore_GetReposBySearch_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(types.RepositoryType), args[3].(int), args[4].(int)) + }) + return _c +} + +func (_c *MockRepoStore_GetReposBySearch_Call) Return(_a0 []*database.Repository, _a1 int, _a2 error) *MockRepoStore_GetReposBySearch_Call { + _c.Call.Return(_a0, _a1, _a2) + return _c +} + +func (_c *MockRepoStore_GetReposBySearch_Call) RunAndReturn(run func(context.Context, string, types.RepositoryType, int, int) ([]*database.Repository, int, error)) *MockRepoStore_GetReposBySearch_Call { + _c.Call.Return(run) + return _c +} + // IsMirrorRepo provides a mock function with given fields: ctx, repoType, namespace, name func (_m *MockRepoStore) IsMirrorRepo(ctx context.Context, repoType types.RepositoryType, namespace string, name string) (bool, error) { ret := _m.Called(ctx, repoType, namespace, name) diff --git a/_mocks/opencsg.com/csghub-server/component/mock_RepoComponent.go b/_mocks/opencsg.com/csghub-server/component/mock_RepoComponent.go index 43c08d473..2651044aa 100644 --- a/_mocks/opencsg.com/csghub-server/component/mock_RepoComponent.go +++ b/_mocks/opencsg.com/csghub-server/component/mock_RepoComponent.go @@ -2500,6 +2500,67 @@ func (_c *MockRepoComponent_GetNameSpaceInfo_Call) RunAndReturn(run func(context return _c } +// GetRepos provides a mock function with given fields: ctx, search, currentUser, repoType +func (_m *MockRepoComponent) GetRepos(ctx context.Context, search string, currentUser string, repoType types.RepositoryType) ([]string, error) { + ret := _m.Called(ctx, search, currentUser, repoType) + + if len(ret) == 0 { + panic("no return value specified for GetRepos") + } + + var r0 []string + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, types.RepositoryType) ([]string, error)); ok { + return rf(ctx, search, currentUser, repoType) + } + if rf, ok := ret.Get(0).(func(context.Context, string, string, types.RepositoryType) []string); ok { + r0 = rf(ctx, search, currentUser, repoType) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, string, types.RepositoryType) error); ok { + r1 = rf(ctx, search, currentUser, repoType) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MockRepoComponent_GetRepos_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetRepos' +type MockRepoComponent_GetRepos_Call struct { + *mock.Call +} + +// GetRepos is a helper method to define mock.On call +// - ctx context.Context +// - search string +// - currentUser string +// - repoType types.RepositoryType +func (_e *MockRepoComponent_Expecter) GetRepos(ctx interface{}, search interface{}, currentUser interface{}, repoType interface{}) *MockRepoComponent_GetRepos_Call { + return &MockRepoComponent_GetRepos_Call{Call: _e.mock.On("GetRepos", ctx, search, currentUser, repoType)} +} + +func (_c *MockRepoComponent_GetRepos_Call) Run(run func(ctx context.Context, search string, currentUser string, repoType types.RepositoryType)) *MockRepoComponent_GetRepos_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(string), args[3].(types.RepositoryType)) + }) + return _c +} + +func (_c *MockRepoComponent_GetRepos_Call) Return(_a0 []string, _a1 error) *MockRepoComponent_GetRepos_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *MockRepoComponent_GetRepos_Call) RunAndReturn(run func(context.Context, string, string, types.RepositoryType) ([]string, error)) *MockRepoComponent_GetRepos_Call { + _c.Call.Return(run) + return _c +} + // GetUserRepoPermission provides a mock function with given fields: ctx, userName, repo func (_m *MockRepoComponent) GetUserRepoPermission(ctx context.Context, userName string, repo *database.Repository) (*types.UserRepoPermission, error) { ret := _m.Called(ctx, userName, repo) diff --git a/api/handler/repo.go b/api/handler/repo.go index 6469d22bf..7ceedebe8 100644 --- a/api/handler/repo.go +++ b/api/handler/repo.go @@ -3015,3 +3015,40 @@ func (h *RepoHandler) ChangePath(ctx *gin.Context) { } httpbase.OK(ctx, nil) } + +// GetRepos godoc +// @Security ApiKey +// @Summary Get repo paths with search query +// @Tags Repository +// @Accept json +// @Produce json +// @Param current_user query string false "current user name" +// @Param search query string true "search query" +// @Param type query string true "repository type query" enums(model, dataset, code, space, mcpserver) +// @Success 200 {object} types.Response{data=[]string} "OK" +// @Failure 400 {object} types.APIBadRequest "Bad request" +// @Failure 500 {object} types.APIInternalServerError "Internal server error" +// @Router /repos [get] +func (h *RepoHandler) GetRepos(ctx *gin.Context) { + currentUser := httpbase.GetCurrentUser(ctx) + search := ctx.Query("search") + repositoryType := ctx.Query("type") + repoType := types.RepositoryType(repositoryType) + if repoType == types.UnknownRepo { + httpbase.BadRequest(ctx, "Unknown repository type") + return + } + + repos, err := h.c.GetRepos(ctx.Request.Context(), search, currentUser, repoType) + if err != nil { + slog.Error( + "Failed to get repos", + slog.Any("error", err)) + httpbase.ServerError(ctx, err) + return + } + slog.Debug( + "Get repos succeed", + slog.String("search", search)) + httpbase.OK(ctx, repos) +} diff --git a/api/handler/repo_test.go b/api/handler/repo_test.go index 82bd85ffd..2c5de3b66 100644 --- a/api/handler/repo_test.go +++ b/api/handler/repo_test.go @@ -1528,3 +1528,15 @@ func TestRepoHandler_CommitFiles(t *testing.T) { t, 200, tester.OKText, nil, ) } + +func TestRepoHandler_GetRepos(t *testing.T) { + tester := NewRepoTester(t).WithHandleFunc(func(rp *RepoHandler) gin.HandlerFunc { + return rp.GetRepos + }) + tester.mocks.repo.EXPECT().GetRepos(mock.Anything, "search", "u", types.ModelRepo).Return([]string{}, nil).Once() + tester.WithQuery("type", "model").WithQuery("search", "search").WithUser().Execute() + + tester.ResponseEq( + t, 200, tester.OKText, []string{}, + ) +} diff --git a/api/router/api.go b/api/router/api.go index 92f448bfd..809facc1d 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -188,6 +188,9 @@ func NewRouter(config *config.Config, enableSwagger bool) (*gin.Engine, error) { versionHandler := handler.NewVersionHandler() apiGroup.GET("/version", versionHandler.Version) + // Admin user get repo path list + apiGroup.GET("/repos", middlewareCollection.Auth.NeedAdmin, repoCommonHandler.GetRepos) + // TODO:use middleware to handle common response // memoryStore := persist.NewMemoryStore(1 * time.Minute) diff --git a/builder/store/database/repository.go b/builder/store/database/repository.go index 4cfff0b07..c19394834 100644 --- a/builder/store/database/repository.go +++ b/builder/store/database/repository.go @@ -95,6 +95,7 @@ type RepoStore interface { FindByRepoTypeAndPaths(ctx context.Context, repoType types.RepositoryType, path []string) ([]Repository, error) FindUnhashedRepos(ctx context.Context, batchSize int, lastID int64) ([]Repository, error) UpdateRepoSensitiveCheckStatus(ctx context.Context, repoID int64, status types.SensitiveCheckStatus) error + GetReposBySearch(ctx context.Context, search string, repoType types.RepositoryType, page, pageSize int) ([]*Repository, int, error) } func (s *repoStoreImpl) UpdateRepoSensitiveCheckStatus(ctx context.Context, repoID int64, status types.SensitiveCheckStatus) error { @@ -1395,3 +1396,18 @@ func (s *repoStoreImpl) FindUnhashedRepos(ctx context.Context, batchSize int, la Scan(ctx) return res, err } + +func (s *repoStoreImpl) GetReposBySearch(ctx context.Context, search string, repoType types.RepositoryType, page, pageSize int) ([]*Repository, int, error) { + var ( + res []*Repository + count int + err error + ) + count, err = s.db.Operator.Core.NewSelect(). + Model(&res). + Where("path like ? and repository_type = ?", fmt.Sprintf("%%%s%%", search), repoType). + Offset((page - 1) * pageSize). + Limit(pageSize). + ScanAndCount(ctx) + return res, count, err +} diff --git a/builder/store/database/repository_test.go b/builder/store/database/repository_test.go index ea78495ff..a484c439d 100644 --- a/builder/store/database/repository_test.go +++ b/builder/store/database/repository_test.go @@ -1821,3 +1821,33 @@ func TestRepoStore_UpdateRepoSensitiveCheckStatus(t *testing.T) { require.Nil(t, err) require.Equal(t, types.SensitiveCheckPass, rp.SensitiveCheckStatus) } + +func TestRepoStore_GetReposBySearch(t *testing.T) { + db := tests.InitTestDB() + defer db.Close() + ctx := context.TODO() + + // insert a new repo + rs := database.NewRepoStoreWithDB(db) + for i := 0; i < 10; i++ { + _, err := db.Operator.Core.NewInsert().Model(&database.Repository{ + UserID: 1, + Path: fmt.Sprintf("%s/%d", "path", i), + GitPath: fmt.Sprintf("datasets_%s/%d", "path", i), + Name: fmt.Sprintf("name_%d", i), + DefaultBranch: "main", + Nickname: "ww", + Description: "ww", + Private: false, + RepositoryType: types.DatasetRepo, + Source: "opencsg", + Hashed: false, + }).Exec(ctx) + require.Nil(t, err) + } + + repos, total, err := rs.GetReposBySearch(ctx, "path/1", types.DatasetRepo, 1, 10) + require.Nil(t, err) + require.NotNil(t, repos) + require.Equal(t, 1, total) +} diff --git a/component/repo.go b/component/repo.go index f3f4e4b82..e71162720 100644 --- a/component/repo.go +++ b/component/repo.go @@ -190,6 +190,7 @@ type RepoComponent interface { BatchMigrateRepoToHashedPath(ctx context.Context, auto bool, batchSize int, lastID int64) (int64, error) GetMirrorTaskStatusAndSyncStatus(repo *database.Repository) (types.MirrorTaskStatus, types.RepositorySyncStatus) CheckDeployPermissionForUser(ctx context.Context, deployReq types.DeployActReq) (*database.User, *database.Deploy, error) + GetRepos(ctx context.Context, search, currentUser string, repoType types.RepositoryType) ([]string, error) IsXnetEnabled(ctx context.Context, repoType types.RepositoryType, namespace, name, username string) (*types.XetEnabled, error) } @@ -4073,3 +4074,15 @@ func (c *repoComponentImpl) GetMirrorTaskStatusAndSyncStatus(repo *database.Repo func (c *repoComponentImpl) RandomPath() []string { return strings.SplitN(uuid.NewString(), "-", 2) } + +func (c *repoComponentImpl) GetRepos(ctx context.Context, search, currentUser string, repoType types.RepositoryType) ([]string, error) { + var repoPaths []string + repos, _, err := c.repoStore.GetReposBySearch(ctx, search, repoType, 1, 10) + if err != nil { + return repoPaths, fmt.Errorf("failed to get repos, error: %w", err) + } + for _, repo := range repos { + repoPaths = append(repoPaths, repo.Path) + } + return repoPaths, nil +} diff --git a/component/repo_test.go b/component/repo_test.go index 9d173b593..e928b42e5 100644 --- a/component/repo_test.go +++ b/component/repo_test.go @@ -3244,3 +3244,16 @@ func TestRepoComponent_SendAssetManagementMsg(t *testing.T) { require.Nil(t, err) wg.Wait() } + +func TestRepoComponent_GetRepos(t *testing.T) { + ctx := context.Background() + + repoComp := initializeTestRepoComponent(ctx, t) + + repoComp.mocks.stores.RepoMock().EXPECT().GetReposBySearch(ctx, "search", types.ModelRepo, 1, 10). + Return([]*database.Repository{{ID: 1, Path: "ns/name"}}, 1, nil).Once() + paths, err := repoComp.GetRepos(ctx, "search", "u", types.ModelRepo) + require.NoError(t, err) + require.Equal(t, 1, len(paths)) + require.Equal(t, "ns/name", paths[0]) +}