Skip to content

Commit 31c2cdb

Browse files
committed
incremental training bug fixes + better recoverability
1 parent 45c2ee9 commit 31c2cdb

File tree

1 file changed

+57
-28
lines changed

1 file changed

+57
-28
lines changed

index/scorch/scorch.go

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ func (s *Scorch) Open() error {
250250
s.asyncTasks.Add(1)
251251
go s.introducerLoop()
252252

253+
s.asyncTasks.Add(1)
254+
go s.trainerLoop()
255+
253256
if !s.readOnly && s.path != "" {
254257
s.asyncTasks.Add(1)
255258
go s.persisterLoop()
@@ -325,6 +328,7 @@ func (s *Scorch) openBolt() error {
325328
s.persisterNotifier = make(chan *epochWatcher, 1)
326329
s.closeCh = make(chan struct{})
327330
s.forceMergeRequestCh = make(chan *mergerCtrl, 1)
331+
s.train = make(chan *trainRequest)
328332

329333
if !s.readOnly && s.path != "" {
330334
err := s.removeOldZapFiles() // Before persister or merger create any new files.
@@ -534,9 +538,21 @@ func (s *Scorch) getInternal(key []byte) ([]byte, error) {
534538
return nil, nil
535539
}
536540

541+
func moveFile(sourcePath, destPath string) error {
542+
// rename is supposed to be atomic on the same filesystem
543+
err := os.Rename(sourcePath, destPath)
544+
if err != nil {
545+
return fmt.Errorf("error renaming file: %v", err)
546+
}
547+
return nil
548+
}
549+
537550
// this is not a routine that will be running throughout the lifetime of the index. It's purpose
538551
// is to only train the vector index before the data ingestion starts.
539552
func (s *Scorch) trainerLoop() {
553+
defer func() {
554+
s.asyncTasks.Done()
555+
}()
540556
// some init stuff
541557
s.segmentConfig["getCentroidIndexCallback"] = s.getCentroidIndex
542558
var totalSamplesProcessed int
@@ -550,41 +566,42 @@ func (s *Scorch) trainerLoop() {
550566
case trainReq := <-s.train:
551567
sampleSeg := trainReq.sample
552568
if s.centroidIndex == nil {
553-
// new centroid index
554-
s.centroidIndex = &SegmentSnapshot{
555-
segment: sampleSeg,
556-
}
557569
switch seg := sampleSeg.(type) {
558570
case segment.UnpersistedSegment:
559571
err := persistToDirectory(seg, nil, path)
560572
if err != nil {
561573
// clean up this ugly ass error handling code
562574
trainReq.ackCh <- fmt.Errorf("error persisting segment: %v", err)
563575
close(trainReq.ackCh)
576+
return
564577
}
565578
default:
566579
fmt.Errorf("segment is not a unpersisted segment")
567580
close(s.closeCh)
581+
return
568582
}
569583
} else {
570584
// merge the new segment with the existing one, no need to persist?
571585
// persist in a tmp file and then rename - is that a fair strategy?
586+
fmt.Println("merging centroid index")
587+
s.segmentConfig["training"] = true
572588
_, _, err := s.segPlugin.MergeEx([]segment.Segment{s.centroidIndex.segment, sampleSeg},
573-
[]*roaring.Bitmap{nil, nil}, "centroid_index.tmp", s.closeCh, nil, s.segmentConfig)
589+
[]*roaring.Bitmap{nil, nil}, filepath.Join(s.path, "centroid_index.tmp"), s.closeCh, nil, s.segmentConfig)
574590
if err != nil {
575591
trainReq.ackCh <- fmt.Errorf("error merging centroid index: %v", err)
576592
close(trainReq.ackCh)
577593
}
594+
// reset the training flag once completed
595+
s.segmentConfig["training"] = false
578596

579597
// close the existing centroid segment - it's supposed to be gc'd at this point
580598
s.centroidIndex.segment.Close()
581-
err = os.Rename(filepath.Join(s.path, "centroid_index.tmp"), filepath.Join(s.path, "centroid_index"))
599+
err = moveFile(filepath.Join(s.path, "centroid_index.tmp"), filepath.Join(s.path, "centroid_index"))
582600
if err != nil {
583601
trainReq.ackCh <- fmt.Errorf("error renaming centroid index: %v", err)
584602
close(trainReq.ackCh)
585603
}
586604
}
587-
588605
totalSamplesProcessed += trainReq.vecCount
589606
// a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint
590607
// where we can be sure that the centroid index is available for the indexing operations downstream
@@ -596,25 +613,29 @@ func (s *Scorch) trainerLoop() {
596613
if err != nil {
597614
trainReq.ackCh <- fmt.Errorf("error starting bolt transaction: %v", err)
598615
close(trainReq.ackCh)
616+
return
599617
}
600618
defer tx.Rollback()
601619

602620
snapshotsBucket, err := tx.CreateBucketIfNotExists(util.BoltSnapshotsBucket)
603621
if err != nil {
604622
trainReq.ackCh <- fmt.Errorf("error creating snapshots bucket: %v", err)
605623
close(trainReq.ackCh)
624+
return
606625
}
607626

608627
centroidBucket, err := snapshotsBucket.CreateBucketIfNotExists(util.BoltCentroidIndexKey)
609628
if err != nil {
610629
trainReq.ackCh <- fmt.Errorf("error creating centroid bucket: %v", err)
611630
close(trainReq.ackCh)
631+
return
612632
}
613633

614634
err = centroidBucket.Put(util.BoltPathKey, []byte(filename))
615635
if err != nil {
616636
trainReq.ackCh <- fmt.Errorf("error updating centroid bucket: %v", err)
617637
close(trainReq.ackCh)
638+
return
618639
}
619640

620641
// total number of vectors that have been processed so far for the training
@@ -623,14 +644,25 @@ func (s *Scorch) trainerLoop() {
623644
if err != nil {
624645
trainReq.ackCh <- fmt.Errorf("error updating vec samples processed: %v", err)
625646
close(trainReq.ackCh)
647+
return
626648
}
627649

628650
err = tx.Commit()
629651
if err != nil {
630652
trainReq.ackCh <- fmt.Errorf("error committing bolt transaction: %v", err)
631653
close(trainReq.ackCh)
654+
return
632655
}
633656

657+
centroidIndex, err := s.segPlugin.OpenEx(filepath.Join(s.path, "centroid_index"), s.segmentConfig)
658+
if err != nil {
659+
trainReq.ackCh <- fmt.Errorf("error opening centroid index: %v", err)
660+
close(trainReq.ackCh)
661+
return
662+
}
663+
s.centroidIndex = &SegmentSnapshot{
664+
segment: centroidIndex,
665+
}
634666
close(trainReq.ackCh)
635667
}
636668
}
@@ -640,25 +672,20 @@ func (s *Scorch) Train(batch *index.Batch) error {
640672
// regulate the Train function
641673
s.FireIndexEvent()
642674

643-
// is the lock really needed?
644-
s.rootLock.Lock()
645-
defer s.rootLock.Unlock()
646-
if s.centroidIndex != nil {
647-
// singleton API
648-
return nil
649-
}
675+
// // is the lock really needed?
676+
// s.rootLock.Lock()
677+
// defer s.rootLock.Unlock()
678+
650679
var trainData []index.Document
651-
if s.centroidIndex == nil {
652-
for key, doc := range batch.IndexOps {
653-
if doc != nil {
654-
// insert _id field
655-
// no need to track updates/deletes over here since
656-
// the API is singleton
657-
doc.AddIDField()
658-
}
659-
if strings.HasPrefix(key, index.TrainDataPrefix) {
660-
trainData = append(trainData, doc)
661-
}
680+
for key, doc := range batch.IndexOps {
681+
if doc != nil {
682+
// insert _id field
683+
// no need to track updates/deletes over here since
684+
// the API is singleton
685+
doc.AddIDField()
686+
}
687+
if strings.HasPrefix(key, index.TrainDataPrefix) {
688+
trainData = append(trainData, doc)
662689
}
663690
}
664691

@@ -670,13 +697,10 @@ func (s *Scorch) Train(batch *index.Batch) error {
670697
//
671698
// note: this might index text data too, how to handle this? s.segmentConfig?
672699
// todo: updates/deletes -> data drift detection
673-
s.segmentConfig["training"] = true
674700
seg, n, err := s.segPlugin.NewEx(trainData, s.segmentConfig)
675701
if err != nil {
676702
return err
677703
}
678-
// reset the training flag once completed
679-
s.segmentConfig["training"] = false
680704

681705
trainReq := &trainRequest{
682706
sample: seg,
@@ -687,6 +711,7 @@ func (s *Scorch) Train(batch *index.Batch) error {
687711
s.train <- trainReq
688712
err = <-trainReq.ackCh
689713
if err != nil {
714+
fmt.Println("error training", err)
690715
return err
691716
}
692717

@@ -697,6 +722,10 @@ func (s *Scorch) Train(batch *index.Batch) error {
697722
s.centroidIndex = &SegmentSnapshot{
698723
segment: centroidIndex,
699724
}
725+
_, err = s.getCentroidIndex("emb")
726+
if err != nil {
727+
return err
728+
}
700729
fmt.Println("number of bytes written to centroid index", n)
701730
return err
702731
}

0 commit comments

Comments
 (0)