Skip to content

Commit 772a623

Browse files
committed
Add timestamps to feedback in Dataset and update related methods and tests
1 parent 762092d commit 772a623

File tree

4 files changed

+87
-9
lines changed

4 files changed

+87
-9
lines changed

dataset/dataset.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ type Dataset struct {
7979
itemLabels *Labels
8080
userFeedback [][]int32
8181
itemFeedback [][]int32
82+
timestamps [][]time.Time
8283
negatives [][]int32
8384
userDict *FreqDict
8485
itemDict *FreqDict
@@ -95,6 +96,7 @@ func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset {
9596
itemLabels: NewLabels(),
9697
userFeedback: make([][]int32, userCount),
9798
itemFeedback: make([][]int32, itemCount),
99+
timestamps: make([][]time.Time, userCount),
98100
userDict: NewFreqDict(),
99101
itemDict: NewFreqDict(),
100102
categories: make(map[string]int),
@@ -203,6 +205,9 @@ func (d *Dataset) AddUser(user data.User) {
203205
if len(d.userFeedback) < len(d.users) {
204206
d.userFeedback = append(d.userFeedback, nil)
205207
}
208+
if len(d.timestamps) < len(d.users) {
209+
d.timestamps = append(d.timestamps, nil)
210+
}
206211
}
207212

208213
func (d *Dataset) AddItem(item data.Item) {
@@ -223,11 +228,12 @@ func (d *Dataset) AddItem(item data.Item) {
223228
}
224229
}
225230

226-
func (d *Dataset) AddFeedback(userId, itemId string) {
231+
func (d *Dataset) AddFeedback(userId, itemId string, timestamp time.Time) {
227232
userIndex := d.userDict.Add(userId)
228233
itemIndex := d.itemDict.Add(itemId)
229234
d.userFeedback[userIndex] = append(d.userFeedback[userIndex], itemIndex)
230235
d.itemFeedback[itemIndex] = append(d.itemFeedback[itemIndex], userIndex)
236+
d.timestamps[userIndex] = append(d.timestamps[userIndex], timestamp)
231237
d.numFeedback++
232238
}
233239

@@ -253,6 +259,7 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
253259
trainSet.items, testSet.items = d.items, d.items
254260
trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers())
255261
trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems())
262+
trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers())
256263
trainSet.userDict, testSet.userDict = d.userDict, d.userDict
257264
trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict
258265
rng := util.NewRandomGenerator(seed)
@@ -262,11 +269,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
262269
k := rng.Intn(len(d.userFeedback[userIndex]))
263270
testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k])
264271
testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex)
272+
testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k])
265273
testSet.numFeedback++
266274
for i, itemIndex := range d.userFeedback[userIndex] {
267275
if i != k {
268276
trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex)
269277
trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex)
278+
trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i])
270279
trainSet.numFeedback++
271280
}
272281
}
@@ -277,13 +286,15 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
277286
for _, userIndex := range testUsers {
278287
if len(d.userFeedback[userIndex]) > 0 {
279288
k := rng.Intn(len(d.userFeedback[userIndex]))
280-
testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k])
289+
testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][k])
281290
testSet.itemFeedback[d.userFeedback[userIndex][k]] = append(testSet.itemFeedback[d.userFeedback[userIndex][k]], userIndex)
291+
testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][k])
282292
testSet.numFeedback++
283293
for i, itemIndex := range d.userFeedback[userIndex] {
284294
if i != k {
285295
trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex)
286296
trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex)
297+
trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i])
287298
trainSet.numFeedback++
288299
}
289300
}
@@ -292,9 +303,10 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
292303
testUserSet := mapset.NewSet(testUsers...)
293304
for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ {
294305
if !testUserSet.Contains(userIndex) {
295-
for _, itemIndex := range d.userFeedback[userIndex] {
306+
for idx, itemIndex := range d.userFeedback[userIndex] {
296307
trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex)
297308
trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex)
309+
trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][idx])
298310
trainSet.numFeedback++
299311
}
300312
}
@@ -303,6 +315,39 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
303315
return trainSet, testSet
304316
}
305317

