@@ -18,6 +18,7 @@ import (
1818 "bufio"
1919 "fmt"
2020 "os"
21+ "sort"
2122 "strconv"
2223 "strings"
2324 "time"
@@ -42,6 +43,8 @@ type CFSplit interface {
4243 CountItems () int
4344 // CountFeedback returns the number of (positive) feedback.
4445 CountFeedback () int
46+ // GetItems returns the items.
47+ GetItems () []data.Item
4548 // GetUserDict returns the frequency dictionary of users.
4649 GetUserDict () * FreqDict
4750 // GetItemDict returns the frequency dictionary of items.
@@ -79,6 +82,7 @@ type Dataset struct {
7982 itemLabels * Labels
8083 userFeedback [][]int32
8184 itemFeedback [][]int32
85+ timestamps [][]time.Time
8286 negatives [][]int32
8387 userDict * FreqDict
8488 itemDict * FreqDict
@@ -95,6 +99,7 @@ func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset {
9599 itemLabels : NewLabels (),
96100 userFeedback : make ([][]int32 , userCount ),
97101 itemFeedback : make ([][]int32 , itemCount ),
102+ timestamps : make ([][]time.Time , userCount ),
98103 userDict : NewFreqDict (),
99104 itemDict : NewFreqDict (),
100105 categories : make (map [string ]int ),
@@ -203,6 +208,9 @@ func (d *Dataset) AddUser(user data.User) {
203208 if len (d .userFeedback ) < len (d .users ) {
204209 d .userFeedback = append (d .userFeedback , nil )
205210 }
211+ if len (d .timestamps ) < len (d .users ) {
212+ d .timestamps = append (d .timestamps , nil )
213+ }
206214}
207215
208216func (d * Dataset ) AddItem (item data.Item ) {
@@ -223,11 +231,12 @@ func (d *Dataset) AddItem(item data.Item) {
223231 }
224232}
225233
226- func (d * Dataset ) AddFeedback (userId , itemId string ) {
234+ func (d * Dataset ) AddFeedback (userId , itemId string , timestamp time. Time ) {
227235 userIndex := d .userDict .Add (userId )
228236 itemIndex := d .itemDict .Add (itemId )
229237 d .userFeedback [userIndex ] = append (d .userFeedback [userIndex ], itemIndex )
230238 d .itemFeedback [itemIndex ] = append (d .itemFeedback [itemIndex ], userIndex )
239+ d .timestamps [userIndex ] = append (d .timestamps [userIndex ], timestamp )
231240 d .numFeedback ++
232241}
233242
@@ -253,6 +262,7 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
253262 trainSet .items , testSet .items = d .items , d .items
254263 trainSet .userFeedback , testSet .userFeedback = make ([][]int32 , d .CountUsers ()), make ([][]int32 , d .CountUsers ())
255264 trainSet .itemFeedback , testSet .itemFeedback = make ([][]int32 , d .CountItems ()), make ([][]int32 , d .CountItems ())
265+ trainSet .timestamps , testSet .timestamps = make ([][]time.Time , d .CountUsers ()), make ([][]time.Time , d .CountUsers ())
256266 trainSet .userDict , testSet .userDict = d .userDict , d .userDict
257267 trainSet .itemDict , testSet .itemDict = d .itemDict , d .itemDict
258268 rng := util .NewRandomGenerator (seed )
@@ -262,11 +272,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
262272 k := rng .Intn (len (d .userFeedback [userIndex ]))
263273 testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][k ])
264274 testSet .itemFeedback [d.userFeedback [userIndex ][k ]] = append (testSet .itemFeedback [d.userFeedback [userIndex ][k ]], userIndex )
275+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d.timestamps [userIndex ][k ])
265276 testSet .numFeedback ++
266277 for i , itemIndex := range d .userFeedback [userIndex ] {
267278 if i != k {
268279 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
269280 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
281+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][i ])
270282 trainSet .numFeedback ++
271283 }
272284 }
@@ -279,11 +291,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
279291 k := rng .Intn (len (d .userFeedback [userIndex ]))
280292 testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][k ])
281293 testSet .itemFeedback [d.userFeedback [userIndex ][k ]] = append (testSet .itemFeedback [d.userFeedback [userIndex ][k ]], userIndex )
294+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d.timestamps [userIndex ][k ])
282295 testSet .numFeedback ++
283296 for i , itemIndex := range d .userFeedback [userIndex ] {
284297 if i != k {
285298 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
286299 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
300+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][i ])
287301 trainSet .numFeedback ++
288302 }
289303 }
@@ -292,9 +306,10 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
292306 testUserSet := mapset .NewSet (testUsers ... )
293307 for userIndex := int32 (0 ); userIndex < int32 (d .CountUsers ()); userIndex ++ {
294308 if ! testUserSet .Contains (userIndex ) {
295- for _ , itemIndex := range d .userFeedback [userIndex ] {
309+ for idx , itemIndex := range d .userFeedback [userIndex ] {
296310 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
297311 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
312+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][idx ])
298313 trainSet .numFeedback ++
299314 }
300315 }
@@ -303,6 +318,39 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
303318 return trainSet , testSet
304319}
305320
321+ // SplitLatest splits dataset by moving the most recent feedback of all users into the test set to avoid leakage.
322+ func (d * Dataset ) SplitLatest (shots int ) (CFSplit , CFSplit ) {
323+ trainSet , testSet := new (Dataset ), new (Dataset )
324+ trainSet .users , testSet .users = d .users , d .users
325+ trainSet .items , testSet .items = d .items , d .items
326+ trainSet .userFeedback , testSet .userFeedback = make ([][]int32 , d .CountUsers ()), make ([][]int32 , d .CountUsers ())
327+ trainSet .itemFeedback , testSet .itemFeedback = make ([][]int32 , d .CountItems ()), make ([][]int32 , d .CountItems ())
328+ trainSet .timestamps , testSet .timestamps = make ([][]time.Time , d .CountUsers ()), make ([][]time.Time , d .CountUsers ())
329+ trainSet .userDict , testSet .userDict = d .userDict , d .userDict
330+ trainSet .itemDict , testSet .itemDict = d .itemDict , d .itemDict
331+ for userIndex := int32 (0 ); userIndex < int32 (d .CountUsers ()); userIndex ++ {
332+ if len (d .userFeedback [userIndex ]) == 0 {
333+ continue
334+ }
335+ idxs := lo .Range (len (d .userFeedback [userIndex ]))
336+ sort .Slice (idxs , func (i , j int ) bool {
337+ return d.timestamps [userIndex ][idxs [i ]].After (d.timestamps [userIndex ][idxs [j ]])
338+ })
339+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d .timestamps [userIndex ][idxs [0 ]])
340+ testSet .itemFeedback [d .userFeedback [userIndex ][idxs [0 ]]] = append (testSet .itemFeedback [d .userFeedback [userIndex ][idxs [0 ]]], userIndex )
341+ testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d .userFeedback [userIndex ][idxs [0 ]])
342+ testSet .numFeedback ++
343+ for i := 1 ; i < len (d .userFeedback [userIndex ]) && i <= shots ; i ++ {
344+ itemIndex := d.userFeedback [userIndex ][idxs [i ]]
345+ trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
346+ trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
347+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][idxs [i ]])
348+ trainSet .numFeedback ++
349+ }
350+ }
351+ return trainSet , testSet
352+ }
353+
306354type Labels struct {
307355 fields * strutil.Pool
308356 values * FreqDict
@@ -366,6 +414,7 @@ func LoadDataFromBuiltIn(dataSetName string) (*Dataset, *Dataset, error) {
366414 test .userDict , test .itemDict = train .userDict , train .itemDict
367415 test .userFeedback = make ([][]int32 , len (train .userFeedback ))
368416 test .itemFeedback = make ([][]int32 , len (train .itemFeedback ))
417+ test .timestamps = make ([][]time.Time , len (train .userFeedback ))
369418 test .negatives = make ([][]int32 , len (train .userFeedback ))
370419 err = loadTest (test , testFilePath )
371420 if err != nil {
@@ -404,7 +453,7 @@ func loadTrain(path string) (*Dataset, error) {
404453 dataset .AddItem (data.Item {ItemId : util .FormatInt (i )})
405454 }
406455 // add feedback
407- dataset .AddFeedback (fields [0 ], fields [1 ])
456+ dataset .AddFeedback (fields [0 ], fields [1 ], time. Time {} )
408457 }
409458 return dataset , scanner .Err ()
410459}
@@ -429,7 +478,7 @@ func loadTest(dataset *Dataset, path string) error {
429478 positive = positive [1 : len (positive )- 1 ]
430479 fields = strings .Split (positive , "," )
431480 // add feedback
432- dataset .AddFeedback (fields [0 ], fields [1 ])
481+ dataset .AddFeedback (fields [0 ], fields [1 ], time. Time {} )
433482 // add negatives
434483 userId , err := strconv .Atoi (fields [0 ])
435484 if err != nil {
0 commit comments