@@ -11,6 +11,8 @@ package streamingest
1111import (
1212 "context"
1313 "fmt"
14+ "math"
15+ "sort"
1416 "time"
1517
1618 "github.com/cockroachdb/cockroach/pkg/base"
@@ -222,11 +224,16 @@ func (p *replicationFlowPlanner) makePlan(
222224 if err != nil {
223225 return nil , nil , err
224226 }
227+ destNodeLocalities , err := getDestNodeLocalities (ctx , dsp , sqlInstanceIDs )
228+ if err != nil {
229+ return nil , nil , err
230+ }
225231
226232 streamIngestionSpecs , streamIngestionFrontierSpec , err := constructStreamIngestionPlanSpecs (
233+ ctx ,
227234 streamingccl .StreamAddress (details .StreamAddress ),
228235 topology ,
229- sqlInstanceIDs ,
236+ destNodeLocalities ,
230237 initialScanTimestamp ,
231238 previousReplicatedTime ,
232239 checkpoint ,
@@ -262,11 +269,6 @@ func (p *replicationFlowPlanner) makePlan(
262269 execinfrapb.ProcessorCoreUnion {StreamIngestionFrontier : streamIngestionFrontierSpec },
263270 execinfrapb.PostProcessSpec {}, streamIngestionResultTypes )
264271
265- for src , dst := range streamIngestionFrontierSpec .SubscribingSQLInstances {
266- log .Infof (ctx , "physical replication src-dst pair candidate: %s:%d" ,
267- src , dst )
268- }
269-
270272 p .PlanToStreamColMap = []int {0 }
271273 sql .FinalizePlan (ctx , planCtx , p )
272274 return p , planCtx , nil
@@ -318,10 +320,128 @@ func measurePlanChange(before, after *sql.PhysicalPlan) float64 {
318320 return float64 (diff ) / float64 (oldCount )
319321}
320322
323+ type partitionWithCandidates struct {
324+ partition streamclient.PartitionInfo
325+ closestDestIDs []base.SQLInstanceID
326+ sharedPrefixLength int
327+ }
328+
329+ type candidatesByPriority []partitionWithCandidates
330+
331+ func (a candidatesByPriority ) Len () int { return len (a ) }
332+ func (a candidatesByPriority ) Swap (i , j int ) { a [i ], a [j ] = a [j ], a [i ] }
333+ func (a candidatesByPriority ) Less (i , j int ) bool {
334+ return a [i ].sharedPrefixLength > a [j ].sharedPrefixLength
335+ }
336+
337+ // nodeMatcher matches each source cluster node to a destination cluster node,
338+ // given a list of available nodes in each cluster. The matcher has a primary goal
339+ // to match src-dst nodes that are "close" to each other, i.e. have common
340+ // locality tags, and a secondary goal to distribute source node assignments
341+ // evenly across destination nodes. Here's the algorithm:
342+ //
343+ // - For each src node, find their closest dst nodes and the number of
344+ // localities that match, the LocalityMatchCount, via the sql.ClosestInstances()
345+ // function. Example: Consider Src-A [US,East] which has match candidates Dst-A
346+ // [US,West], Dst-B [US, Central]. In the example, the LocalityMatchCount is 1,
347+ // as only US matches with the src node's locality.
348+ //
349+ // - Prioritize matching src nodes with a higher locality match count, via the
350+ // findSourceNodePriority() function.
351+ //
352+ // - While we have src nodes left to match, match the highest priority src node
353+ // to the dst node candidate that has the fewest matches already, via the
354+ // findMatch() function.
355+
356+ type nodeMatcher struct {
357+ destMatchCount map [base.SQLInstanceID ]int
358+ destNodesInfo []sql.InstanceLocality
359+ destNodeToLocality map [base.SQLInstanceID ]roachpb.Locality
360+ }
361+
362+ func makeNodeMatcher (destNodesInfo []sql.InstanceLocality ) * nodeMatcher {
363+ nodeToLocality := make (map [base.SQLInstanceID ]roachpb.Locality , len (destNodesInfo ))
364+ for _ , node := range destNodesInfo {
365+ nodeToLocality [node .GetInstanceID ()] = node .GetLocality ()
366+ }
367+ return & nodeMatcher {
368+ destMatchCount : make (map [base.SQLInstanceID ]int , len (destNodesInfo )),
369+ destNodesInfo : destNodesInfo ,
370+ destNodeToLocality : nodeToLocality ,
371+ }
372+ }
373+
374+ func (nm * nodeMatcher ) destNodeIDs () []base.SQLInstanceID {
375+ allDestNodeIDs := make ([]base.SQLInstanceID , 0 , len (nm .destNodesInfo ))
376+ for _ , info := range nm .destNodesInfo {
377+ allDestNodeIDs = append (allDestNodeIDs , info .GetInstanceID ())
378+ }
379+ return allDestNodeIDs
380+ }
381+
382+ // findSourceNodePriority finds the closest dest nodes for each source node and
383+ // returns a list of (source node, dest node match candidates) pairs ordered by
384+ // matching priority. A source node is earlier (higher priority) in the list if
385+ // it shares more locality tiers with their destination node match candidates.
386+ func (nm * nodeMatcher ) findSourceNodePriority (topology streamclient.Topology ) candidatesByPriority {
387+
388+ allDestNodeIDs := nm .destNodeIDs ()
389+ candidates := make (candidatesByPriority , 0 , len (topology .Partitions ))
390+ for _ , partition := range topology .Partitions {
391+ closestDestIDs , sharedPrefixLength := sql .ClosestInstances (nm .destNodesInfo ,
392+ partition .SrcLocality )
393+ if sharedPrefixLength == 0 {
394+ closestDestIDs = allDestNodeIDs
395+ }
396+ candidate := partitionWithCandidates {
397+ partition : partition ,
398+ closestDestIDs : closestDestIDs ,
399+ sharedPrefixLength : sharedPrefixLength ,
400+ }
401+ candidates = append (candidates , candidate )
402+ }
403+ sort .Sort (candidates )
404+
405+ return candidates
406+ }
407+
408+ // findMatch returns the destination node id with the fewest src node matches from the input list.
409+ func (nm * nodeMatcher ) findMatch (destIDCandidates []base.SQLInstanceID ) base.SQLInstanceID {
410+ minCount := math .MaxInt
411+ currentMatch := base .SQLInstanceID (0 )
412+
413+ for _ , destID := range destIDCandidates {
414+ currentDestCount := nm .destMatchCount [destID ]
415+ if currentDestCount < minCount {
416+ currentMatch = destID
417+ minCount = currentDestCount
418+ }
419+ }
420+ nm .destMatchCount [currentMatch ]++
421+ return currentMatch
422+ }
423+
424+ func getDestNodeLocalities (
425+ ctx context.Context , dsp * sql.DistSQLPlanner , instanceIDs []base.SQLInstanceID ,
426+ ) ([]sql.InstanceLocality , error ) {
427+
428+ instanceInfos := make ([]sql.InstanceLocality , 0 , len (instanceIDs ))
429+ for _ , id := range instanceIDs {
430+ nodeDesc , err := dsp .GetSQLInstanceInfo (id )
431+ if err != nil {
432+ log .Eventf (ctx , "unable to get node descriptor for sql node %s" , id )
433+ return nil , err
434+ }
435+ instanceInfos = append (instanceInfos , sql .MakeInstanceLocality (id , nodeDesc .Locality ))
436+ }
437+ return instanceInfos , nil
438+ }
439+
321440func constructStreamIngestionPlanSpecs (
441+ ctx context.Context ,
322442 streamAddress streamingccl.StreamAddress ,
323443 topology streamclient.Topology ,
324- sqlInstanceIDs []base. SQLInstanceID ,
444+ destSQLInstances []sql. InstanceLocality ,
325445 initialScanTimestamp hlc.Timestamp ,
326446 previousReplicatedTimestamp hlc.Timestamp ,
327447 checkpoint jobspb.StreamIngestionCheckpoint ,
@@ -330,41 +450,55 @@ func constructStreamIngestionPlanSpecs(
330450 sourceTenantID roachpb.TenantID ,
331451 destinationTenantID roachpb.TenantID ,
332452) ([]* execinfrapb.StreamIngestionDataSpec , * execinfrapb.StreamIngestionFrontierSpec , error ) {
333- // For each stream partition in the topology, assign it to a node.
334- streamIngestionSpecs := make ([]* execinfrapb.StreamIngestionDataSpec , 0 , len (sqlInstanceIDs ))
453+
454+ streamIngestionSpecs := make ([]* execinfrapb.StreamIngestionDataSpec , 0 , len (destSQLInstances ))
455+ destSQLInstancesToIdx := make (map [base.SQLInstanceID ]int , len (destSQLInstances ))
456+ for i , id := range destSQLInstances {
457+ spec := & execinfrapb.StreamIngestionDataSpec {
458+ StreamID : uint64 (streamID ),
459+ JobID : int64 (jobID ),
460+ PreviousReplicatedTimestamp : previousReplicatedTimestamp ,
461+ InitialScanTimestamp : initialScanTimestamp ,
462+ Checkpoint : checkpoint , // TODO: Only forward relevant checkpoint info
463+ StreamAddress : string (streamAddress ),
464+ PartitionSpecs : make (map [string ]execinfrapb.StreamIngestionPartitionSpec ),
465+ TenantRekey : execinfrapb.TenantRekey {
466+ OldID : sourceTenantID ,
467+ NewID : destinationTenantID ,
468+ },
469+ }
470+ streamIngestionSpecs = append (streamIngestionSpecs , spec )
471+ destSQLInstancesToIdx [id .GetInstanceID ()] = i
472+ }
335473
336474 trackedSpans := make ([]roachpb.Span , 0 )
337475 subscribingSQLInstances := make (map [string ]uint32 )
338- for i , partition := range topology .Partitions {
339- // Round robin assign the stream partitions to nodes. Partitions 0 through
340- // len(nodes) - 1 creates the spec. Future partitions just add themselves to
341- // the partition addresses.
342- if i < len (sqlInstanceIDs ) {
343- spec := & execinfrapb.StreamIngestionDataSpec {
344- StreamID : uint64 (streamID ),
345- JobID : int64 (jobID ),
346- PreviousReplicatedTimestamp : previousReplicatedTimestamp ,
347- InitialScanTimestamp : initialScanTimestamp ,
348- Checkpoint : checkpoint , // TODO: Only forward relevant checkpoint info
349- StreamAddress : string (streamAddress ),
350- PartitionSpecs : make (map [string ]execinfrapb.StreamIngestionPartitionSpec ),
351- TenantRekey : execinfrapb.TenantRekey {
352- OldID : sourceTenantID ,
353- NewID : destinationTenantID ,
354- },
355- }
356- streamIngestionSpecs = append (streamIngestionSpecs , spec )
357- }
358- n := i % len (sqlInstanceIDs )
359476
360- subscribingSQLInstances [partition .ID ] = uint32 (sqlInstanceIDs [n ])
361- streamIngestionSpecs [n ].PartitionSpecs [partition .ID ] = execinfrapb.StreamIngestionPartitionSpec {
477+ // Update stream ingestion specs with their matched source node.
478+ matcher := makeNodeMatcher (destSQLInstances )
479+ for _ , candidate := range matcher .findSourceNodePriority (topology ) {
480+ destID := matcher .findMatch (candidate .closestDestIDs )
481+ log .Infof (ctx , "physical replication src-dst pair candidate: %s (locality %s) - %d (" +
482+ "locality %s)" ,
483+ candidate .partition .ID ,
484+ candidate .partition .SrcLocality ,
485+ destID ,
486+ matcher .destNodeToLocality [destID ])
487+ partition := candidate .partition
488+ subscribingSQLInstances [partition .ID ] = uint32 (destID )
489+
490+ specIdx , ok := destSQLInstancesToIdx [destID ]
491+ if ! ok {
492+ return nil , nil , errors .AssertionFailedf (
493+ "matched destination node id does not contain a stream ingestion spec" )
494+ }
495+ streamIngestionSpecs [specIdx ].PartitionSpecs [partition .ID ] = execinfrapb.
496+ StreamIngestionPartitionSpec {
362497 PartitionID : partition .ID ,
363498 SubscriptionToken : string (partition .SubscriptionToken ),
364499 Address : string (partition .SrcAddr ),
365500 Spans : partition .Spans ,
366501 }
367-
368502 trackedSpans = append (trackedSpans , partition .Spans ... )
369503 }
370504
0 commit comments