318+
// SplitLatest splits dataset by moving the most recent feedback of all users into the test set to avoid leakage.
319+
func (d *Dataset) SplitLatest() (CFSplit, CFSplit) {
320+
trainSet, testSet := new(Dataset), new(Dataset)
321+
trainSet.users, testSet.users = d.users, d.users
322+
trainSet.items, testSet.items = d.items, d.items
323+
trainSet.userFeedback, testSet.userFeedback = make([][]int32, d.CountUsers()), make([][]int32, d.CountUsers())
324+
trainSet.itemFeedback, testSet.itemFeedback = make([][]int32, d.CountItems()), make([][]int32, d.CountItems())
325+
trainSet.timestamps, testSet.timestamps = make([][]time.Time, d.CountUsers()), make([][]time.Time, d.CountUsers())
326+
trainSet.userDict, testSet.userDict = d.userDict, d.userDict
327+
trainSet.itemDict, testSet.itemDict = d.itemDict, d.itemDict
328+
for userIndex := int32(0); userIndex < int32(d.CountUsers()); userIndex++ {
329+
if len(d.userFeedback[userIndex]) == 0 {
330+
continue
331+
}
332+
_, latestIdx := lo.MaxIndexBy(d.timestamps[userIndex], func(a, b time.Time) bool {
333+
return a.After(b)
334+
})
335+
testSet.timestamps[userIndex] = append(testSet.timestamps[userIndex], d.timestamps[userIndex][latestIdx])
336+
testSet.itemFeedback[d.userFeedback[userIndex][latestIdx]] = append(testSet.itemFeedback[d.userFeedback[userIndex][latestIdx]], userIndex)
337+
testSet.userFeedback[userIndex] = append(testSet.userFeedback[userIndex], d.userFeedback[userIndex][latestIdx])
338+
testSet.numFeedback++
339+
for i, itemIndex := range d.userFeedback[userIndex] {
340+
if i != latestIdx {
341+
trainSet.userFeedback[userIndex] = append(trainSet.userFeedback[userIndex], itemIndex)
342+
trainSet.itemFeedback[itemIndex] = append(trainSet.itemFeedback[itemIndex], userIndex)
343+
trainSet.timestamps[userIndex] = append(trainSet.timestamps[userIndex], d.timestamps[userIndex][i])
344+
trainSet.numFeedback++
345+
}
346+
}
347+
}
348+
return trainSet, testSet
349+
}
350+
306351
type Labels struct {
307352
fields *strutil.Pool
308353
values *FreqDict
@@ -366,6 +411,7 @@ func LoadDataFromBuiltIn(dataSetName string) (*Dataset, *Dataset, error) {
366411
test.userDict, test.itemDict = train.userDict, train.itemDict
367412
test.userFeedback = make([][]int32, len(train.userFeedback))
368413
test.itemFeedback = make([][]int32, len(train.itemFeedback))
414+
test.timestamps = make([][]time.Time, len(train.userFeedback))
369415
test.negatives = make([][]int32, len(train.userFeedback))
370416
err = loadTest(test, testFilePath)
371417
if err != nil {
@@ -404,7 +450,7 @@ func loadTrain(path string) (*Dataset, error) {
404450
dataset.AddItem(data.Item{ItemId: util.FormatInt(i)})
405451
}
406452
// add feedback
407-
dataset.AddFeedback(fields[0], fields[1])
453+
dataset.AddFeedback(fields[0], fields[1], time.Time{})
408454
}
409455
return dataset, scanner.Err()
410456
}
@@ -429,7 +475,7 @@ func loadTest(dataset *Dataset, path string) error {
429475
positive = positive[1 : len(positive)-1]
430476
fields = strings.Split(positive, ",")
431477
// add feedback
432-
dataset.AddFeedback(fields[0], fields[1])
478+
dataset.AddFeedback(fields[0], fields[1], time.Time{})
433479
// add negatives
434480
userId, err := strconv.Atoi(fields[0])
435481
if err != nil {

dataset/dataset_test.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,15 @@ func TestDataset_AddFeedback(t *testing.T) {
159159
}
160160
for i := 0; i < 10; i++ {
161161
for j := i; j < 10; j++ {
162-
dataSet.AddFeedback(strconv.Itoa(i), strconv.Itoa(j))
162+
dataSet.AddFeedback(strconv.Itoa(i), strconv.Itoa(j), time.Unix(int64(i*10+j), 0))
163163
}
164164
}
165165
userIDF := dataSet.GetUserIDF()
166166
itemIDF := dataSet.GetItemIDF()
167167
for i := 0; i < 10; i++ {
168168
assert.Len(t, dataSet.GetUserFeedback()[i], 10-i)
169169
assert.Len(t, dataSet.GetItemFeedback()[i], i+1)
170+
assert.Len(t, dataSet.timestamps[i], 10-i)
170171
assert.InDelta(t, math32.Log(float32(10)/float32(10-i)), userIDF[i], 1e-2)
171172
assert.InDelta(t, math32.Log(float32(10)/float32(i+1)), itemIDF[i], 1e-2)
172173
}
@@ -184,7 +185,7 @@ func TestDataset_Split(t *testing.T) {
184185
}
185186
for i := 0; i < numUsers; i++ {
186187
for j := i + 1; j < numItems; j++ {
187-
dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j))
188+
dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Time{})
188189
}
189190
}
190191
assert.Equal(t, 9, dataset.CountFeedback())
@@ -206,6 +207,37 @@ func TestDataset_Split(t *testing.T) {
206207
assert.Equal(t, 2, test2.CountFeedback())
207208
}
208209

