Skip to content

Commit 3574b3f

Browse files
committed
Remove dup output
1 parent efd130d commit 3574b3f

File tree

2 files changed

+46
-15
lines changed

2 files changed

+46
-15
lines changed

cmd/gorse-benchmark/main.go

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"strings"
1111

1212
mapset "github.com/deckarep/golang-set/v2"
13+
"github.com/gorse-io/gorse/common/parallel"
1314
"github.com/gorse-io/gorse/config"
1415
"github.com/gorse-io/gorse/dataset"
1516
"github.com/gorse-io/gorse/logics"
@@ -18,7 +19,10 @@ import (
1819
"github.com/gorse-io/gorse/storage"
1920
"github.com/gorse-io/gorse/storage/data"
2021
"github.com/samber/lo"
22+
"github.com/samber/lo/mutable"
23+
"github.com/sashabaranov/go-openai"
2124
"github.com/spf13/cobra"
25+
"go.uber.org/atomic"
2226
"golang.org/x/term"
2327
"modernc.org/sortutil"
2428
)
@@ -65,8 +69,8 @@ var llmCmd = &cobra.Command{
6569
fmt.Printf(" Negative Feedbacks: %d\n", dataset.CountNegative())
6670
// Split dataset
6771
train, test := dataset.Split(0.8, 42)
68-
EvaluateLLM(cfg, train, test, aux.GetItems())
6972
// EvaluateFM(train, test)
73+
EvaluateLLM(cfg, train, test, aux.GetItems())
7074
},
7175
}
7276

@@ -109,29 +113,40 @@ func EvaluateLLM(cfg *config.Config, train, test dataset.CTRSplit, items []data.
109113
}
110114
}
111115

112-
var sumAUC float32
113-
var validUsers float32
114-
for userId, testItems := range userTest {
116+
var sumAUC atomic.Float32
117+
var validUsers atomic.Float32
118+
parallel.Detachable(len(userTest), runtime.NumCPU(), 100, func(pCtx *parallel.Context, userIdx int) {
119+
userId := int32(userIdx)
120+
testItems := userTest[userId]
121+
if len(userTrain[userId]) > 100 || len(userTrain[userId]) == 0 {
122+
return
123+
}
115124
if _, ok := userPositive[userId]; !ok {
116-
continue
125+
return
117126
}
118127
if _, ok := userNegative[userId]; !ok {
119-
continue
128+
return
120129
}
121130
candidates := make([]*data.Item, 0, len(testItems))
122131
for _, itemId := range testItems {
123132
candidates = append(candidates, &items[itemId])
124133
}
134+
mutable.Reverse(candidates)
125135
feedback := make([]*logics.FeedbackItem, 0, len(testItems))
126136
for _, itemId := range userTrain[userId] {
127137
feedback = append(feedback, &logics.FeedbackItem{
128138
Item: items[itemId],
129139
})
130140
}
141+
pCtx.Detach()
131142
result, err := chat.Rank(&data.User{}, feedback, candidates)
132143
if err != nil {
144+
if apiError, ok := err.(*openai.APIError); ok && apiError.HTTPStatusCode == 421 {
145+
return
146+
}
133147
log.Fatalf("failed to rank items for user %d: %v", userId, err)
134148
}
149+
pCtx.Attach()
135150
var posPredictions, negPredictions []float32
136151
for i, name := range result {
137152
itemId := test.GetIndex().EncodeItem(name) - int32(test.CountUsers())
@@ -143,18 +158,20 @@ func EvaluateLLM(cfg *config.Config, train, test dataset.CTRSplit, items []data.
143158
log.Fatalf("item %s not found in test set for user %d", name, userId)
144159
}
145160
}
146-
sumAUC += AUC(posPredictions, negPredictions) * float32(len(posPredictions))
147-
validUsers += float32(len(posPredictions))
148-
fmt.Println("User", userId, "AUC:", AUC(posPredictions, negPredictions))
149-
if validUsers >= 100 {
150-
break
161+
if len(negPredictions) == 0 || len(posPredictions) == 0 {
162+
return
151163
}
152-
}
153-
if validUsers == 0 {
164+
sumAUC.Add(AUC(posPredictions, negPredictions) * float32(len(posPredictions)))
165+
validUsers.Add(float32(len(posPredictions)))
166+
fmt.Printf("User %d AUC: %f pos: %d/%d, neg: %d/%d\n", userId, AUC(posPredictions, negPredictions),
167+
len(posPredictions), userPositive[userId].Cardinality(),
168+
len(negPredictions), userNegative[userId].Cardinality())
169+
})
170+
if validUsers.Load() == 0 {
154171
return 0
155172
}
156173

157-
score := sumAUC / validUsers
174+
score := sumAUC.Load() / validUsers.Load()
158175
fmt.Println("LLM GAUC:", score)
159176
return score
160177
}
@@ -169,6 +186,15 @@ func EvaluateFM(train, test dataset.CTRSplit) float32 {
169186
SetJobs(runtime.NumCPU()).
170187
SetPatience(10))
171188

189+
userTrain := make(map[int32]int, train.CountUsers())
190+
for i := 0; i < train.Count(); i++ {
191+
indices, _, target := train.Get(i)
192+
userId := indices[0]
193+
if target > 0 {
194+
userTrain[userId]++
195+
}
196+
}
197+
172198
var posFeatures, negFeatures []lo.Tuple2[[]int32, []float32]
173199
var posUsers, negUsers []int32
174200
for i := 0; i < test.Count(); i++ {
@@ -196,6 +222,9 @@ func EvaluateFM(train, test dataset.CTRSplit) float32 {
196222
var sumAUC float32
197223
var validUsers float32
198224
for user, pos := range userPosPrediction {
225+
if userTrain[user] > 100 || userTrain[user] == 0 {
226+
continue
227+
}
199228
if neg, ok := userNegPrediction[user]; ok {
200229
sumAUC += AUC(pos, neg) * float32(len(pos))
201230
validUsers += float32(len(pos))

logics/chat.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ func (r *ChatRanker) Rank(user *data.User, feedback []*FeedbackItem, items []*da
111111
s.Add(item.ItemId)
112112
}
113113
var result []string
114+
m := mapset.NewSet[string]()
114115
for _, itemId := range parsed {
115-
if s.Contains(itemId) {
116+
if s.Contains(itemId) && !m.Contains(itemId) {
116117
result = append(result, itemId)
118+
m.Add(itemId)
117119
}
118120
}
119121
return result, nil

0 commit comments

Comments
 (0)