@@ -79,6 +79,7 @@ type Dataset struct {
7979 itemLabels * Labels
8080 userFeedback [][]int32
8181 itemFeedback [][]int32
82+ timestamps [][]time.Time
8283 negatives [][]int32
8384 userDict * FreqDict
8485 itemDict * FreqDict
@@ -95,6 +96,7 @@ func NewDataset(timestamp time.Time, userCount, itemCount int) *Dataset {
9596 itemLabels : NewLabels (),
9697 userFeedback : make ([][]int32 , userCount ),
9798 itemFeedback : make ([][]int32 , itemCount ),
99+ timestamps : make ([][]time.Time , userCount ),
98100 userDict : NewFreqDict (),
99101 itemDict : NewFreqDict (),
100102 categories : make (map [string ]int ),
@@ -203,6 +205,9 @@ func (d *Dataset) AddUser(user data.User) {
203205 if len (d .userFeedback ) < len (d .users ) {
204206 d .userFeedback = append (d .userFeedback , nil )
205207 }
208+ if len (d .timestamps ) < len (d .users ) {
209+ d .timestamps = append (d .timestamps , nil )
210+ }
206211}
207212
208213func (d * Dataset ) AddItem (item data.Item ) {
@@ -223,11 +228,12 @@ func (d *Dataset) AddItem(item data.Item) {
223228 }
224229}
225230
226- func (d * Dataset ) AddFeedback (userId , itemId string ) {
231+ func (d * Dataset ) AddFeedback (userId , itemId string , timestamp time. Time ) {
227232 userIndex := d .userDict .Add (userId )
228233 itemIndex := d .itemDict .Add (itemId )
229234 d .userFeedback [userIndex ] = append (d .userFeedback [userIndex ], itemIndex )
230235 d .itemFeedback [itemIndex ] = append (d .itemFeedback [itemIndex ], userIndex )
236+ d .timestamps [userIndex ] = append (d .timestamps [userIndex ], timestamp )
231237 d .numFeedback ++
232238}
233239
@@ -253,6 +259,7 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
253259 trainSet .items , testSet .items = d .items , d .items
254260 trainSet .userFeedback , testSet .userFeedback = make ([][]int32 , d .CountUsers ()), make ([][]int32 , d .CountUsers ())
255261 trainSet .itemFeedback , testSet .itemFeedback = make ([][]int32 , d .CountItems ()), make ([][]int32 , d .CountItems ())
262+ trainSet .timestamps , testSet .timestamps = make ([][]time.Time , d .CountUsers ()), make ([][]time.Time , d .CountUsers ())
256263 trainSet .userDict , testSet .userDict = d .userDict , d .userDict
257264 trainSet .itemDict , testSet .itemDict = d .itemDict , d .itemDict
258265 rng := util .NewRandomGenerator (seed )
@@ -262,11 +269,13 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
262269 k := rng .Intn (len (d .userFeedback [userIndex ]))
263270 testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][k ])
264271 testSet .itemFeedback [d.userFeedback [userIndex ][k ]] = append (testSet .itemFeedback [d.userFeedback [userIndex ][k ]], userIndex )
272+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d.timestamps [userIndex ][k ])
265273 testSet .numFeedback ++
266274 for i , itemIndex := range d .userFeedback [userIndex ] {
267275 if i != k {
268276 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
269277 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
278+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][i ])
270279 trainSet .numFeedback ++
271280 }
272281 }
@@ -277,13 +286,15 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
277286 for _ , userIndex := range testUsers {
278287 if len (d .userFeedback [userIndex ]) > 0 {
279288 k := rng .Intn (len (d .userFeedback [userIndex ]))
280- testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][k ])
289+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d.timestamps [userIndex ][k ])
281290 testSet .itemFeedback [d.userFeedback [userIndex ][k ]] = append (testSet .itemFeedback [d.userFeedback [userIndex ][k ]], userIndex )
291+ testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][k ])
282292 testSet .numFeedback ++
283293 for i , itemIndex := range d .userFeedback [userIndex ] {
284294 if i != k {
285295 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
286296 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
297+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][i ])
287298 trainSet .numFeedback ++
288299 }
289300 }
@@ -292,9 +303,10 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
292303 testUserSet := mapset .NewSet (testUsers ... )
293304 for userIndex := int32 (0 ); userIndex < int32 (d .CountUsers ()); userIndex ++ {
294305 if ! testUserSet .Contains (userIndex ) {
295- for _ , itemIndex := range d .userFeedback [userIndex ] {
306+ for idx , itemIndex := range d .userFeedback [userIndex ] {
296307 trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
297308 trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
309+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][idx ])
298310 trainSet .numFeedback ++
299311 }
300312 }
@@ -303,6 +315,39 @@ func (d *Dataset) SplitCF(numTestUsers int, seed int64) (CFSplit, CFSplit) {
303315 return trainSet , testSet
304316}
305317
318+ // SplitLatest splits dataset by moving the most recent feedback of all users into the test set to avoid leakage.
319+ func (d * Dataset ) SplitLatest () (CFSplit , CFSplit ) {
320+ trainSet , testSet := new (Dataset ), new (Dataset )
321+ trainSet .users , testSet .users = d .users , d .users
322+ trainSet .items , testSet .items = d .items , d .items
323+ trainSet .userFeedback , testSet .userFeedback = make ([][]int32 , d .CountUsers ()), make ([][]int32 , d .CountUsers ())
324+ trainSet .itemFeedback , testSet .itemFeedback = make ([][]int32 , d .CountItems ()), make ([][]int32 , d .CountItems ())
325+ trainSet .timestamps , testSet .timestamps = make ([][]time.Time , d .CountUsers ()), make ([][]time.Time , d .CountUsers ())
326+ trainSet .userDict , testSet .userDict = d .userDict , d .userDict
327+ trainSet .itemDict , testSet .itemDict = d .itemDict , d .itemDict
328+ for userIndex := int32 (0 ); userIndex < int32 (d .CountUsers ()); userIndex ++ {
329+ if len (d .userFeedback [userIndex ]) == 0 {
330+ continue
331+ }
332+ _ , latestIdx := lo .MaxIndexBy (d .timestamps [userIndex ], func (a , b time.Time ) bool {
333+ return a .After (b )
334+ })
335+ testSet .timestamps [userIndex ] = append (testSet .timestamps [userIndex ], d.timestamps [userIndex ][latestIdx ])
336+ testSet .itemFeedback [d.userFeedback [userIndex ][latestIdx ]] = append (testSet .itemFeedback [d.userFeedback [userIndex ][latestIdx ]], userIndex )
337+ testSet .userFeedback [userIndex ] = append (testSet .userFeedback [userIndex ], d.userFeedback [userIndex ][latestIdx ])
338+ testSet .numFeedback ++
339+ for i , itemIndex := range d .userFeedback [userIndex ] {
340+ if i != latestIdx {
341+ trainSet .userFeedback [userIndex ] = append (trainSet .userFeedback [userIndex ], itemIndex )
342+ trainSet .itemFeedback [itemIndex ] = append (trainSet .itemFeedback [itemIndex ], userIndex )
343+ trainSet .timestamps [userIndex ] = append (trainSet .timestamps [userIndex ], d.timestamps [userIndex ][i ])
344+ trainSet .numFeedback ++
345+ }
346+ }
347+ }
348+ return trainSet , testSet
349+ }
350+
306351type Labels struct {
307352 fields * strutil.Pool
308353 values * FreqDict
@@ -366,6 +411,7 @@ func LoadDataFromBuiltIn(dataSetName string) (*Dataset, *Dataset, error) {
366411 test .userDict , test .itemDict = train .userDict , train .itemDict
367412 test .userFeedback = make ([][]int32 , len (train .userFeedback ))
368413 test .itemFeedback = make ([][]int32 , len (train .itemFeedback ))
414+ test .timestamps = make ([][]time.Time , len (train .userFeedback ))
369415 test .negatives = make ([][]int32 , len (train .userFeedback ))
370416 err = loadTest (test , testFilePath )
371417 if err != nil {
@@ -404,7 +450,7 @@ func loadTrain(path string) (*Dataset, error) {
404450 dataset .AddItem (data.Item {ItemId : util .FormatInt (i )})
405451 }
406452 // add feedback
407- dataset .AddFeedback (fields [0 ], fields [1 ])
453+ dataset .AddFeedback (fields [0 ], fields [1 ], time. Time {} )
408454 }
409455 return dataset , scanner .Err ()
410456}
@@ -429,7 +475,7 @@ func loadTest(dataset *Dataset, path string) error {
429475 positive = positive [1 : len (positive )- 1 ]
430476 fields = strings .Split (positive , "," )
431477 // add feedback
432- dataset .AddFeedback (fields [0 ], fields [1 ])
478+ dataset .AddFeedback (fields [0 ], fields [1 ], time. Time {} )
433479 // add negatives
434480 userId , err := strconv .Atoi (fields [0 ])
435481 if err != nil {
0 commit comments