diff --git a/Makefile b/Makefile index 51c0b548..d879bf19 100644 --- a/Makefile +++ b/Makefile @@ -31,3 +31,18 @@ swag: migrate_local: go run cmd/csghub-server/main.go migration migrate --config local.toml + +db_migrate: + @go run -tags "$(GO_TAGS)" cmd/csghub-server/main.go migration migrate --config local.toml + +db_rollback: + @go run -tags "$(GO_TAGS)" cmd/csghub-server/main.go migration rollback --config local.toml + +start_server: + @go run -tags "$(GO_TAGS)" cmd/csghub-server/main.go start server -l Info -f json --config local.toml + +start_user: + @go run -tags "$(GO_TAGS)" cmd/csghub-server/main.go user launch -l Info -f json --config local.toml + +error_doc: + @go run cmd/csghub-server/main.go errorx doc-gen diff --git a/api/middleware/cache.go b/api/middleware/cache.go new file mode 100644 index 00000000..1c1f294b --- /dev/null +++ b/api/middleware/cache.go @@ -0,0 +1,35 @@ +package middleware + +import ( + "time" + + cache "github.com/chenyahui/gin-cache" + "github.com/gin-gonic/gin" + "opencsg.com/csghub-server/api/httpbase" +) + +func CacheStrategyTrendingRepos() cache.Option { + return cache.WithCacheStrategyByRequest(getCacheStrategyTrendingReposByRequest) +} + +func getCacheStrategyTrendingReposByRequest(c *gin.Context) (bool, cache.Strategy) { + // only cache anonymous users access the trending repositories + if httpbase.GetCurrentUser(c) != "" { + return false, cache.Strategy{} + } + + sort := c.Query("sort") + if sort != "trending" { + return false, cache.Strategy{} + } + + search := c.Query("search") + if search != "" { + return false, cache.Strategy{} + } + + return true, cache.Strategy{ + CacheKey: c.Request.RequestURI, + CacheDuration: 2 * time.Minute, + } +} diff --git a/api/middleware/cache_test.go b/api/middleware/cache_test.go new file mode 100644 index 00000000..34de8aa0 --- /dev/null +++ b/api/middleware/cache_test.go @@ -0,0 +1,203 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" + + cache "github.com/chenyahui/gin-cache" + "github.com/chenyahui/gin-cache/persist" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "opencsg.com/csghub-server/api/httpbase" +) + +func TestCacheStrategyTrendingRepos(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + user string + sortParam string + searchParam string + expectedCache bool + }{ + { + name: "Anonymous user, trending sort, no search", + user: "", + sortParam: "trending", + searchParam: "", + expectedCache: true, + }, + { + name: "Logged-in user, trending sort, no search", + user: "testuser", + sortParam: "trending", + searchParam: "", + expectedCache: false, + }, + { + name: "Anonymous user, non-trending sort, no search", + user: "", + sortParam: "popular", + searchParam: "", + expectedCache: false, + }, + { + name: "Anonymous user, trending sort, with search", + user: "", + sortParam: "trending", + searchParam: "testquery", + expectedCache: false, + }, + { + name: "Anonymous user, no sort, no search", + user: "", + sortParam: "", + searchParam: "", + expectedCache: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a memory store for testing + store := persist.NewMemoryStore(time.Minute) + + // Create a test router with cache middleware + router := gin.New() + + // Counter to track how many times the handler is called + callCount := 0 + testHandler := gin.HandlerFunc(func(c *gin.Context) { + callCount++ + c.JSON(200, gin.H{"message": "test", "call": callCount}) + }) + + // Apply cache middleware with our strategy + cacheMiddleware := cache.Cache(store, 2*time.Minute, CacheStrategyTrendingRepos()) + router.GET("/test", func(c *gin.Context) { + // Set user context if provided + if tt.user != "" { + c.Set(httpbase.CurrentUserCtxVar, tt.user) + } + // Call next to continue to cache middleware + c.Next() + }, cacheMiddleware, testHandler) + + // Build the request URL + url := "/test" + if tt.sortParam != "" || tt.searchParam != "" { + url += "?sort=" + tt.sortParam + "&search=" + tt.searchParam + } + + // First request + req1, _ := http.NewRequest(http.MethodGet, url, nil) + w1 := httptest.NewRecorder() + router.ServeHTTP(w1, req1) + + assert.Equal(t, 200, w1.Code) + initialCallCount := callCount + + if tt.expectedCache { + // For cached responses, make a second identical request + // The handler should NOT be called again if caching is working + req2, _ := http.NewRequest(http.MethodGet, url, nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + // Call count should remain the same for cached response + assert.Equal(t, initialCallCount, callCount, "Handler should not be called again for cached response") + } else { + // For non-cached responses, make a second identical request + // The handler SHOULD be called again + req2, _ := http.NewRequest(http.MethodGet, url, nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + assert.Equal(t, 200, w2.Code) + // Call count should increase for non-cached response + assert.Greater(t, callCount, initialCallCount, "Handler should be called again for non-cached response") + } + }) + } +} + +// Test the cache strategy function directly by replicating its logic +func TestCacheStrategyLogic(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + user string + sortParam string + searchParam string + expectedCache bool + expectedKey string + }{ + { + name: "Anonymous user, trending sort, no search", + user: "", + sortParam: "trending", + searchParam: "", + expectedCache: true, + expectedKey: "/test?sort=trending&search=", + }, + { + name: "Logged-in user, trending sort, no search", + user: "testuser", + sortParam: "trending", + searchParam: "", + expectedCache: false, + expectedKey: "", + }, + { + name: "Anonymous user, non-trending sort, no search", + user: "", + sortParam: "popular", + searchParam: "", + expectedCache: false, + expectedKey: "", + }, + { + name: "Anonymous user, trending sort, with search", + user: "", + sortParam: "trending", + searchParam: "testquery", + expectedCache: false, + expectedKey: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Simulate request with query parameters + url := "/test?sort=" + tt.sortParam + "&search=" + tt.searchParam + req, _ := http.NewRequest(http.MethodGet, url, nil) + req.RequestURI = url // Manually set RequestURI since http.NewRequest doesn't set it + c.Request = req + + // Set the current user in context if provided + if tt.user != "" { + c.Set(httpbase.CurrentUserCtxVar, tt.user) + } + + // Test the cache strategy logic directly + shouldCache, strategy := getCacheStrategyTrendingReposByRequest(c) + + assert.Equal(t, tt.expectedCache, shouldCache) + if tt.expectedCache { + assert.Equal(t, tt.expectedKey, strategy.CacheKey) + assert.Equal(t, 2*time.Minute, strategy.CacheDuration) + } else { + assert.Empty(t, strategy.CacheKey) + assert.Zero(t, strategy.CacheDuration) + } + }) + } +} diff --git a/api/router/api.go b/api/router/api.go index 419fc7fe..ef27cc8a 100644 --- a/api/router/api.go +++ b/api/router/api.go @@ -496,12 +496,14 @@ func createModelRoutes(config *config.Config, modelHandler *handler.ModelHandler, repoCommonHandler *handler.RepoHandler, monitorHandler *handler.MonitorHandler) { + // gin cache + memoryStore := persist.NewMemoryStore(2 * time.Minute) // Models routes modelsGroup := apiGroup.Group("/models") modelsGroup.Use(middleware.RepoType(types.ModelRepo), middlewareCollection.Repo.RepoExists) { modelsGroup.POST("", middlewareCollection.Auth.NeedLogin, modelHandler.Create) - modelsGroup.GET("", modelHandler.Index) + modelsGroup.GET("", cache.Cache(memoryStore, time.Minute, middleware.CacheStrategyTrendingRepos()), modelHandler.Index) modelsGroup.PUT("/:namespace/:name", middlewareCollection.Auth.NeedLogin, modelHandler.Update) modelsGroup.DELETE("/:namespace/:name", middlewareCollection.Auth.NeedLogin, modelHandler.Delete) modelsGroup.GET("/:namespace/:name", modelHandler.Show) diff --git a/builder/store/database/benchmark_recom_indexes_test.go b/builder/store/database/benchmark_recom_indexes_test.go new file mode 100644 index 00000000..dc12f651 --- /dev/null +++ b/builder/store/database/benchmark_recom_indexes_test.go @@ -0,0 +1,336 @@ +package database_test + +import ( + "context" + "fmt" + "math/rand" + "testing" + "time" + + "github.com/stretchr/testify/require" + "opencsg.com/csghub-server/builder/store/database" + "opencsg.com/csghub-server/common/tests" + "opencsg.com/csghub-server/common/types" +) + +const ( + TestDataSize = 500000 // 500,000 records (sufficient to test index performance) + BenchmarkRepoCount = 10000 // Simulate 10,000 repositories +) + +// BenchmarkRecomIndexes_CoveringIndex tests covering index performance +func BenchmarkRecomIndexes_CoveringIndex(b *testing.B) { + db := tests.InitTransactionTestDB() + defer db.Close() + + insertTestData(b, db) + createCoveringIndex(b, db) + + b.ResetTimer() + b.Run("SortQuery", func(b *testing.B) { + benchmarkSortQuery(b, db) + }) + + b.Run("CombinedQuery", func(b *testing.B) { + benchmarkCombinedQuery(b, db) + }) + + b.Run("CombinedQuery_Where", func(b *testing.B) { + benchmarkCombinedQuery_Where(b, db) + }) +} + +// BenchmarkRecomIndexes_SpecializedIndex tests specialized index performance +func BenchmarkRecomIndexes_SpecializedIndex(b *testing.B) { + db := tests.InitTransactionTestDB() + defer db.Close() + + insertTestData(b, db) + createSpecializedIndex(b, db) + + b.ResetTimer() + b.Run("SortQuery", func(b *testing.B) { + benchmarkSortQuery(b, db) + }) + + b.Run("CombinedQuery", func(b *testing.B) { + benchmarkCombinedQuery(b, db) + }) + + b.Run("CombinedQuery_Where", func(b *testing.B) { + benchmarkCombinedQuery_Where(b, db) + }) +} + +// BenchmarkRecomIndexes_NoIndex tests performance without an index (as a baseline comparison) +func BenchmarkRecomIndexes_NoIndex(b *testing.B) { + db := tests.InitTransactionTestDB() + defer db.Close() + + insertTestData(b, db) + dropAllIndexes(b, db) + + b.ResetTimer() + + b.Run("SortQuery", func(b *testing.B) { + benchmarkSortQuery(b, db) + }) + + b.Run("CombinedQuery", func(b *testing.B) { + benchmarkCombinedQuery(b, db) + }) + + b.Run("CombinedQuery_Where", func(b *testing.B) { + benchmarkCombinedQuery_Where(b, db) + }) +} + +// insertTestData quickly inserts test data +func insertTestData(b *testing.B, db *database.DB) { + ctx := context.Background() + start := time.Now() + + // First, insert test users + insertTestUsers(b, db, ctx) + + // Then, insert test repositories + insertTestRepositories(b, db, ctx) + + // Finally, insert recommendation scores + insertTestRecomScores(b, db, ctx) + + b.Logf("Test data insertion complete, total time: %v", time.Since(start)) +} + +// insertTestUsers inserts test user data +func insertTestUsers(b *testing.B, db *database.DB, ctx context.Context) { + start := time.Now() + + var userBatch []*database.User + for userID := int64(1); userID <= int64(BenchmarkRepoCount/100); userID++ { // Create fewer users than repos + userBatch = append(userBatch, &database.User{ + ID: userID, + GitID: userID + 1000, + NickName: fmt.Sprintf("TestUser%d", userID), + Username: fmt.Sprintf("testuser%d", userID), + Email: fmt.Sprintf("testuser%d@example.com", userID), + Password: "test_password", + UUID: fmt.Sprintf("uuid-%d", userID), + }) + + if len(userBatch) >= 1000 { + _, err := db.Operator.Core.NewInsert().Model(&userBatch).Exec(ctx) + require.NoError(b, err) + userBatch = userBatch[:0] + } + } + + if len(userBatch) > 0 { + _, err := db.Operator.Core.NewInsert().Model(&userBatch).Exec(ctx) + require.NoError(b, err) + } + + b.Logf("User data insertion time: %v", time.Since(start)) +} + +// insertTestRepositories inserts test repository data +func insertTestRepositories(b *testing.B, db *database.DB, ctx context.Context) { + start := time.Now() + + repoTypes := []types.RepositoryType{ + types.ModelRepo, + types.DatasetRepo, + types.SpaceRepo, + types.CodeRepo, + types.PromptRepo, + } + + var repoBatch []*database.Repository + for repoID := int64(1); repoID <= int64(BenchmarkRepoCount); repoID++ { + userID := ((repoID - 1) % int64(BenchmarkRepoCount/100)) + 1 // Distribute repos among users + repoType := repoTypes[(repoID-1)%int64(len(repoTypes))] + + repoBatch = append(repoBatch, &database.Repository{ + ID: repoID, + UserID: userID, + Path: fmt.Sprintf("testuser%d/testrepo%d", userID, repoID), + GitPath: fmt.Sprintf("testuser%d/testrepo%d.git", userID, repoID), + Name: fmt.Sprintf("testrepo%d", repoID), + Nickname: fmt.Sprintf("Test Repository %d", repoID), + Description: fmt.Sprintf("Test repository %d for benchmark testing", repoID), + Private: rand.Intn(2) == 1, // Random private/public + DefaultBranch: "main", + RepositoryType: repoType, + HTTPCloneURL: fmt.Sprintf("https://example.com/testuser%d/testrepo%d.git", userID, repoID), + SSHCloneURL: fmt.Sprintf("git@example.com:testuser%d/testrepo%d.git", userID, repoID), + Source: types.LocalSource, + Likes: rand.Int63n(1000), + DownloadCount: rand.Int63n(10000), + StarCount: rand.Intn(500), + }) + + if len(repoBatch) >= 1000 { + _, err := db.Operator.Core.NewInsert().Model(&repoBatch).Exec(ctx) + require.NoError(b, err) + repoBatch = repoBatch[:0] + } + } + + if len(repoBatch) > 0 { + _, err := db.Operator.Core.NewInsert().Model(&repoBatch).Exec(ctx) + require.NoError(b, err) + } + + b.Logf("Repository data insertion time: %v", time.Since(start)) +} + +// insertTestRecomScores inserts recommendation score data +func insertTestRecomScores(b *testing.B, db *database.DB, ctx context.Context) { + start := time.Now() + + weightNames := []database.RecomWeightName{ + database.RecomWeightTotal, + database.RecomWeightFreshness, + database.RecomWeightDownloads, + database.RecomWeightQuality, + database.RecomWeightOp, + } + + var batch []*database.RecomRepoScore + count := 0 + + for repoID := int64(1); repoID <= int64(BenchmarkRepoCount) && count < TestDataSize; repoID++ { + for _, weightName := range weightNames { + if count >= TestDataSize { + break + } + + batch = append(batch, &database.RecomRepoScore{ + RepositoryID: repoID, + WeightName: weightName, + Score: rand.Float64() * 100, + }) + count++ + + if len(batch) >= 1000 { + _, err := db.Operator.Core.NewInsert().Model(&batch).Exec(ctx) + require.NoError(b, err) + batch = batch[:0] + } + } + } + + if len(batch) > 0 { + _, err := db.Operator.Core.NewInsert().Model(&batch).Exec(ctx) + require.NoError(b, err) + } + + b.Logf("Inserted %d recommendation score records, time: %v", count, time.Since(start)) +} + +func createCoveringIndex(b *testing.B, db *database.DB) { + dropAllIndexes(b, db) + + ctx := b.Context() + start := time.Now() + _, err := db.Operator.Core.ExecContext(ctx, ` + CREATE INDEX idx_recom_covering + ON recom_repo_scores (weight_name, repository_id, score DESC)`) + require.NoError(b, err) + b.Logf("Covering index creation time: %v", time.Since(start)) +} + +func createSpecializedIndex(b *testing.B, db *database.DB) { + dropAllIndexes(b, db) + + ctx := b.Context() + + start := time.Now() + _, err := db.Operator.Core.ExecContext(ctx, ` + CREATE INDEX idx_recom_total_weight_score + ON recom_repo_scores (repository_id, score DESC) + WHERE weight_name = 'total'`) + require.NoError(b, err) + b.Logf("Specialized index creation time: %v", time.Since(start)) +} + +func dropAllIndexes(b *testing.B, db *database.DB) { + ctx := b.Context() + start := time.Now() + + // Drop all custom indexes to ensure no index optimization + _, _ = db.Operator.Core.ExecContext(ctx, "DROP INDEX IF EXISTS idx_recom_covering") + _, _ = db.Operator.Core.ExecContext(ctx, "DROP INDEX IF EXISTS idx_recom_total_weight_score") + + b.Logf("All indexes dropped, time: %v", time.Since(start)) +} + +func benchmarkSortQuery(b *testing.B, db *database.DB) { + b.ResetTimer() + + ctx := context.Background() + for i := 0; i < b.N; i++ { + var results []database.RecomRepoScore + err := db.Operator.Core.NewSelect(). + Model(&results). + Where("weight_name = ?", database.RecomWeightTotal). + Order("score DESC"). + Limit(100). + Scan(ctx) + require.NoError(b, err) + } + + b.ReportMetric(float64(b.Elapsed().Milliseconds()/int64(b.N)), "ms/sql_query") +} + +func benchmarkCombinedQuery(b *testing.B, db *database.DB) { + b.ResetTimer() + ctx := context.Background() + for i := 0; i < b.N; i++ { + var results []struct { + RepositoryID int64 `bun:"repository_id"` + Popularity float64 `bun:"popularity"` + } + + query := ` + SELECT repos.*, COALESCE(r.score, 0) AS popularity + FROM repositories repos + LEFT JOIN recom_repo_scores r ON repos.id = r.repository_id + AND r.weight_name = ? + ORDER BY popularity DESC NULLS LAST + LIMIT 50 + ` + + err := db.Operator.Core.NewRaw(query, database.RecomWeightTotal, database.RecomWeightTotal). + Scan(ctx, &results) + require.NoError(b, err) + } + + b.ReportMetric(float64(b.Elapsed().Milliseconds()/int64(b.N)), "ms/sql_query") +} + +func benchmarkCombinedQuery_Where(b *testing.B, db *database.DB) { + b.ResetTimer() + ctx := context.Background() + for i := 0; i < b.N; i++ { + var results []struct { + RepositoryID int64 `bun:"repository_id"` + Popularity float64 `bun:"popularity"` + } + + query := ` + SELECT repos.*, COALESCE(r.score, 0) AS popularity + FROM repositories repos + LEFT JOIN recom_repo_scores r ON repos.id = r.repository_id + WHERE r.weight_name = ? + ORDER BY popularity DESC NULLS LAST + LIMIT 50 + ` + + err := db.Operator.Core.NewRaw(query, database.RecomWeightTotal, database.RecomWeightTotal). + Scan(ctx, &results) + require.NoError(b, err) + } + + b.ReportMetric(float64(b.Elapsed().Milliseconds()/int64(b.N)), "ms/sql_query") +} diff --git a/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.down.sql b/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.down.sql new file mode 100644 index 00000000..9add04ae --- /dev/null +++ b/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.down.sql @@ -0,0 +1,7 @@ +SET statement_timeout = 0; + +--bun:split + +DROP INDEX IF EXISTS idx_recom_total_weight_score; + +--bun:split diff --git a/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.up.sql b/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.up.sql new file mode 100644 index 00000000..52b83c1c --- /dev/null +++ b/builder/store/database/migrations/20250821065021_add_idx_recom_total_weight_score.up.sql @@ -0,0 +1,8 @@ +SET statement_timeout = 0; + +--bun:split +CREATE INDEX IF NOT EXISTS idx_recom_total_weight_score +ON recom_repo_scores (repository_id, score DESC) +WHERE weight_name = 'total'; + +--bun:split diff --git a/builder/store/database/repository.go b/builder/store/database/repository.go index 71b6530c..33d3c83d 100644 --- a/builder/store/database/repository.go +++ b/builder/store/database/repository.go @@ -619,9 +619,9 @@ func (s *repoStoreImpl) PublicToUser(ctx context.Context, repoType types.Reposit } if filter.Sort == "trending" { - q.Join("Left Join recom_repo_scores on repository.id = recom_repo_scores.repository_id AND recom_repo_scores.weight_name=?", RecomWeightTotal) - q.Join("Left Join recom_op_weights on repository.id = recom_op_weights.repository_id") - q.ColumnExpr(`COALESCE(recom_repo_scores.score, 0)+COALESCE(recom_op_weights.weight, 0) AS popularity`) + q.Join("LEFT JOIN recom_repo_scores ON repository.id = recom_repo_scores.repository_id") + q.Where("recom_repo_scores.weight_name = ?", RecomWeightTotal) + q.ColumnExpr(`COALESCE(recom_repo_scores.score, 0) AS popularity`) } err = q.Order(sortBy[filter.Sort]). diff --git a/component/recom.go b/component/recom.go index 77af9d49..d3d12f77 100644 --- a/component/recom.go +++ b/component/recom.go @@ -62,7 +62,9 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context, batchSize } lastRepoID := int64(0) for { - repos, err := rc.repoStore.BatchGet(ctx, lastRepoID, batchSize, nil) + ctxTimeout, cancel := context.WithTimeout(ctx, 5*time.Second) + repos, err := rc.repoStore.BatchGet(ctxTimeout, lastRepoID, batchSize, nil) + cancel() if err != nil { return errors.New("error fetching repositories") } @@ -70,7 +72,9 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context, batchSize for _, repo := range repos { repoIDs = append(repoIDs, repo.ID) } - scores, err := rc.recomStore.FindScoreByRepoIDs(ctx, repoIDs) + ctxTimeout, cancel = context.WithTimeout(ctx, 5*time.Second) + scores, err := rc.recomStore.FindByRepoIDs(ctxTimeout, repoIDs) + cancel() if err != nil { return errors.New("error fetching scores of repos") } @@ -94,7 +98,9 @@ func (rc *recomComponentImpl) CalculateRecomScore(ctx context.Context, batchSize newScores = append(newScores, newRepoScores...) } - err = rc.recomStore.UpsertScore(ctx, newScores) + ctxTimeout, cancel = context.WithTimeout(ctx, 5*time.Second) + err = rc.recomStore.UpsertScore(ctxTimeout, newScores) + cancel() if err != nil { slog.Error("failed to flush recom score", slog.Any("error", err), slog.Any("repo_ids", repoIDs)) } else { diff --git a/component/recom_test.go b/component/recom_test.go index 86dc87e9..6aa02f0d 100644 --- a/component/recom_test.go +++ b/component/recom_test.go @@ -42,11 +42,11 @@ func TestRecomComponent_CalculateRecomScore(t *testing.T) { repo3 := database.Repository{ID: 3, Path: "foo/bar3"} repo3.UpdatedAt = time.Now().Add(24 * time.Hour) // loop 1 - rc.mocks.stores.RepoMock().EXPECT().BatchGet(ctx, int64(0), batchSize, (*types.BatchGetFilter)(nil)).Return([]database.Repository{ + rc.mocks.stores.RepoMock().EXPECT().BatchGet(mock.Anything, int64(0), batchSize, (*types.BatchGetFilter)(nil)).Return([]database.Repository{ repo1, repo2, }, nil) // loop 2 - rc.mocks.stores.RepoMock().EXPECT().BatchGet(ctx, int64(2), batchSize, (*types.BatchGetFilter)(nil)).Return([]database.Repository{ + rc.mocks.stores.RepoMock().EXPECT().BatchGet(mock.Anything, int64(2), batchSize, (*types.BatchGetFilter)(nil)).Return([]database.Repository{ repo3, }, nil) @@ -77,7 +77,7 @@ func TestRecomComponent_CalculateRecomScore(t *testing.T) { ).Return(nil, nil) // rc.mocks.stores.RecomMock().EXPECT().UpsertScore(ctx, int64(2), 12.34).Return(nil) - rc.mocks.stores.RecomMock().EXPECT().UpsertScore(ctx, mock.Anything).RunAndReturn( + rc.mocks.stores.RecomMock().EXPECT().UpsertScore(mock.Anything, mock.Anything).RunAndReturn( func(ctx context.Context, scores []*database.RecomRepoScore) error { // scores to map by repo id scoresMap := make(map[int64][]*database.RecomRepoScore)