Skip to content

Commit 135f6c0

Browse files
authored
implement Attentional Factorization Machines (#1134)
1 parent ef99929 commit 135f6c0

File tree

17 files changed

+532
-93
lines changed

17 files changed

+532
-93
lines changed

cmd/gorse-benchmark/main.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
// Copyright 2026 gorse Project Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package main
16+
17+
import (
18+
"context"
19+
"fmt"
20+
"log"
21+
"os"
22+
"runtime"
23+
"sort"
24+
25+
"github.com/gorse-io/gorse/config"
26+
"github.com/gorse-io/gorse/dataset"
27+
"github.com/gorse-io/gorse/master"
28+
"github.com/gorse-io/gorse/model/ctr"
29+
"github.com/gorse-io/gorse/storage"
30+
"github.com/gorse-io/gorse/storage/data"
31+
"github.com/samber/lo"
32+
"github.com/spf13/cobra"
33+
"modernc.org/sortutil"
34+
)
35+
36+
var rootCmd = &cobra.Command{
37+
Use: "gorse-benchmark",
38+
Short: "Gorse Benchmarking Tool",
39+
}
40+
41+
var llmCmd = &cobra.Command{
42+
Use: "llm",
43+
Short: "Benchmark LLM models",
44+
Run: func(cmd *cobra.Command, args []string) {
45+
// Load configuration
46+
configPath, _ := cmd.Flags().GetString("config")
47+
cfg, err := config.LoadConfig(configPath)
48+
if err != nil {
49+
log.Fatalf("failed to load config: %v", err)
50+
}
51+
// Load dataset
52+
m := master.NewMaster(cfg, os.TempDir(), false)
53+
m.DataClient, err = data.Open(m.Config.Database.DataStore, m.Config.Database.DataTablePrefix,
54+
storage.WithIsolationLevel(m.Config.Database.MySQL.IsolationLevel))
55+
if err != nil {
56+
log.Fatalf("failed to open data client: %v", err)
57+
}
58+
evaluator := master.NewOnlineEvaluator(
59+
m.Config.Recommend.DataSource.PositiveFeedbackTypes,
60+
m.Config.Recommend.DataSource.ReadFeedbackTypes)
61+
dataset, _, err := m.LoadDataFromDatabase(context.Background(), m.DataClient,
62+
m.Config.Recommend.DataSource.PositiveFeedbackTypes,
63+
m.Config.Recommend.DataSource.ReadFeedbackTypes,
64+
m.Config.Recommend.DataSource.ItemTTL,
65+
m.Config.Recommend.DataSource.PositiveFeedbackTTL,
66+
evaluator,
67+
nil)
68+
if err != nil {
69+
log.Fatalf("failed to load dataset: %v", err)
70+
}
71+
fmt.Println("Dataset loaded:")
72+
fmt.Printf(" Users: %d\n", dataset.CountUsers())
73+
fmt.Printf(" Items: %d\n", dataset.CountItems())
74+
fmt.Printf(" Positive Feedbacks: %d\n", dataset.CountPositive())
75+
fmt.Printf(" Negative Feedbacks: %d\n", dataset.CountNegative())
76+
// Split dataset
77+
train, test := dataset.Split(0.2, 42)
78+
EvaluateFM(train, test)
79+
// EvaluateLLM(cfg, train, test, aux.GetItems())
80+
},
81+
}
82+
83+
func EvaluateFM(train, test dataset.CTRSplit) float32 {
84+
fmt.Println("Training FM...")
85+
ml := ctr.NewAFM(nil)
86+
ml.Fit(context.Background(), train, test,
87+
ctr.NewFitConfig().
88+
SetVerbose(10).
89+
SetJobs(runtime.NumCPU()).
90+
SetPatience(10))
91+
92+
userTrain := make(map[int32]int, train.CountUsers())
93+
for i := 0; i < train.Count(); i++ {
94+
indices, _, _, target := train.Get(i)
95+
userId := indices[0]
96+
if target > 0 {
97+
userTrain[userId]++
98+
}
99+
}
100+
101+
var posFeatures, negFeatures []lo.Tuple2[[]int32, []float32]
102+
var posEmbeddings, negEmbeddings [][][]float32
103+
var posUsers, negUsers []int32
104+
for i := 0; i < test.Count(); i++ {
105+
indices, values, embeddings, target := test.Get(i)
106+
userId := indices[0]
107+
if target > 0 {
108+
posFeatures = append(posFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values})
109+
posEmbeddings = append(posEmbeddings, embeddings)
110+
posUsers = append(posUsers, userId)
111+
} else {
112+
negFeatures = append(negFeatures, lo.Tuple2[[]int32, []float32]{A: indices, B: values})
113+
negEmbeddings = append(negEmbeddings, embeddings)
114+
negUsers = append(negUsers, userId)
115+
}
116+
}
117+
posPrediction := ml.BatchInternalPredict(posFeatures, posEmbeddings, runtime.NumCPU())
118+
negPrediction := ml.BatchInternalPredict(negFeatures, negEmbeddings, runtime.NumCPU())
119+
120+
userPosPrediction := make(map[int32][]float32)
121+
userNegPrediction := make(map[int32][]float32)
122+
for i, p := range posPrediction {
123+
userPosPrediction[posUsers[i]] = append(userPosPrediction[posUsers[i]], p)
124+
}
125+
for i, p := range negPrediction {
126+
userNegPrediction[negUsers[i]] = append(userNegPrediction[negUsers[i]], p)
127+
}
128+
var sumAUC float32
129+
var validUsers float32
130+
for user, pos := range userPosPrediction {
131+
if userTrain[user] > 100 || userTrain[user] == 0 {
132+
continue
133+
}
134+
if neg, ok := userNegPrediction[user]; ok {
135+
sumAUC += AUC(pos, neg) * float32(len(pos))
136+
validUsers += float32(len(pos))
137+
}
138+
}
139+
if validUsers == 0 {
140+
return 0
141+
}
142+
score := sumAUC / validUsers
143+
144+
fmt.Println("FM GAUC:", score)
145+
return score
146+
}
147+
148+
func AUC(posPrediction, negPrediction []float32) float32 {
149+
sort.Sort(sortutil.Float32Slice(posPrediction))
150+
sort.Sort(sortutil.Float32Slice(negPrediction))
151+
var sum float32
152+
var nPos int
153+
for pPos := range posPrediction {
154+
// find the negative sample with the greatest prediction less than current positive sample
155+
for nPos < len(negPrediction) && negPrediction[nPos] < posPrediction[pPos] {
156+
nPos++
157+
}
158+
// add the number of negative samples have less prediction than current positive sample
159+
sum += float32(nPos)
160+
}
161+
if len(posPrediction)*len(negPrediction) == 0 {
162+
return 0
163+
}
164+
return sum / float32(len(posPrediction)*len(negPrediction))
165+
}
166+
167+
func init() {
168+
rootCmd.PersistentFlags().StringP("config", "c", "", "Path to configuration file")
169+
rootCmd.AddCommand(llmCmd)
170+
}
171+
172+
func main() {
173+
if err := rootCmd.Execute(); err != nil {
174+
log.Fatal(err)
175+
}
176+
}

codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@ coverage:
66

77
ignore:
88
- "protocol/*.pb.go"
9+
- "cmd/**"

common/nn/layers.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,38 @@ func (s *Sequential) SetJobs(jobs int) {
157157
}
158158
}
159159

160+
type Attention struct {
161+
W Layer
162+
H *Tensor
163+
jobs int
164+
}
165+
166+
func NewAttention(dimensions, k int) *Attention {
167+
return &Attention{
168+
W: NewLinear(dimensions, k),
169+
H: Normal(0, 0.01, k, dimensions),
170+
}
171+
}
172+
173+
func (a *Attention) Parameters() []*Tensor {
174+
var params []*Tensor
175+
params = append(params, a.H)
176+
params = append(params, a.W.Parameters()...)
177+
return params
178+
}
179+
180+
func (a *Attention) Forward(x *Tensor) *Tensor {
181+
return Mul(
182+
Softmax(MatMul(ReLu(a.W.Forward(x)), a.H, false, false, a.jobs), 1),
183+
x,
184+
)
185+
}
186+
187+
func (a *Attention) SetJobs(jobs int) {
188+
a.W.SetJobs(jobs)
189+
a.jobs = max(1, jobs)
190+
}
191+
160192
func Save(o any, w io.Writer) error {
161193
var save func(o any, key []string) error
162194
save = func(o any, key []string) error {

dataset/dataset.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ type CTRSplit interface {
6666
CountNegative() int
6767
GetIndex() UnifiedIndex
6868
GetTarget(i int) float32
69-
Get(i int) ([]int32, []float32, float32)
69+
Get(i int) ([]int32, []float32, [][]float32, float32)
70+
GetItemEmbeddingDim() []int
71+
GetItemEmbeddingIndex() *Index
7072
}
7173

7274
type Dataset struct {

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ require (
202202
golang.org/x/crypto v0.45.0 // indirect
203203
golang.org/x/net v0.47.0 // indirect
204204
golang.org/x/sync v0.18.0 // indirect
205-
golang.org/x/term v0.37.0 // indirect
205+
golang.org/x/term v0.38.0 // indirect
206206
golang.org/x/text v0.31.0 // indirect
207207
golang.org/x/time v0.14.0 // indirect
208208
gonum.org/v1/gonum v0.16.0 // indirect

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,8 +1054,8 @@ golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
10541054
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
10551055
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
10561056
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
1057-
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
1058-
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
1057+
golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
1058+
golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
10591059
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
10601060
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
10611061
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

logics/chat.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,11 @@ func (r *ChatRanker) Rank(ctx context.Context, user *data.User, feedback []*Feed
111111
s.Add(item.ItemId)
112112
}
113113
var result []string
114+
m := mapset.NewSet[string]()
114115
for _, itemId := range parsed {
115-
if s.Contains(itemId) {
116+
if s.Contains(itemId) && !m.Contains(itemId) {
116117
result = append(result, itemId)
118+
m.Add(itemId)
117119
}
118120
}
119121
return result, nil

master/rpc_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ func newMockMasterRPC(t *testing.T) *mockMasterRPC {
6767
assert.NoError(t, err)
6868
// create click model
6969
train, test := newClickDataset()
70-
fm := ctr.NewFMV2(model.Params{model.NEpochs: 0})
70+
fm := ctr.NewAFM(model.Params{model.NEpochs: 0})
7171
fm.Fit(context.Background(), train, test, &ctr.FitConfig{})
7272
// create ranking model
7373
trainSet, testSet := newRankingDataset()

master/tasks.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1019,7 +1019,7 @@ func (m *Master) trainClickThroughRatePrediction(parent context.Context, trainSe
10191019
zap.Float32("Recall", m.clickThroughRateTarget.Score.Recall),
10201020
zap.Any("params", clickThroughRateParams))
10211021
}
1022-
clickModel := ctr.NewFMV2(clickThroughRateParams)
1022+
clickModel := ctr.NewAFM(clickThroughRateParams)
10231023
m.clickThroughRateModelMutex.Unlock()
10241024

10251025
startFitTime := time.Now()
@@ -1244,7 +1244,7 @@ func (m *Master) optimizeClickThroughRatePrediction(parent context.Context, trai
12441244

12451245
search := ctr.NewModelSearch(map[string]ctr.ModelCreator{
12461246
"FM": func() ctr.FactorizationMachines {
1247-
return ctr.NewFMV2(nil)
1247+
return ctr.NewAFM(nil)
12481248
},
12491249
}, trainSet, testSet,
12501250
ctr.NewFitConfig().

model/ctr/data.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ type Dataset struct {
147147
Users []int32
148148
Items []int32
149149
Target []float32
150-
ItemEmbeddings [][][]float32
150+
ItemEmbeddings [][][]float32 // Index by row id, embedding id, embedding dimension
151151
ItemEmbeddingDimension []int
152152
ItemEmbeddingIndex *dataset.Index
153153
PositiveCount int
@@ -207,11 +207,12 @@ func (dataset *Dataset) GetTarget(i int) float32 {
207207
}
208208

209209
// Get returns the i-th sample.
210-
func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
210+
func (dataset *Dataset) Get(i int) ([]int32, []float32, [][]float32, float32) {
211211
var (
212-
indices []int32
213-
values []float32
214-
position int32
212+
indices []int32
213+
values []float32
214+
embedding [][]float32
215+
position int32
215216
)
216217
// append user id
217218
if len(dataset.Users) > 0 {
@@ -224,6 +225,9 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
224225
indices = append(indices, position+dataset.Items[i])
225226
values = append(values, 1)
226227
position += int32(dataset.CountItems())
228+
if len(dataset.ItemEmbeddings) > 0 {
229+
embedding = dataset.ItemEmbeddings[dataset.Items[i]]
230+
}
227231
}
228232
// append user indices
229233
if len(dataset.Users) > 0 {
@@ -248,7 +252,7 @@ func (dataset *Dataset) Get(i int) ([]int32, []float32, float32) {
248252
indices = append(indices, contextIndices...)
249253
values = append(values, contextValues...)
250254
}
251-
return indices, values, dataset.Target[i]
255+
return indices, values, embedding, dataset.Target[i]
252256
}
253257

254258
// LoadLibFMFile loads libFM format file.
@@ -325,14 +329,20 @@ func LoadDataFromBuiltIn(name string) (train, test *Dataset, err error) {
325329
func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
326330
// create train/test dataset
327331
trainSet := &Dataset{
328-
Index: dataset.Index,
329-
UserLabels: dataset.UserLabels,
330-
ItemLabels: dataset.ItemLabels,
332+
Index: dataset.Index,
333+
UserLabels: dataset.UserLabels,
334+
ItemLabels: dataset.ItemLabels,
335+
ItemEmbeddings: dataset.ItemEmbeddings,
336+
ItemEmbeddingIndex: dataset.ItemEmbeddingIndex,
337+
ItemEmbeddingDimension: dataset.ItemEmbeddingDimension,
331338
}
332339
testSet := &Dataset{
333-
Index: dataset.Index,
334-
UserLabels: dataset.UserLabels,
335-
ItemLabels: dataset.ItemLabels,
340+
Index: dataset.Index,
341+
UserLabels: dataset.UserLabels,
342+
ItemLabels: dataset.ItemLabels,
343+
ItemEmbeddings: dataset.ItemEmbeddings,
344+
ItemEmbeddingIndex: dataset.ItemEmbeddingIndex,
345+
ItemEmbeddingDimension: dataset.ItemEmbeddingDimension,
336346
}
337347
// split by random
338348
numTestSize := int(float32(dataset.Count()) * ratio)
@@ -369,3 +379,11 @@ func (dataset *Dataset) Split(ratio float32, seed int64) (*Dataset, *Dataset) {
369379
}
370380
return trainSet, testSet
371381
}
382+
383+
func (dataset *Dataset) GetItemEmbeddingDim() []int {
384+
return dataset.ItemEmbeddingDimension
385+
}
386+
387+
func (dataset *Dataset) GetItemEmbeddingIndex() *dataset.Index {
388+
return dataset.ItemEmbeddingIndex
389+
}

0 commit comments

Comments
 (0)