Skip to content

Commit e83067e

Browse files
committed
Handle NaN id document
1 parent de0fc86 commit e83067e

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

plm/copy.go

Lines changed: 62 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"math"
77
"runtime"
8+
"strings"
89
"sync"
910
"sync/atomic"
1011
"time"
@@ -204,6 +205,12 @@ func (cm *CopyManager) copyCollection(
204205
isCapped, _ := spec.Options.Lookup("capped").BooleanOK()
205206

206207
var nextSegment nextSegmentFunc
208+
209+
readResultC := make(chan readBatchResult)
210+
211+
var batchID atomic.Uint32
212+
var nextID nextBatchIDFunc = func() uint32 { return batchID.Add(1) }
213+
207214
if isCapped { //nolint:nestif
208215
segmenter, err := NewCappedSegmenter(ctx,
209216
cm.source, namespace, cm.options.ReadBatchSizeBytes)
@@ -234,13 +241,14 @@ func (cm *CopyManager) copyCollection(
234241
}
235242

236243
nextSegment = segmenter.Next
244+
245+
go segmenter.handleNanDoc(readResultC, nextID)
237246
}
238247

239248
collectionReadCtx, stopCollectionRead := context.WithCancel(ctx)
240249

241250
// pendingSegments tracks in-progress read segments
242251
pendingSegments := &sync.WaitGroup{}
243-
readResultC := make(chan readBatchResult)
244252

245253
allBatchesSent := make(chan struct{}) // closes when all batches are sent to inserters
246254

@@ -260,8 +268,6 @@ func (cm *CopyManager) copyCollection(
260268
// spawn readSegment in loop until the collection is exhausted or canceled.
261269
go func() {
262270
var segmentID uint32
263-
var batchID atomic.Uint32
264-
var nextID nextBatchIDFunc = func() uint32 { return batchID.Add(1) }
265271

266272
readStopped := collectionReadCtx.Done()
267273

@@ -297,6 +303,7 @@ func (cm *CopyManager) copyCollection(
297303
}
298304

299305
pendingSegments.Add(1)
306+
300307
go func() {
301308
defer func() {
302309
<-cm.readLimit
@@ -560,6 +567,7 @@ type Segmenter struct {
560567
batchSize int32
561568
keyRanges []keyRange
562569
currIDRange keyRange
570+
nanDoc bson.Raw // document with NaN _id, if any
563571
}
564572

565573
type keyRange struct {
@@ -625,7 +633,7 @@ func NewSegmenter(
625633

626634
mcoll := m.Database(ns.Database).Collection(ns.Collection)
627635

628-
idKeyRange, err := getIDKeyRange(ctx, mcoll)
636+
idKeyRange, nanDoc, err := getIDKeyRange(ctx, mcoll)
629637
if err != nil {
630638
if errors.Is(err, mongo.ErrNoDocuments) {
631639
return nil, errEOC // empty collection
@@ -640,6 +648,7 @@ func NewSegmenter(
640648
segmentSize: segmentSize,
641649
batchSize: batchSize,
642650
currIDRange: idKeyRange,
651+
nanDoc: *nanDoc,
643652
}
644653

645654
return s, nil
@@ -770,29 +779,73 @@ func (seg *Segmenter) findSegmentMaxKey(
770779
return raw.Lookup("_id"), nil
771780
}
772781

782+
// handleNanDoc sends a document with NaN _id to the readResultC channel if it exists.
783+
func (seg *Segmenter) handleNanDoc(
784+
readResults chan<- readBatchResult,
785+
nextID nextBatchIDFunc,
786+
) {
787+
if len(seg.nanDoc) == 0 {
788+
return
789+
}
790+
791+
readResults <- readBatchResult{
792+
ID: nextID(),
793+
Documents: []any{seg.nanDoc},
794+
SizeBytes: len(seg.nanDoc),
795+
}
796+
}
797+
773798
// getIDKeyRange returns the minimum and maximum _id values in the collection.
774799
// It uses two FindOne operations with sort directions of 1 (ascending) and -1 (descending)
775800
// to determine the full _id range. This is used to define the collection boundaries
776801
// when the _id type is uniform across all documents.
777-
func getIDKeyRange(ctx context.Context, mcoll *mongo.Collection) (keyRange, error) {
802+
func getIDKeyRange(ctx context.Context, mcoll *mongo.Collection) (keyRange, *bson.Raw, error) {
778803
findOptions := options.FindOne().SetSort(bson.D{{"_id", 1}}).SetProjection(bson.D{{"_id", 1}})
804+
779805
minRaw, err := mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
780806
if err != nil {
781-
return keyRange{}, errors.Wrap(err, "min _id")
807+
return keyRange{}, nil, errors.Wrap(err, "min _id")
808+
}
809+
810+
nanDoc := bson.Raw{}
811+
812+
if strings.Contains(minRaw.Lookup("_id").DebugString(), "NaN") {
813+
nanDoc = minRaw
814+
815+
findOptions = options.FindOne().SetSort(bson.D{{"_id", 1}}).
816+
SetProjection(bson.D{{"_id", 1}}).SetSkip(1)
817+
818+
minRaw, err = mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
819+
if err != nil {
820+
return keyRange{}, nil, errors.Wrap(err, "min _id (next document)")
821+
}
782822
}
783823

784824
findOptions = options.FindOne().SetSort(bson.D{{"_id", -1}}).SetProjection(bson.D{{"_id", 1}})
825+
785826
maxRaw, err := mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
786827
if err != nil {
787-
return keyRange{}, errors.Wrap(err, "max _id")
828+
return keyRange{}, nil, errors.Wrap(err, "max _id")
829+
}
830+
831+
if strings.Contains(maxRaw.Lookup("_id").DebugString(), "NaN") {
832+
nanDoc = maxRaw
833+
834+
findOptions = options.FindOne().SetSort(bson.D{{"_id", -1}}).
835+
SetProjection(bson.D{{"_id", 1}}).SetSkip(1)
836+
837+
maxRaw, err = mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
838+
if err != nil {
839+
return keyRange{}, nil, errors.Wrap(err, "min _id (next document)")
840+
}
788841
}
789842

790843
ret := keyRange{
791844
Min: minRaw.Lookup("_id"),
792845
Max: maxRaw.Lookup("_id"),
793846
}
794847

795-
return ret, nil
848+
return ret, &nanDoc, nil
796849
}
797850

798851
// getIDKeyRangeByType returns a slice of keyRange grouped by the BSON type of the _id field.
@@ -812,6 +865,7 @@ func getIDKeyRangeByType(ctx context.Context, mcoll *mongo.Collection) ([]keyRan
812865
}
813866

814867
var segmentRanges []keyRange
868+
815869
err = cur.All(ctx, &segmentRanges)
816870
if err != nil {
817871
return nil, errors.Wrap(err, "all")

tests/test_collections.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from plm import PLM, Runner
99
from pymongo import MongoClient
1010
from testing import Testing
11+
from bson.decimal128 import Decimal128
1112

1213

1314
def ensure_collection(source: MongoClient, target: MongoClient, db: str, coll: str, **kwargs):
@@ -637,6 +638,28 @@ def test_plm_126_clone_with_nan_id_document(t: Testing):
637638
[{"_id": random.uniform(1e5, 1e10), "i": i} for i in range(50)]
638639
)
639640

641+
with t.run(phase=Runner.Phase.CLONE) as r:
642+
r.start()
643+
r.wait_for_clone_completed()
644+
645+
sourceDocCount = t.source["db_1"]["coll_1"].count_documents({})
646+
targetDocCount = t.target["db_1"]["coll_1"].count_documents({})
647+
assert sourceDocCount == targetDocCount
648+
649+
650+
@pytest.mark.skip(reason="Clone with NaN _id is not supported for multi-id types")
651+
def test_clone_with_nan_id_document_multi_id_types(t: Testing):
652+
t.source["db_1"]["coll_1"].insert_one({"_id": Decimal128("NaN"), "i": 200})
653+
t.source["db_1"]["coll_1"].insert_many(
654+
[{"_id": random.uniform(1e5, 1e10), "i": i} for i in range(50)]
655+
)
656+
t.source["db_1"]["coll_1"].insert_many(
657+
[{"_id": Decimal128(str(random.uniform(1e5, 1e10))), "i": i} for i in range(50)]
658+
)
659+
t.source["db_1"]["coll_1"].insert_many(
660+
[{"_id": "inel" + str(random.uniform(1e5, 1e10)), "i": i} for i in range(50)]
661+
)
662+
640663
with t.run(phase=Runner.Phase.CLONE) as r:
641664
r.start()
642665
r.wait_for_clone_completed()

0 commit comments

Comments
 (0)