210+
func TestDataset_SplitLatest(t *testing.T) {
211+
const numUsers, numItems = 3, 5
212+
// create dataset
213+
dataset := NewDataset(time.Now(), numUsers, numItems)
214+
for i := 0; i < numUsers; i++ {
215+
dataset.AddUser(data.User{UserId: fmt.Sprintf("user%v", i)})
216+
}
217+
for i := 0; i < numItems; i++ {
218+
dataset.AddItem(data.Item{ItemId: fmt.Sprintf("item%v", i)})
219+
}
220+
for i := 0; i < numUsers; i++ {
221+
for j := i + 1; j < numItems; j++ {
222+
dataset.AddFeedback(fmt.Sprintf("user%v", i), fmt.Sprintf("item%v", j), time.Unix(int64(j), 0))
223+
}
224+
}
225+
assert.Equal(t, 9, dataset.CountFeedback())
226+
// split
227+
train, test := dataset.SplitLatest()
228+
assert.Equal(t, numUsers, train.CountUsers())
229+
assert.Equal(t, numItems, train.CountItems())
230+
assert.Equal(t, numUsers, test.CountUsers())
231+
assert.Equal(t, numItems, test.CountItems())
232+
assert.Equal(t, 6, train.CountFeedback())
233+
assert.Equal(t, 3, test.CountFeedback())
234+
for i := 0; i < numUsers; i++ {
235+
assert.Len(t, train.GetUserFeedback()[i], numItems-i-2)
236+
assert.Len(t, test.GetUserFeedback()[i], 1)
237+
assert.Equal(t, 4, int(test.GetUserFeedback()[i][0]))
238+
}
239+
}
240+
209241
func TestDataset_LoadMovieLens1M(t *testing.T) {
210242
train, test, err := LoadDataFromBuiltIn("ml-1m")
211243
assert.NoError(t, err)

master/tasks.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ func (m *Master) LoadDataFromDatabase(
479479
break
480480
}
481481
}
482-
dataSet.AddFeedback(f.UserId, f.ItemId)
482+
dataSet.AddFeedback(f.UserId, f.ItemId, f.Timestamp)
483483
}
484484
span.Add(len(feedback))
485485
}

model/cf/evaluator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func TestEvaluate(t *testing.T) {
144144
}
145145
for i := 0; i < 16; i++ {
146146
test.AddItem(data.Item{ItemId: strconv.Itoa(i)})
147-
test.AddFeedback(strconv.Itoa(i/4), strconv.Itoa(i))
147+
test.AddFeedback(strconv.Itoa(i/4), strconv.Itoa(i), time.Time{})
148148
}
149149
assert.Equal(t, 16, test.CountFeedback())
150150
assert.Equal(t, 4, test.CountUsers())

0 commit comments

Comments
 (0)