@@ -250,6 +250,9 @@ func (s *Scorch) Open() error {
250250 s .asyncTasks .Add (1 )
251251 go s .introducerLoop ()
252252
253+ s .asyncTasks .Add (1 )
254+ go s .trainerLoop ()
255+
253256 if ! s .readOnly && s .path != "" {
254257 s .asyncTasks .Add (1 )
255258 go s .persisterLoop ()
@@ -325,6 +328,7 @@ func (s *Scorch) openBolt() error {
325328 s .persisterNotifier = make (chan * epochWatcher , 1 )
326329 s .closeCh = make (chan struct {})
327330 s .forceMergeRequestCh = make (chan * mergerCtrl , 1 )
331+ s .train = make (chan * trainRequest )
328332
329333 if ! s .readOnly && s .path != "" {
330334 err := s .removeOldZapFiles () // Before persister or merger create any new files.
@@ -534,9 +538,21 @@ func (s *Scorch) getInternal(key []byte) ([]byte, error) {
534538 return nil , nil
535539}
536540
541+ func moveFile (sourcePath , destPath string ) error {
542+ // rename is supposed to be atomic on the same filesystem
543+ err := os .Rename (sourcePath , destPath )
544+ if err != nil {
545+ return fmt .Errorf ("error renaming file: %v" , err )
546+ }
547+ return nil
548+ }
549+
537550// this is not a routine that will be running throughout the lifetime of the index. It's purpose
538551// is to only train the vector index before the data ingestion starts.
539552func (s * Scorch ) trainerLoop () {
553+ defer func () {
554+ s .asyncTasks .Done ()
555+ }()
540556 // some init stuff
541557 s .segmentConfig ["getCentroidIndexCallback" ] = s .getCentroidIndex
542558 var totalSamplesProcessed int
@@ -550,41 +566,42 @@ func (s *Scorch) trainerLoop() {
550566 case trainReq := <- s .train :
551567 sampleSeg := trainReq .sample
552568 if s .centroidIndex == nil {
553- // new centroid index
554- s .centroidIndex = & SegmentSnapshot {
555- segment : sampleSeg ,
556- }
557569 switch seg := sampleSeg .(type ) {
558570 case segment.UnpersistedSegment :
559571 err := persistToDirectory (seg , nil , path )
560572 if err != nil {
561573 // clean up this ugly ass error handling code
562574 trainReq .ackCh <- fmt .Errorf ("error persisting segment: %v" , err )
563575 close (trainReq .ackCh )
576+ return
564577 }
565578 default :
566579 fmt .Errorf ("segment is not a unpersisted segment" )
567580 close (s .closeCh )
581+ return
568582 }
569583 } else {
570584 // merge the new segment with the existing one, no need to persist?
571585 // persist in a tmp file and then rename - is that a fair strategy?
586+ fmt .Println ("merging centroid index" )
587+ s .segmentConfig ["training" ] = true
572588 _ , _ , err := s .segPlugin .MergeEx ([]segment.Segment {s .centroidIndex .segment , sampleSeg },
573- []* roaring.Bitmap {nil , nil }, "centroid_index.tmp" , s .closeCh , nil , s .segmentConfig )
589+ []* roaring.Bitmap {nil , nil }, filepath . Join ( s . path , "centroid_index.tmp" ) , s .closeCh , nil , s .segmentConfig )
574590 if err != nil {
575591 trainReq .ackCh <- fmt .Errorf ("error merging centroid index: %v" , err )
576592 close (trainReq .ackCh )
577593 }
594+ // reset the training flag once completed
595+ s .segmentConfig ["training" ] = false
578596
579597 // close the existing centroid segment - it's supposed to be gc'd at this point
580598 s .centroidIndex .segment .Close ()
581- err = os . Rename (filepath .Join (s .path , "centroid_index.tmp" ), filepath .Join (s .path , "centroid_index" ))
599+ err = moveFile (filepath .Join (s .path , "centroid_index.tmp" ), filepath .Join (s .path , "centroid_index" ))
582600 if err != nil {
583601 trainReq .ackCh <- fmt .Errorf ("error renaming centroid index: %v" , err )
584602 close (trainReq .ackCh )
585603 }
586604 }
587-
588605 totalSamplesProcessed += trainReq .vecCount
589606 // a bolt transaction is necessary for failover-recovery scenario and also serves as a checkpoint
590607 // where we can be sure that the centroid index is available for the indexing operations downstream
@@ -596,25 +613,29 @@ func (s *Scorch) trainerLoop() {
596613 if err != nil {
597614 trainReq .ackCh <- fmt .Errorf ("error starting bolt transaction: %v" , err )
598615 close (trainReq .ackCh )
616+ return
599617 }
600618 defer tx .Rollback ()
601619
602620 snapshotsBucket , err := tx .CreateBucketIfNotExists (util .BoltSnapshotsBucket )
603621 if err != nil {
604622 trainReq .ackCh <- fmt .Errorf ("error creating snapshots bucket: %v" , err )
605623 close (trainReq .ackCh )
624+ return
606625 }
607626
608627 centroidBucket , err := snapshotsBucket .CreateBucketIfNotExists (util .BoltCentroidIndexKey )
609628 if err != nil {
610629 trainReq .ackCh <- fmt .Errorf ("error creating centroid bucket: %v" , err )
611630 close (trainReq .ackCh )
631+ return
612632 }
613633
614634 err = centroidBucket .Put (util .BoltPathKey , []byte (filename ))
615635 if err != nil {
616636 trainReq .ackCh <- fmt .Errorf ("error updating centroid bucket: %v" , err )
617637 close (trainReq .ackCh )
638+ return
618639 }
619640
620641 // total number of vectors that have been processed so far for the training
@@ -623,14 +644,25 @@ func (s *Scorch) trainerLoop() {
623644 if err != nil {
624645 trainReq .ackCh <- fmt .Errorf ("error updating vec samples processed: %v" , err )
625646 close (trainReq .ackCh )
647+ return
626648 }
627649
628650 err = tx .Commit ()
629651 if err != nil {
630652 trainReq .ackCh <- fmt .Errorf ("error committing bolt transaction: %v" , err )
631653 close (trainReq .ackCh )
654+ return
632655 }
633656
657+ centroidIndex , err := s .segPlugin .OpenEx (filepath .Join (s .path , "centroid_index" ), s .segmentConfig )
658+ if err != nil {
659+ trainReq .ackCh <- fmt .Errorf ("error opening centroid index: %v" , err )
660+ close (trainReq .ackCh )
661+ return
662+ }
663+ s .centroidIndex = & SegmentSnapshot {
664+ segment : centroidIndex ,
665+ }
634666 close (trainReq .ackCh )
635667 }
636668 }
@@ -640,25 +672,20 @@ func (s *Scorch) Train(batch *index.Batch) error {
640672 // regulate the Train function
641673 s .FireIndexEvent ()
642674
643- // is the lock really needed?
644- s .rootLock .Lock ()
645- defer s .rootLock .Unlock ()
646- if s .centroidIndex != nil {
647- // singleton API
648- return nil
649- }
675+ // // is the lock really needed?
676+ // s.rootLock.Lock()
677+ // defer s.rootLock.Unlock()
678+
650679 var trainData []index.Document
651- if s .centroidIndex == nil {
652- for key , doc := range batch .IndexOps {
653- if doc != nil {
654- // insert _id field
655- // no need to track updates/deletes over here since
656- // the API is singleton
657- doc .AddIDField ()
658- }
659- if strings .HasPrefix (key , index .TrainDataPrefix ) {
660- trainData = append (trainData , doc )
661- }
680+ for key , doc := range batch .IndexOps {
681+ if doc != nil {
682+ // insert _id field
683+ // no need to track updates/deletes over here since
684+ // the API is singleton
685+ doc .AddIDField ()
686+ }
687+ if strings .HasPrefix (key , index .TrainDataPrefix ) {
688+ trainData = append (trainData , doc )
662689 }
663690 }
664691
@@ -670,13 +697,10 @@ func (s *Scorch) Train(batch *index.Batch) error {
670697 //
671698 // note: this might index text data too, how to handle this? s.segmentConfig?
672699 // todo: updates/deletes -> data drift detection
673- s .segmentConfig ["training" ] = true
674700 seg , n , err := s .segPlugin .NewEx (trainData , s .segmentConfig )
675701 if err != nil {
676702 return err
677703 }
678- // reset the training flag once completed
679- s .segmentConfig ["training" ] = false
680704
681705 trainReq := & trainRequest {
682706 sample : seg ,
@@ -687,6 +711,7 @@ func (s *Scorch) Train(batch *index.Batch) error {
687711 s .train <- trainReq
688712 err = <- trainReq .ackCh
689713 if err != nil {
714+ fmt .Println ("error training" , err )
690715 return err
691716 }
692717
@@ -697,6 +722,10 @@ func (s *Scorch) Train(batch *index.Batch) error {
697722 s .centroidIndex = & SegmentSnapshot {
698723 segment : centroidIndex ,
699724 }
725+ _ , err = s .getCentroidIndex ("emb" )
726+ if err != nil {
727+ return err
728+ }
700729 fmt .Println ("number of bytes written to centroid index" , n )
701730 return err
702731}
0 commit comments