diff --git a/cmd/gorse-benchmark/main.go b/cmd/gorse-benchmark/main.go index 938646e6b..a98ae9584 100644 --- a/cmd/gorse-benchmark/main.go +++ b/cmd/gorse-benchmark/main.go @@ -17,20 +17,31 @@ package main import ( "context" "fmt" - "log" + "math" "os" "runtime" "sort" + "strconv" + "sync" + "time" + mapset "github.com/deckarep/golang-set/v2" + "github.com/gorse-io/gorse/common/log" + "github.com/gorse-io/gorse/common/parallel" "github.com/gorse-io/gorse/config" "github.com/gorse-io/gorse/dataset" + "github.com/gorse-io/gorse/logics" "github.com/gorse-io/gorse/master" + "github.com/gorse-io/gorse/model/cf" "github.com/gorse-io/gorse/model/ctr" "github.com/gorse-io/gorse/storage" "github.com/gorse-io/gorse/storage/data" + "github.com/olekukonko/tablewriter" "github.com/samber/lo" + "github.com/sashabaranov/go-openai" "github.com/spf13/cobra" - "modernc.org/sortutil" + "go.uber.org/atomic" + "go.uber.org/zap" ) var rootCmd = &cobra.Command{ @@ -46,19 +57,21 @@ var llmCmd = &cobra.Command{ configPath, _ := cmd.Flags().GetString("config") cfg, err := config.LoadConfig(configPath) if err != nil { - log.Fatalf("failed to load config: %v", err) + log.Logger().Fatal("failed to load config", zap.Error(err)) } + shots, _ := cmd.Flags().GetInt("shots") + // Load dataset m := master.NewMaster(cfg, os.TempDir(), false) m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix, storage.WithIsolationLevel(m.Config.Database.MySQL.IsolationLevel)) if err != nil { - log.Fatalf("failed to open data client: %v", err) + log.Logger().Fatal("failed to open data client", zap.Error(err)) } evaluator := master.NewOnlineEvaluator( m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes) - dataset, _, err := m.LoadDataFromDatabase(context.Background(), m.DataClient, + ctrDataset, dataset, err := m.LoadDataFromDatabase(context.Background(), m.DataClient, m.Config.Recommend.DataSource.PositiveFeedbackTypes, m.Config.Recommend.DataSource.ReadFeedbackTypes, m.Config.Recommend.DataSource.ItemTTL, @@ -66,111 +79,290 @@ var llmCmd = &cobra.Command{ evaluator, nil) if err != nil { - log.Fatalf("failed to load dataset: %v", err) + log.Logger().Fatal("failed to load dataset", zap.Error(err)) } - fmt.Println("Dataset loaded:") - fmt.Printf(" Users: %d\n", dataset.CountUsers()) - fmt.Printf(" Items: %d\n", dataset.CountItems()) - fmt.Printf(" Positive Feedbacks: %d\n", dataset.CountPositive()) - fmt.Printf(" Negative Feedbacks: %d\n", dataset.CountNegative()) + // Split dataset - train, test := dataset.Split(0.2, 42) - EvaluateFM(train, test) - // EvaluateLLM(cfg, train, test, aux.GetItems()) + var scores sync.Map + train, test := dataset.SplitLatest(shots) + test.SampleUserNegatives(dataset, 99) + + table := tablewriter.NewWriter(os.Stdout) + table.Header([]string{"", "#users", "#items", "#interactions"}) + lo.Must0(table.Bulk([][]string{ + {"total", strconv.Itoa(dataset.CountUsers()), strconv.Itoa(dataset.CountItems()), strconv.Itoa(dataset.CountFeedback())}, + {"train", strconv.Itoa(train.CountUsers()), strconv.Itoa(train.CountItems()), strconv.Itoa(train.CountFeedback())}, + {"test", strconv.Itoa(test.CountUsers()), strconv.Itoa(test.CountItems()), strconv.Itoa(test.CountFeedback())}, + })) + lo.Must0(table.Render()) + + go EvaluateCF(train, test, &scores) + go EvaluateAFM(ctrDataset, train, test, &scores) + EvaluateLLM(cfg, train, test, &scores) + data := [][]string{{"Model", "NDCG"}} + scores.Range(func(key, value any) bool { + score := value.(cf.Score) + data = append(data, []string{key.(string), fmt.Sprintf("%.4f", score.NDCG)}) + return true + }) + table = tablewriter.NewWriter(os.Stdout) + table.Header(data[0]) + lo.Must0(table.Bulk(data[1:])) + lo.Must0(table.Render()) }, } -func EvaluateFM(train, test dataset.CTRSplit) float32 { - fmt.Println("Training FM...") +func EvaluateCF(train, test dataset.CFSplit, scores *sync.Map) { + for name, model := range map[string]cf.MatrixFactorization{ + "ALS": cf.NewALS(nil), + "BPR": cf.NewBPR(nil), + } { + score := model.Fit(context.Background(), train, test, + cf.NewFitConfig(). + SetVerbose(10). + SetJobs(runtime.NumCPU()). + SetPatience(10)) + scores.Store(name, score) + } +} + +func EvaluateAFM(ctrDataset *ctr.Dataset, train, test dataset.CFSplit, scores *sync.Map) { + ctrTrain, ctrTest := SplitCTRDataset(ctrDataset, train, test) ml := ctr.NewAFM(nil) - ml.Fit(context.Background(), train, test, + ml.Fit(context.Background(), ctrTrain, ctrTest, ctr.NewFitConfig(). SetVerbose(10). SetJobs(runtime.NumCPU()). SetPatience(10)) - userTrain := make(map[int32]int, train.CountUsers()) - for i := 0; i < train.Count(); i++ { - indices, _, _, target := train.Get(i) - userId := indices[0] - if target > 0 { - userTrain[userId]++ + buildCTRInput := func(user, item int32) ([]int32, []float32, [][]float32) { + var ( + indices []int32 + values []float32 + embedding [][]float32 + position int32 + ) + if ctrDataset.CountUsers() > 0 { + indices = append(indices, user) + values = append(values, 1) + position += int32(ctrDataset.CountUsers()) + } + if ctrDataset.CountItems() > 0 { + indices = append(indices, position+item) + values = append(values, 1) + position += int32(ctrDataset.CountItems()) + if len(ctrDataset.ItemEmbeddings) > 0 && int(item) < len(ctrDataset.ItemEmbeddings) { + embedding = ctrDataset.ItemEmbeddings[item] + } } + if ctrDataset.CountUsers() > 0 { + if int(user) < len(ctrDataset.UserLabels) { + for _, feature := range ctrDataset.UserLabels[user] { + indices = append(indices, position+feature.A) + values = append(values, feature.B) + } + } + position += int32(ctrDataset.Index.CountUserLabels()) + } + if ctrDataset.CountItems() > 0 { + if int(item) < len(ctrDataset.ItemLabels) { + for _, feature := range ctrDataset.ItemLabels[item] { + indices = append(indices, position+feature.A) + values = append(values, feature.B) + } + } + } + return indices, values, embedding } - var posFeatures, negFeatures []lo.Tuple2[[]int32, []float32] - var posEmbeddings, negEmbeddings [][][]float32 - var posUsers, negUsers []int32 - for i := 0; i < test.Count(); i++ { - indices, values, embeddings, target := test.Get(i) - userId := indices[0] - if target > 0 { - posFeatures = append(posFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) - posEmbeddings = append(posEmbeddings, embeddings) - posUsers = append(posUsers, userId) - } else { - negFeatures = append(negFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values}) - negEmbeddings = append(negEmbeddings, embeddings) - negUsers = append(negUsers, userId) + negatives := test.SampleUserNegatives(train, 99) + userFeedback := test.GetUserFeedback() + + var sumNDCG float32 + var ndcgUsers float32 + for userIdx := 0; userIdx < test.CountUsers(); userIdx++ { + positives := userFeedback[userIdx] + if len(positives) == 0 { + continue + } + targetSet := mapset.NewSet(positives...) + candidatesSet := mapset.NewSet(positives...) + for _, item := range negatives[userIdx] { + candidatesSet.Add(item) + } + candidates := candidatesSet.ToSlice() + if len(candidates) == 0 { + continue } + features := make([]lo.Tuple2[[]int32, []float32], len(candidates)) + embeddings := make([][][]float32, len(candidates)) + for i, item := range candidates { + indices, values, embedding := buildCTRInput(int32(userIdx), item) + features[i] = lo.Tuple2[[]int32, []float32]{A: indices, B: values} + embeddings[i] = embedding + } + predictions := ml.BatchInternalPredict(features, embeddings, runtime.NumCPU()) + + type scoredItem struct { + item int32 + score float32 + } + scored := make([]scoredItem, 0, len(candidates)) + for i, item := range candidates { + scored = append(scored, scoredItem{item: item, score: predictions[i]}) + } + sort.Slice(scored, func(i, j int) bool { + return scored[i].score > scored[j].score + }) + rankList := make([]int32, 0, len(scored)) + for _, s := range scored { + rankList = append(rankList, s.item) + } + sumNDCG += cf.NDCG(targetSet, rankList) + ndcgUsers++ + } + + ndcg := float32(0) + if ndcgUsers > 0 { + ndcg = sumNDCG / ndcgUsers } - posPrediction := ml.BatchInternalPredict(posFeatures, posEmbeddings, runtime.NumCPU()) - negPrediction := ml.BatchInternalPredict(negFeatures, negEmbeddings, runtime.NumCPU()) + scores.Store("AFM", cf.Score{NDCG: ndcg}) +} - userPosPrediction := make(map[int32][]float32) - userNegPrediction := make(map[int32][]float32) - for i, p := range posPrediction { - userPosPrediction[posUsers[i]] = append(userPosPrediction[posUsers[i]], p) +func SplitCTRDataset(ctrDataset *ctr.Dataset, train, test dataset.CFSplit) (*ctr.Dataset, *ctr.Dataset) { + makeKey := func(user, item int32) uint64 { + return (uint64(uint32(user)) << 32) | uint64(uint32(item)) } - for i, p := range negPrediction { - userNegPrediction[negUsers[i]] = append(userNegPrediction[negUsers[i]], p) + newSubset := func(capacity int) *ctr.Dataset { + return &ctr.Dataset{ + Index: ctrDataset.Index, + UserLabels: ctrDataset.UserLabels, + ItemLabels: ctrDataset.ItemLabels, + ItemEmbeddings: ctrDataset.ItemEmbeddings, + ItemEmbeddingIndex: ctrDataset.ItemEmbeddingIndex, + ItemEmbeddingDimension: ctrDataset.ItemEmbeddingDimension, + Users: make([]int32, 0, capacity), + Items: make([]int32, 0, capacity), + Target: make([]float32, 0, capacity), + } } - var sumAUC float32 - var validUsers float32 - for user, pos := range userPosPrediction { - if userTrain[user] > 100 || userTrain[user] == 0 { - continue + appendSample := func(dataSet *ctr.Dataset, user, item int32, target float32) { + dataSet.Users = append(dataSet.Users, user) + dataSet.Items = append(dataSet.Items, item) + dataSet.Target = append(dataSet.Target, target) + if target > 0 { + dataSet.PositiveCount++ + } else { + dataSet.NegativeCount++ } - if neg, ok := userNegPrediction[user]; ok { - sumAUC += AUC(pos, neg) * float32(len(pos)) - validUsers += float32(len(pos)) + } + + trainSet := newSubset(ctrDataset.Count()) + testSet := newSubset(test.CountFeedback() + test.CountUsers()*100) + + testPositive := mapset.NewSet[uint64]() + for userIdx, items := range test.GetUserFeedback() { + for _, itemIdx := range items { + testPositive.Add(makeKey(int32(userIdx), itemIdx)) } } - if validUsers == 0 { - return 0 + negatives := test.SampleUserNegatives(train, 99) + testNegative := mapset.NewSet[uint64]() + for userIdx, items := range negatives { + for _, itemIdx := range items { + testNegative.Add(makeKey(int32(userIdx), itemIdx)) + } } - score := sumAUC / validUsers - - fmt.Println("FM GAUC:", score) - return score + addedNegative := make(map[int32]bool) + for i := 0; i < ctrDataset.Count(); i++ { + user := ctrDataset.Users[i] + item := ctrDataset.Items[i] + target := ctrDataset.Target[i] + key := makeKey(user, item) + if target > 0 && testPositive.Contains(key) { + appendSample(testSet, user, item, target) + } else if target <= 0 && !addedNegative[user] { + appendSample(testSet, user, item, target) + addedNegative[user] = true + } else if !testPositive.Contains(key) && !testNegative.Contains(key) { + appendSample(trainSet, user, item, target) + } + } + return trainSet, testSet } -func AUC(posPrediction, negPrediction []float32) float32 { - sort.Sort(sortutil.Float32Slice(posPrediction)) - sort.Sort(sortutil.Float32Slice(negPrediction)) - var sum float32 - var nPos int - for pPos := range posPrediction { - // find the negative sample with the greatest prediction less than current positive sample - for nPos < len(negPrediction) && negPrediction[nPos] < posPrediction[pPos] { - nPos++ - } - // add the number of negative samples have less prediction than current positive sample - sum += float32(nPos) - } - if len(posPrediction)*len(negPrediction) == 0 { - return 0 +func EvaluateLLM(cfg *config.Config, train, test dataset.CFSplit, scores *sync.Map) { + chat, err := logics.NewChatRanker(cfg.OpenAI, cfg.Recommend.Ranker.Prompt) + if err != nil { + log.Logger().Fatal("failed to create chat ranker", zap.Error(err)) } - return sum / float32(len(posPrediction)*len(negPrediction)) + + var sum atomic.Float32 + var count atomic.Float32 + negatives := test.SampleUserNegatives(train, 99) + lo.Must0(parallel.Detachable(context.Background(), test.CountUsers(), runtime.NumCPU(), 10, func(pCtx *parallel.Context, userIdx int) { + targetSet := mapset.NewSet(test.GetUserFeedback()[userIdx]...) + negativeSample := negatives[userIdx] + candidates := make([]*data.Item, 0, targetSet.Cardinality()+len(negativeSample)) + for _, itemIdx := range negativeSample { + candidates = append(candidates, &test.GetItems()[itemIdx]) + } + if len(test.GetUserFeedback()[userIdx]) == 0 { + return + } + for _, itemIdx := range test.GetUserFeedback()[userIdx] { + candidates = append(candidates, &test.GetItems()[itemIdx]) + } + feedback := make([]*logics.FeedbackItem, 0, len(train.GetUserFeedback()[int32(userIdx)])) + for _, itemIdx := range train.GetUserFeedback()[int32(userIdx)] { + feedback = append(feedback, &logics.FeedbackItem{ + Item: train.GetItems()[itemIdx], + }) + } + pCtx.Detach() + start := time.Now() + result, err := chat.Rank(context.Background(), &data.User{}, feedback, candidates) + if err != nil { + if apiError, ok := err.(*openai.APIError); ok && apiError.HTTPStatusCode == 421 { + return + } + log.Logger().Fatal("failed to rank items for user", zap.Int("user", userIdx), zap.Error(err)) + } + duration := time.Since(start) + pCtx.Attach() + var score float32 + if len(result) > 0 { + score = cf.NDCG(targetSet, lo.Map(result, func(itemId string, _ int) int32 { + return train.GetItemDict().Id(itemId) + })) + } else { + score = 0 + } + sum.Add(score) + count.Add(1) + log.Logger().Info("LLM ranking result", + zap.Int("user", userIdx), + zap.Int("feedback", len(feedback)), + zap.Int("candidates", len(candidates)), + zap.Int("results", len(result)), + zap.Float32("user_NDCG", score), + zap.Float32("avg_NDCG", sum.Load()/count.Load()), + zap.Duration("duration", duration), + ) + })) + + score := sum.Load() / count.Load() + scores.Store(cfg.OpenAI.ChatCompletionModel, cf.Score{NDCG: score}) } func init() { rootCmd.PersistentFlags().StringP("config", "c", "", "Path to configuration file") rootCmd.AddCommand(llmCmd) + llmCmd.PersistentFlags().IntP("shots", "s", math.MaxInt, "Number of shots for each user") } func main() { if err := rootCmd.Execute(); err != nil { - log.Fatal(err) + log.Logger().Fatal("failed to execute command", zap.Error(err)) } } diff --git a/common/parallel/parallel_test.go b/common/parallel/parallel_test.go index 11eaecf0c..b337baf17 100644 --- a/common/parallel/parallel_test.go +++ b/common/parallel/parallel_test.go @@ -163,7 +163,7 @@ func TestParallelCancel(t *testing.T) { cancel() } count.Add(1) - time.Sleep(100 * time.Millisecond) + time.Sleep(time.Second) return nil }) diff --git a/dataset/dataset.go b/dataset/dataset.go index f92045147..3ac277e3b 100644 --- a/dataset/dataset.go +++ b/dataset/dataset.go @@ -18,6 +18,7 @@ import ( "bufio" "fmt" "os" + "sort" "strconv" "strings" "time" @@ -42,6 +43,8 @@ type CFSplit interface { CountItems() int // CountFeedback returns the number of (positive) feedback. CountFeedback() int + // GetItems returns the items. + GetItems() []data.Item // GetUserDict returns the frequency dictionary of users. GetUserDict() *FreqDict // GetItemDict returns the frequency dictionary of items. @@ -79,6 +82,7 @@ type Dataset struct { itemLabels *Labels userFeedback [][]int32 itemFeedback [][]int32 + timestamps [][]time.Time negatives [][]int32 userDict *FreqDict itemDict *FreqDict @@ -95,6 +99,7 @@ func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset { itemLabels: NewLabels(), userFeedback: make([][]int32, userCount), itemFeedback: make([][]int32, itemCount), + timestamps: make([][]time.Time, userCount), userDict: NewFreqDict(), itemDict: NewFreqDict(), categories: make(map[string]int), @@ -203,6 +208,9 @@ func (d *Dataset) AddUser(user data.User) { if len(d.userFeedback) < len(d.users) { d.userFeedback = append(d.userFeedback, nil) } + if len(d.timestamps) < len(d.users) { + d.timestamps = append(d.timestamps, nil) + } } func (d *Dataset) AddItem(item data.Item) { @@ -223,11 +231,12 @@ func (d *Dataset) AddItem(item data.Item) { } } -func (d *Dataset) AddFeedback(userId, itemId string) { +func (d *Dataset) AddFeedback(userId, itemId string, timestamp time.Time) { userIndex := d.userDict.Add(userId) itemIndex := d.itemDict.Add(itemId) d.userFeedback[userIndex] = append(d.userFeedback[userIndex], itemIndex) d.itemFeedback[itemIndex] = append(d.itemFeedback[itemIndex], userIndex) + d.timestamps[userIndex] = append(d.timestamps[userIndex], timestamp) d.numFeedback++ } @@ -253,6 +262,7 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { trainSet.items, testSet.items = d.items, d.items trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers()) trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems()) + trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers()) trainSet.userDict, testSet.userDict = d.userDict, d.userDict trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict rng := util.NewRandomGenerator(seed) @@ -262,11 +272,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { k := rng.Intn(len(d.userFeedback[userIndex])) testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k]) testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex) + testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k]) testSet.numFeedback++ for i, itemIndex := range d.userFeedback[userIndex] { if i != k { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) + trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i]) trainSet.numFeedback++ } } @@ -279,11 +291,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { k := rng.Intn(len(d.userFeedback[userIndex])) testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k]) testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex) + testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k]) testSet.numFeedback++ for i, itemIndex := range d.userFeedback[userIndex] { if i != k { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) + trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i]) trainSet.numFeedback++ } } @@ -292,9 +306,10 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { testUserSet := mapset.NewSet(testUsers...) for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { if !testUserSet.Contains(userIndex) { - for _, itemIndex := range d.userFeedback[userIndex] { + for idx, itemIndex := range d.userFeedback[userIndex] { trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) + trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][idx]) trainSet.numFeedback++ } } @@ -303,6 +318,39 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) { return trainSet, testSet } +// SplitLatest splits dataset by moving the most recent feedback of all users into the test set to avoid leakage. +func (d *Dataset) SplitLatest(shots int) (CFSplit, CFSplit) { + trainSet, testSet := new(Dataset), new(Dataset) + trainSet.users, testSet.users = d.users, d.users + trainSet.items, testSet.items = d.items, d.items + trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers()) + trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems()) + trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers()) + trainSet.userDict, testSet.userDict = d.userDict, d.userDict + trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict + for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ { + if len(d.userFeedback[userIndex]) == 0 { + continue + } + idxs := lo.Range(len(d.userFeedback[userIndex])) + sort.Slice(idxs, func(i, j int) bool { + return d.timestamps[userIndex][idxs[i]].After(d.timestamps[userIndex][idxs[j]]) + }) + testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][idxs[0]]) + testSet.itemFeedback[d.userFeedback[userIndex][idxs[0]]] = append(testSet.itemFeedback[d.userFeedback[userIndex][idxs[0]]], userIndex) + testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][idxs[0]]) + testSet.numFeedback++ + for i := 1; i < len(d.userFeedback[userIndex]) && i <= shots; i++ { + itemIndex := d.userFeedback[userIndex][idxs[i]] + trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex) + trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex) + trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][idxs[i]]) + trainSet.numFeedback++ + } + } + return trainSet, testSet +} + type Labels struct { fields *strutil.Pool values *FreqDict @@ -366,6 +414,7 @@ func LoadDataFromBuiltIn(dataSetName string) (*Dataset, *Dataset, error) { test.userDict, test.itemDict = train.userDict, train.itemDict test.userFeedback = make([][]int32, len(train.userFeedback)) test.itemFeedback = make([][]int32, len(train.itemFeedback)) + test.timestamps = make([][]time.Time, len(train.userFeedback)) test.negatives = make([][]int32, len(train.userFeedback)) err = loadTest(test, testFilePath) if err != nil { @@ -404,7 +453,7 @@ func loadTrain(path string) (*Dataset, error) { dataset.AddItem(data.Item{ItemId: util.FormatInt(i)}) } // add feedback - dataset.AddFeedback(fields[0], fields[1]) + dataset.AddFeedback(fields[0], fields[1], time.Time{}) } return dataset, scanner.Err() } @@ -429,7 +478,7 @@ func loadTest(dataset *Dataset, path string) error { positive = positive[1 : len(positive)-1] fields = strings.Split(positive, ",") // add feedback - dataset.AddFeedback(fields[0], fields[1]) + dataset.AddFeedback(fields[0], fields[1], time.Time{}) // add negatives userId, err := strconv.Atoi(fields[0]) if err != nil { diff --git a/dataset/dataset_test.go b/dataset/dataset_test.go index 834e6cc67..799e29a84 100644 --- a/dataset/dataset_test.go +++ b/dataset/dataset_test.go @@ -16,6 +16,7 @@ package dataset import ( "fmt" + "math" "strconv" "testing" "time" @@ -159,7 +160,7 @@ func TestDataset_AddFeedback(t *testing.T) { } for i := 0; i < 10; i++ { for j := i; j < 10; j++ { - dataSet.AddFeedback(strconv.Itoa(i), strconv.Itoa(j)) + dataSet.AddFeedback(strconv.Itoa(i), strconv.Itoa(j), time.Unix(int64(i*10+j), 0)) } } userIDF := dataSet.GetUserIDF() @@ -167,6 +168,7 @@ func TestDataset_AddFeedback(t *testing.T) { for i := 0; i < 10; i++ { assert.Len(t, dataSet.GetUserFeedback()[i], 10-i) assert.Len(t, dataSet.GetItemFeedback()[i], i+1) + assert.Len(t, dataSet.timestamps[i], 10-i) assert.InDelta(t, math32.Log(float32(10)/float32(10-i)), userIDF[i], 1e-2) assert.InDelta(t, math32.Log(float32(10)/float32(i+1)), itemIDF[i], 1e-2) } @@ -184,7 +186,7 @@ func TestDataset_Split(t *testing.T) { } for i := 0; i < numUsers; i++ { for j := i + 1; j < numItems; j++ { - dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j)) + dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Time{}) } } assert.Equal(t, 9, dataset.CountFeedback()) @@ -206,6 +208,37 @@ func TestDataset_Split(t *testing.T) { assert.Equal(t, 2, test2.CountFeedback()) } +func TestDataset_SplitLatest(t *testing.T) { + const numUsers, numItems = 3, 5 + // create dataset + dataset := NewDataset(time.Now(), numUsers, numItems) + for i := 0; i < numUsers; i++ { + dataset.AddUser(data.User{UserId: fmt.Sprintf("user%v", i)}) + } + for i := 0; i < numItems; i++ { + dataset.AddItem(data.Item{ItemId: fmt.Sprintf("item%v", i)}) + } + for i := 0; i < numUsers; i++ { + for j := i + 1; j < numItems; j++ { + dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Unix(int64(j), 0)) + } + } + assert.Equal(t, 9, dataset.CountFeedback()) + // split + train, test := dataset.SplitLatest(math.MaxInt) + assert.Equal(t, numUsers, train.CountUsers()) + assert.Equal(t, numItems, train.CountItems()) + assert.Equal(t, numUsers, test.CountUsers()) + assert.Equal(t, numItems, test.CountItems()) + assert.Equal(t, 6, train.CountFeedback()) + assert.Equal(t, 3, test.CountFeedback()) + for i := 0; i < numUsers; i++ { + assert.Len(t, train.GetUserFeedback()[i], numItems-i-2) + assert.Len(t, test.GetUserFeedback()[i], 1) + assert.Equal(t, 4, int(test.GetUserFeedback()[i][0])) + } +} + func TestDataset_LoadMovieLens1M(t *testing.T) { train, test, err := LoadDataFromBuiltIn("ml-1m") assert.NoError(t, err) diff --git a/go.mod b/go.mod index 324417644..dd182c05c 100644 --- a/go.mod +++ b/go.mod @@ -40,6 +40,7 @@ require ( github.com/minio/minio-go/v7 v7.0.97 github.com/modern-go/reflect2 v1.0.2 github.com/nikolalohinski/gonja/v2 v2.5.0 + github.com/olekukonko/tablewriter v1.1.3 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 github.com/rakyll/statik v0.1.8 @@ -107,6 +108,9 @@ require ( github.com/buger/jsonparser v1.1.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clipperhouse/displaywidth v0.6.2 // indirect + github.com/clipperhouse/stringish v0.1.1 // indirect + github.com/clipperhouse/uax29/v2 v2.3.0 // indirect github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect @@ -114,6 +118,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/envoyproxy/go-control-plane/envoy v1.35.0 // indirect github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.10 // indirect @@ -153,7 +158,9 @@ require ( github.com/klauspost/crc32 v1.3.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mailru/easyjson v0.7.7 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect github.com/minio/blake2b-simd v0.0.0-20160723061019-3f5f724cb5b1 // indirect github.com/minio/crc64nvme v1.1.1 // indirect github.com/minio/md5-simd v1.1.2 // indirect @@ -161,6 +168,9 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/montanaflynn/stats v0.7.1 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 // indirect + github.com/olekukonko/errors v1.1.0 // indirect + github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 // indirect github.com/openzipkin/zipkin-go v0.4.1 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect diff --git a/go.sum b/go.sum index 76ae20838..4ff3aa178 100644 --- a/go.sum +++ b/go.sum @@ -132,6 +132,12 @@ github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWR github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/clipperhouse/displaywidth v0.6.2 h1:ZDpTkFfpHOKte4RG5O/BOyf3ysnvFswpyYrV7z2uAKo= +github.com/clipperhouse/displaywidth v0.6.2/go.mod h1:R+kHuzaYWFkTm7xoMmK1lFydbci4X2CicfbGstSGg0o= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.3.0 h1:SNdx9DVUqMoBuBoW3iLOj4FQv3dN5mDtuqwuhIGpJy4= +github.com/clipperhouse/uax29/v2 v2.3.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= github.com/cloudflare/cfssl v0.0.0-20190808011637-b1ec8c586c2a/go.mod h1:yMWuSON2oQp+43nFtAV/uvKQIFpSPerB57DCt9t8sSA= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/xds/go v0.0.0-20251022180443-0feb69152e9f h1:Y8xYupdHxryycyPlc9Y+bSQAYZnetRJ70VMVKm5CKI0= @@ -192,6 +198,8 @@ github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= @@ -552,6 +560,8 @@ github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaO github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4= github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= @@ -562,8 +572,8 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.10/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= -github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= -github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.0/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= @@ -606,6 +616,14 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLA github.com/nikolalohinski/gonja/v2 v2.5.0 h1:O59grn57yCFEeTdHGzYPzg2gGeh4MgroC2ArQJ9pry0= github.com/nikolalohinski/gonja/v2 v2.5.0/go.mod h1:UIzXPVuOsr5h7dZ5DUbqk3/Z7oFA/NLGQGMjqT4L2aU= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6 h1:zrbMGy9YXpIeTnGj4EljqMiZsIcE09mmF8XsD5AYOJc= +github.com/olekukonko/cat v0.0.0-20250911104152-50322a0618f6/go.mod h1:rEKTHC9roVVicUIfZK7DYrdIoM0EOr8mK1Hj5s3JjH0= +github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= +github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= +github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0 h1:jrYnow5+hy3WRDCBypUFvVKNSPPCdqgSXIE9eJDD8LM= +github.com/olekukonko/ll v0.1.4-0.20260115111900-9e59c2286df0/go.mod h1:b52bVQRRPObe+yyBl0TxNfhesL0nedD4Cht0/zx55Ew= +github.com/olekukonko/tablewriter v1.1.3 h1:VSHhghXxrP0JHl+0NnKid7WoEmd9/urKRJLysb70nnA= +github.com/olekukonko/tablewriter v1.1.3/go.mod h1:9VU0knjhmMkXjnMKrZ3+L2JhhtsQ/L38BbL3CRNE8tM= github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= github.com/onsi/ginkgo/v2 v2.23.4 h1:ktYTpKJAVZnDT4VjxSbiBenUjmlL/5QkBEocaWXiQus= github.com/onsi/ginkgo/v2 v2.23.4/go.mod h1:Bt66ApGPBFzHyR+JO10Zbt0Gsp4uWxu5mIOTusL46e8= diff --git a/logics/chat.go b/logics/chat.go index ecb3f22ef..95c089afa 100644 --- a/logics/chat.go +++ b/logics/chat.go @@ -82,6 +82,9 @@ func (r *ChatRanker) Rank(ctx context.Context, user *data.User, feedback []*Feed Role: openai.ChatMessageRoleUser, Content: buf.String(), }}, + ChatTemplateKwargs: map[string]any{ + "enable_thinking": false, + }, }) if err == nil { return resp, nil diff --git a/master/tasks.go b/master/tasks.go index 81808750b..4d922f0f2 100644 --- a/master/tasks.go +++ b/master/tasks.go @@ -479,7 +479,7 @@ func (m *Master) LoadDataFromDatabase( break } } - dataSet.AddFeedback(f.UserId, f.ItemId) + dataSet.AddFeedback(f.UserId, f.ItemId, f.Timestamp) } span.Add(len(feedback)) } diff --git a/model/cf/evaluator_test.go b/model/cf/evaluator_test.go index 510a8a1f8..45641a883 100644 --- a/model/cf/evaluator_test.go +++ b/model/cf/evaluator_test.go @@ -144,7 +144,7 @@ func TestEvaluate(t *testing.T) { } for i := 0; i < 16; i++ { test.AddItem(data.Item{ItemId: strconv.Itoa(i)}) - test.AddFeedback(strconv.Itoa(i/4), strconv.Itoa(i)) + test.AddFeedback(strconv.Itoa(i/4), strconv.Itoa(i), time.Time{}) } assert.Equal(t, 16, test.CountFeedback()) assert.Equal(t, 4, test.CountUsers())