Skip to content

Commit b827ee4

Browse files
authored
PLM-126: Handle collections with NaN ID document (#112)
1 parent b11eb5c commit b827ee4

File tree

2 files changed

+130
-29
lines changed

2 files changed

+130
-29
lines changed

plm/copy.go

Lines changed: 92 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"fmt"
66
"math"
77
"runtime"
8+
"strings"
89
"sync"
910
"sync/atomic"
1011
"time"
@@ -204,6 +205,12 @@ func (cm *CopyManager) copyCollection(
204205
isCapped, _ := spec.Options.Lookup("capped").BooleanOK()
205206

206207
var nextSegment nextSegmentFunc
208+
209+
readResultC := make(chan readBatchResult)
210+
211+
var batchID atomic.Uint32
212+
var nextID nextBatchIDFunc = func() uint32 { return batchID.Add(1) }
213+
207214
if isCapped { //nolint:nestif
208215
segmenter, err := NewCappedSegmenter(ctx,
209216
cm.source, namespace, cm.options.ReadBatchSizeBytes)
@@ -234,13 +241,14 @@ func (cm *CopyManager) copyCollection(
234241
}
235242

236243
nextSegment = segmenter.Next
244+
245+
go segmenter.handleNanIDDoc(readResultC, nextID)
237246
}
238247

239248
collectionReadCtx, stopCollectionRead := context.WithCancel(ctx)
240249

241250
// pendingSegments tracks in-progress read segments
242251
pendingSegments := &sync.WaitGroup{}
243-
readResultC := make(chan readBatchResult)
244252

245253
allBatchesSent := make(chan struct{}) // closes when all batches are sent to inserters
246254

@@ -260,8 +268,6 @@ func (cm *CopyManager) copyCollection(
260268
// spawn readSegment in loop until the collection is exhausted or canceled.
261269
go func() {
262270
var segmentID uint32
263-
var batchID atomic.Uint32
264-
var nextID nextBatchIDFunc = func() uint32 { return batchID.Add(1) }
265271

266272
readStopped := collectionReadCtx.Done()
267273

@@ -297,6 +303,7 @@ func (cm *CopyManager) copyCollection(
297303
}
298304

299305
pendingSegments.Add(1)
306+
300307
go func() {
301308
defer func() {
302309
<-cm.readLimit
@@ -560,6 +567,7 @@ type Segmenter struct {
560567
batchSize int32
561568
keyRanges []keyRange
562569
currIDRange keyRange
570+
nanDoc bson.Raw // document with NaN _id, if any
563571
}
564572

565573
type keyRange struct {
@@ -625,7 +633,7 @@ func NewSegmenter(
625633

626634
mcoll := m.Database(ns.Database).Collection(ns.Collection)
627635

628-
idKeyRange, err := getIDKeyRange(ctx, mcoll)
636+
idKeyRange, nanDoc, err := getIDKeyRange(ctx, mcoll)
629637
if err != nil {
630638
if errors.Is(err, mongo.ErrNoDocuments) {
631639
return nil, errEOC // empty collection
@@ -640,29 +648,31 @@ func NewSegmenter(
640648
segmentSize: segmentSize,
641649
batchSize: batchSize,
642650
currIDRange: idKeyRange,
651+
nanDoc: *nanDoc,
643652
}
644653

645654
return s, nil
646655
}
647656

648-
keyRangeByType, err := getIDKeyRangeByType(ctx, mcoll)
657+
multiTypeIDkeyRanges, err := getMultiTypeIDKeyRanges(ctx, mcoll)
649658
if err != nil {
650659
return nil, errors.Wrap(err, "get ID key range by type")
651660
}
652661

653-
if len(keyRangeByType) == 0 {
662+
if len(multiTypeIDkeyRanges) == 0 {
654663
return nil, errEOC // empty collection
655664
}
656665

657-
currIDRange := keyRangeByType[0]
658-
keyRanges := keyRangeByType[1:]
666+
currIDRange := multiTypeIDkeyRanges[0]
667+
remainingKeyRanges := multiTypeIDkeyRanges[1:]
659668

660669
s := &Segmenter{
661670
mcoll: mcoll,
662671
segmentSize: segmentSize,
663672
batchSize: batchSize,
664-
keyRanges: keyRanges,
673+
keyRanges: remainingKeyRanges,
665674
currIDRange: currIDRange,
675+
nanDoc: *nanDoc,
666676
}
667677

668678
return s, nil
@@ -770,54 +780,107 @@ func (seg *Segmenter) findSegmentMaxKey(
770780
return raw.Lookup("_id"), nil
771781
}
772782

783+
// handleNanIDDoc sends a document with NaN _id to the readResultC channel if it exists.
784+
func (seg *Segmenter) handleNanIDDoc(
785+
readResults chan<- readBatchResult,
786+
nextID nextBatchIDFunc,
787+
) {
788+
if len(seg.nanDoc) == 0 {
789+
return
790+
}
791+
792+
readResults <- readBatchResult{
793+
ID: nextID(),
794+
Documents: []any{seg.nanDoc},
795+
SizeBytes: len(seg.nanDoc),
796+
}
797+
}
798+
773799
// getIDKeyRange returns the minimum and maximum _id values in the collection.
774800
// It uses two FindOne operations with sort directions of 1 (ascending) and -1 (descending)
775801
// to determine the full _id range. This is used to define the collection boundaries
776802
// when the _id type is uniform across all documents.
777-
func getIDKeyRange(ctx context.Context, mcoll *mongo.Collection) (keyRange, error) {
778-
findOptions := options.FindOne().SetSort(bson.D{{"_id", 1}}).SetProjection(bson.D{{"_id", 1}})
779-
minRaw, err := mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
803+
func getIDKeyRange(ctx context.Context, mcoll *mongo.Collection) (keyRange, *bson.Raw, error) {
804+
minIDOptions := options.FindOne().SetSort(bson.D{{"_id", 1}}).SetProjection(bson.D{{"_id", 1}})
805+
806+
minRaw, err := mcoll.FindOne(ctx, bson.D{}, minIDOptions).Raw()
780807
if err != nil {
781-
return keyRange{}, errors.Wrap(err, "min _id")
808+
return keyRange{}, nil, errors.Wrap(err, "min _id")
809+
}
810+
811+
nanDoc := bson.Raw{}
812+
813+
if strings.Contains(minRaw.Lookup("_id").DebugString(), "NaN") {
814+
nanDoc = minRaw
815+
816+
minRaw, err = mcoll.FindOne(ctx, bson.D{}, minIDOptions.SetSkip(1)).Raw()
817+
if err != nil {
818+
return keyRange{}, nil, errors.Wrap(err, "min _id (skip NaN)")
819+
}
782820
}
783821

784-
findOptions = options.FindOne().SetSort(bson.D{{"_id", -1}}).SetProjection(bson.D{{"_id", 1}})
785-
maxRaw, err := mcoll.FindOne(ctx, bson.D{}, findOptions).Raw()
822+
maxIDOptions := options.FindOne().SetSort(bson.D{{"_id", -1}}).SetProjection(bson.D{{"_id", 1}})
823+
824+
maxRaw, err := mcoll.FindOne(ctx, bson.D{}, maxIDOptions).Raw()
786825
if err != nil {
787-
return keyRange{}, errors.Wrap(err, "max _id")
826+
return keyRange{}, nil, errors.Wrap(err, "max _id")
827+
}
828+
829+
if strings.Contains(maxRaw.Lookup("_id").DebugString(), "NaN") {
830+
nanDoc = maxRaw
831+
832+
maxRaw, err = mcoll.FindOne(ctx, bson.D{}, maxIDOptions.SetSkip(1)).Raw()
833+
if err != nil {
834+
return keyRange{}, nil, errors.Wrap(err, "max _id (skip NaN)")
835+
}
788836
}
789837

790838
ret := keyRange{
791839
Min: minRaw.Lookup("_id"),
792840
Max: maxRaw.Lookup("_id"),
793841
}
794842

795-
return ret, nil
843+
return ret, &nanDoc, nil
796844
}
797845

798-
// getIDKeyRangeByType returns a slice of keyRange grouped by the BSON type of the _id field.
846+
// getMultiTypeIDKeyRanges returns a slice of keyRange grouped by the BSON type of the _id field.
799847
// It performs an aggregation that groups documents by _id type, computing the min and max _id
800848
// for each group. This allows the Segmenter to handle collections with heterogeneous _id types
801849
// by processing each type range sequentially.
802-
func getIDKeyRangeByType(ctx context.Context, mcoll *mongo.Collection) ([]keyRange, error) {
803-
cur, err := mcoll.Aggregate(ctx, mongo.Pipeline{
804-
bson.D{{"$group", bson.D{
805-
{"_id", bson.D{{"type", bson.D{{"$type", "$_id"}}}}},
806-
{"minKey", bson.D{{"$min", "$_id"}}},
807-
{"maxKey", bson.D{{"$max", "$_id"}}},
808-
}}},
809-
})
850+
func getMultiTypeIDKeyRanges(ctx context.Context, mcoll *mongo.Collection) ([]keyRange, error) {
851+
cur, err := mcoll.Aggregate(ctx,
852+
mongo.Pipeline{
853+
// Match only numeric types that are not NaN
854+
bson.D{{"$match", bson.D{
855+
{"$expr", bson.D{
856+
// Only allow if _id is not NaN
857+
{"$ne", bson.A{"$_id", bson.D{{"$literal", math.NaN()}}}},
858+
}},
859+
}}},
860+
// Group by type and find min/max
861+
bson.D{{"$group", bson.D{
862+
{"_id", bson.D{{"type", bson.D{{"$type", "$_id"}}}}},
863+
{"minKey", bson.D{{"$min", "$_id"}}},
864+
{"maxKey", bson.D{{"$max", "$_id"}}},
865+
}}},
866+
})
810867
if err != nil {
811868
return nil, errors.Wrap(err, "query")
812869
}
813870

814-
var segmentRanges []keyRange
815-
err = cur.All(ctx, &segmentRanges)
871+
var keyRanges []keyRange
872+
873+
err = cur.All(ctx, &keyRanges)
816874
if err != nil {
817875
return nil, errors.Wrap(err, "all")
818876
}
819877

820-
return segmentRanges, nil
878+
for i := range keyRanges {
879+
log.Ctx(ctx).Debugf("Keyrange %d: type: %s, range [%v <=> %v]", i+1,
880+
keyRanges[i].Min.Type.String(), keyRanges[i].Min, keyRanges[i].Max)
881+
}
882+
883+
return keyRanges, nil
821884
}
822885

823886
// CappedSegmenter provides sequential cursor access for capped collections.

tests/test_collections.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from plm import PLM, Runner
99
from pymongo import MongoClient
1010
from testing import Testing
11+
from bson.decimal128 import Decimal128
1112

1213

1314
def ensure_collection(source: MongoClient, target: MongoClient, db: str, coll: str, **kwargs):
@@ -589,6 +590,7 @@ def test_plm_109_rename_complex(t: Testing, phase: Runner.Phase):
589590

590591
t.compare_all()
591592

593+
592594
@pytest.mark.timeout(30)
593595
def test_plm_110_rename_during_clone_and_repl(t: Testing):
594596
payload = random.randbytes(1000)
@@ -628,3 +630,39 @@ def test_plm_110_rename_during_clone_and_repl(t: Testing):
628630
t.source[db][coll].insert_many({"payload": payload} for _ in range(500))
629631

630632
t.compare_all()
633+
634+
635+
def test_clone_with_nan_id_document(t: Testing):
636+
t.source["db_1"]["coll_1"].insert_one({"_id": float("nan"), "i": 100})
637+
t.source["db_1"]["coll_1"].insert_many(
638+
[{"_id": random.uniform(1e5, 1e10), "i": i} for i in range(50)]
639+
)
640+
641+
with t.run(phase=Runner.Phase.MANUAL) as r:
642+
r.start()
643+
r.wait_for_clone_completed()
644+
645+
sourceDocCount = t.source["db_1"]["coll_1"].count_documents({})
646+
targetDocCount = t.target["db_1"]["coll_1"].count_documents({})
647+
assert sourceDocCount == targetDocCount
648+
649+
650+
def test_clone_with_nan_id_document_multi_id_types(t: Testing):
651+
t.source["db_1"]["coll_1"].insert_one({"_id": Decimal128("NaN"), "i": 200})
652+
t.source["db_1"]["coll_1"].insert_many(
653+
[{"_id": random.uniform(1e5, 1e10), "i": i} for i in range(50)]
654+
)
655+
t.source["db_1"]["coll_1"].insert_many(
656+
[{"_id": Decimal128(str(random.uniform(1e5, 1e10))), "i": i} for i in range(50)]
657+
)
658+
t.source["db_1"]["coll_1"].insert_many(
659+
[{"_id": str(random.uniform(1e5, 1e10)), "i": i} for i in range(50)]
660+
)
661+
662+
with t.run(phase=Runner.Phase.MANUAL) as r:
663+
r.start()
664+
r.wait_for_clone_completed()
665+
666+
sourceDocCount = t.source["db_1"]["coll_1"].count_documents({})
667+
targetDocCount = t.target["db_1"]["coll_1"].count_documents({})
668+
assert sourceDocCount == targetDocCount

0 commit comments

Comments
 (0)