Skip to content

Commit 3714133

Browse files
craig[bot]msbutlerandy-kimball
committed
143456: crosscluster/producer: allow table level auth for REPLICATIONSOURCE r=jeffswenson a=msbutler Epic: none Release note (ops change): previously, the user provided in the source URI in the LDR stream required the REPLICATIONSOURCE priv at the system level. With this change, the user only needs this priv on the source tables (i.e. a table level priv). 143489: cspann: add ChildKey de-duplicator r=drewkimball a=andy-kimball Add a new de-duplicator class for ChildKeys. Duplicate ChildKeys are going to become much more common with non-transactional fixups. Add an O(N) implementation that minimizes allocations by avoiding hashing KeyBytes values as strings. Instead, it hashes KeyBytes into a uint64 value, which is then hashed using a regular Go map. uint64 collisions are handled by rehashing. Use the new de-duper in the index. In a later PR, it will be used directly in SearchSet as well to detect duplicates as early as possible. Epic: CRDB-42943 Release note: None Co-authored-by: Michael Butler <[email protected]> Co-authored-by: Andrew Kimball <[email protected]>
3 parents 34c8cfc + 41208db + e60506c commit 3714133

File tree

21 files changed

+506
-171
lines changed

21 files changed

+506
-171
lines changed

pkg/crosscluster/logical/logical_replication_job.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ func (p *logicalReplicationPlanner) generatePlanImpl(
488488
// During an offline initial scan, we need to replicate the whole table, not
489489
// just the primary keys.
490490
UseTableSpan: payload.CreateTable && progress.ReplicatedTime.IsEmpty(),
491+
StreamID: streampb.StreamID(payload.StreamID),
491492
}
492493
for _, pair := range payload.ReplicationPairs {
493494
req.TableIDs = append(req.TableIDs, pair.SrcDescriptorID)

pkg/crosscluster/logical/logical_replication_job_test.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,11 @@ func TestCreateTables(t *testing.T) {
498498
gURL := replicationtestutils.GetExternalConnectionURI(t, srv, srv, serverutils.DBName("g"))
499499

500500
sqlG.Exec(t, "CREATE TABLE tab (pk int primary key, payload string)")
501+
sqlG.Exec(t, "CREATE TABLE foo (pk int primary key, payload string)")
501502

502503
var jobID jobspb.JobID
503504
// use create logically replicated table syntax
504-
sqlG.QueryRow(t, "CREATE LOGICALLY REPLICATED TABLE tab2 FROM TABLE tab ON $1 WITH UNIDIRECTIONAL", gURL.String()).Scan(&jobID)
505+
sqlG.QueryRow(t, "CREATE LOGICALLY REPLICATED TABLES (tab2, foo2) FROM TABLES (tab, foo) ON $1 WITH BIDIRECTIONAL ON $2", gURL.String(), gURL.String()).Scan(&jobID)
505506
WaitUntilReplicatedTime(t, srv.Clock().Now(), sqlG, jobID)
506507
// check that tab2 is empty
507508
sqlG.CheckQueryResults(t, "SELECT * FROM tab2", [][]string{})
@@ -1581,6 +1582,32 @@ func WaitUntilReplicatedTime(
15811582
})
15821583
}
15831584

1585+
func GetReverseJobID(
1586+
ctx context.Context, t *testing.T, db *sqlutils.SQLRunner, parentID jobspb.JobID,
1587+
) jobspb.JobID {
1588+
// get the created time of the parent job
1589+
var created time.Time
1590+
db.QueryRow(t, "SELECT created FROM system.jobs WHERE id = $1", parentID).Scan(&created)
1591+
1592+
var jobID jobspb.JobID
1593+
testutils.SucceedsSoon(t, func() error {
1594+
err := db.DB.QueryRowContext(ctx, `
1595+
SELECT id
1596+
FROM system.jobs
1597+
WHERE job_type = 'LOGICAL REPLICATION'
1598+
AND id != $1
1599+
AND created > $2
1600+
ORDER BY created DESC
1601+
LIMIT 1`,
1602+
parentID, created).Scan(&jobID)
1603+
if err != nil {
1604+
return errors.Wrapf(err, "reverse job not found for parent %d", parentID)
1605+
}
1606+
return nil
1607+
})
1608+
return jobID
1609+
}
1610+
15841611
type mockBatchHandler struct {
15851612
err error
15861613
}
@@ -2028,10 +2055,11 @@ func TestUserPrivileges(t *testing.T) {
20282055
testuser.ExpectErr(t, "failed privilege check: table or system level REPLICATIONDEST privilege required: user testuser does not have REPLICATIONDEST privilege on relation tab", createStmt, dbBURL.String())
20292056
dbA.Exec(t, fmt.Sprintf("GRANT SYSTEM REPLICATIONDEST TO %s", username.TestUser))
20302057
testuser.QueryRow(t, createStmt, dbBURL.String()).Scan(&jobAID)
2058+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
20312059
dbA.Exec(t, fmt.Sprintf("REVOKE SYSTEM REPLICATIONDEST FROM %s", username.TestUser))
20322060
})
20332061
t.Run("replication-src", func(t *testing.T) {
2034-
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE system privilege", createStmt, dbBURL2.String())
2062+
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE privilege on relation tab", createStmt, dbBURL2.String())
20352063
sourcePriv := "REPLICATIONSOURCE"
20362064
if rng.Intn(3) == 0 {
20372065
// Test deprecated privilege name.
@@ -2042,6 +2070,15 @@ func TestUserPrivileges(t *testing.T) {
20422070
dbA.QueryRow(t, createStmt, dbBURL2.String()).Scan(&jobAID)
20432071
dbB.Exec(t, fmt.Sprintf("REVOKE SYSTEM %s FROM %s", sourcePriv, username.TestUser+"3"))
20442072
})
2073+
t.Run("table-level-replication-src", func(t *testing.T) {
2074+
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE privilege on relation tab", createStmt, dbBURL2.String())
2075+
2076+
dbB.Exec(t, fmt.Sprintf("GRANT REPLICATIONSOURCE ON TABLE tab TO %s", username.TestUser+"3"))
2077+
dbA.QueryRow(t, createStmt, dbBURL2.String()).Scan(&jobAID)
2078+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
2079+
2080+
dbB.Exec(t, fmt.Sprintf("REVOKE REPLICATIONSOURCE ON TABLE tab FROM %s", username.TestUser+"3"))
2081+
})
20452082
t.Run("table-level-replication-dest", func(t *testing.T) {
20462083

20472084
dbA.Exec(t, `CREATE TABLE tab2 (x INT PRIMARY KEY)`)
@@ -2055,7 +2092,9 @@ func TestUserPrivileges(t *testing.T) {
20552092
testuser.ExpectErr(t, "failed privilege check: table or system level REPLICATIONDEST privilege required: user testuser does not have REPLICATIONDEST privilege on relation tab2", multiTableStmt, dbBURL.String())
20562093

20572094
dbA.Exec(t, fmt.Sprintf(`GRANT REPLICATIONDEST ON TABLE tab2 TO %s`, username.TestUser))
2058-
testuser.Exec(t, multiTableStmt, dbBURL.String())
2095+
testuser.QueryRow(t, multiTableStmt, dbBURL.String()).Scan(&jobAID)
2096+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
2097+
20592098
dbA.Exec(t, fmt.Sprintf(`REVOKE REPLICATIONDEST ON TABLE tab FROM %s`, username.TestUser))
20602099
dbA.Exec(t, fmt.Sprintf(`REVOKE REPLICATIONDEST ON TABLE tab2 FROM %s`, username.TestUser))
20612100
})
@@ -2066,7 +2105,8 @@ func TestUserPrivileges(t *testing.T) {
20662105

20672106
// Grant CREATE privilege on destination database - should now succeed
20682107
dbA.Exec(t, `GRANT CREATE ON DATABASE a TO testuser`)
2069-
testuser.Exec(t, createStmt, dbBURL.String())
2108+
testuser.QueryRow(t, createStmt, dbBURL.String()).Scan(&jobAID)
2109+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
20702110

20712111
dbAURL := replicationtestutils.GetExternalConnectionURI(t, s, s, serverutils.DBName("a"), serverutils.User(username.TestUser))
20722112

@@ -2075,8 +2115,14 @@ func TestUserPrivileges(t *testing.T) {
20752115
testuser.ExpectErr(t, " uri requires REPLICATIONDEST privilege for bidirectional replication: user testuser3 does not have REPLICATIONDEST privilege on relation tab", createStmtBidi, dbBURL2.String(), dbAURL.String())
20762116

20772117
dbB.Exec(t, fmt.Sprintf("GRANT SYSTEM REPLICATIONDEST TO %s", username.TestUser+"3"))
2078-
testuser.QueryRow(t, createStmtBidi, dbBURL2.String(), dbAURL.String()).Scan(&jobAID)
20792118

2119+
var jobAID2 jobspb.JobID
2120+
testuser.QueryRow(t, createStmtBidi, dbBURL2.String(), dbAURL.String()).Scan(&jobAID2)
2121+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID2)
2122+
2123+
// Ensure the reverse job advances as well
2124+
reverseJobID := GetReverseJobID(ctx, t, dbA, jobAID2)
2125+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, reverseJobID)
20802126
})
20812127
}
20822128

pkg/crosscluster/producer/replication_manager.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,30 @@ func (r *replicationStreamManagerImpl) AuthorizeViaJob(
450450
return nil
451451
}
452452

453-
func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(ctx context.Context) error {
454-
err := r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
455-
syntheticprivilege.GlobalPrivilegeObject,
456-
privilege.REPLICATIONSOURCE)
453+
// AuthorizeViaReplicationPriv ensures the user has the REPLICATIONSOUCE privilege. If tableNames is passed, then table level auth is tried.
454+
func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(
455+
ctx context.Context, tableNames ...string,
456+
) (err error) {
457+
458+
authorize := func() error {
459+
// First try fast path for system level priv.
460+
err = r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
461+
syntheticprivilege.GlobalPrivilegeObject,
462+
privilege.REPLICATIONSOURCE)
463+
if err == nil {
464+
return nil
465+
} else if pgerror.GetPGCode(err) != pgcode.InsufficientPrivilege {
466+
return err
467+
}
468+
if len(tableNames) != 0 {
469+
err = replicationutils.AuthorizeTableLevelPriv(ctx, r.resolver, r.evalCtx.SessionAccessor, privilege.REPLICATIONSOURCE, tableNames)
470+
if err == nil {
471+
return nil
472+
} else if pgerror.GetPGCode(err) != pgcode.InsufficientPrivilege {
473+
return err
474+
}
475+
}
457476

458-
if pgerror.GetPGCode(err) == pgcode.InsufficientPrivilege {
459477
// Fallback to legacy REPLICATION priv.
460478
if fallbackErr := r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
461479
syntheticprivilege.GlobalPrivilegeObject,
@@ -465,10 +483,12 @@ func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(ctx context.C
465483
// the deprecated REPLICATION priv.
466484
return err
467485
}
468-
} else if err != nil {
469-
return err
486+
return nil
470487
}
471488

489+
if err = authorize(); err != nil {
490+
return err
491+
}
472492
r.authorized = true
473493
return nil
474494
}

pkg/crosscluster/replicationutils/utils.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ func AuthorizeTableLevelPriv(
347347
return err
348348
}
349349
lookupFlags := tree.ObjectLookupFlags{
350+
// TODO(msbutler): for reasons beyond my paygrade, to grab offline
351+
// descriptors, we need to also pass RequireMutable.
352+
RequireMutable: true,
353+
IncludeOffline: true,
350354
Required: true,
351355
DesiredObjectKind: tree.TableObject,
352356
DesiredTableDescKind: tree.ResolveRequireTableDesc,

pkg/repstream/streampb/stream.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ message LogicalReplicationPlanRequest {
170170
// PlanAsOf is the time as of which the plan should be produced.
171171
util.hlc.Timestamp plan_as_of = 2 [(gogoproto.nullable) = false];
172172
bool use_table_span = 3;
173+
int64 stream_id = 4 [(gogoproto.customname) = "StreamID", (gogoproto.casttype) = "StreamID"];
173174
}
174175

175176
// SourcePartition contains per partition information for a replication plan.

pkg/sql/sem/builtins/replication_builtins.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,14 @@ var replicationBuiltins = map[string]builtinDefinition{
511511
if err := protoutil.Unmarshal(reqBytes, &req); err != nil {
512512
return nil, err
513513
}
514-
if err := mgr.AuthorizeViaReplicationPriv(ctx); err != nil {
514+
if req.StreamID != 0 {
515+
if err := mgr.AuthorizeViaJob(ctx, req.StreamID); err != nil {
516+
return nil, err
517+
}
518+
// Auth via replication priv exists to ensure a user that planned their
519+
// job pre 25.2, which will not send a stream id, can still plan their
520+
// distsql flow.
521+
} else if err := mgr.AuthorizeViaReplicationPriv(ctx); err != nil {
515522
return nil, err
516523
}
517524

@@ -546,14 +553,14 @@ var replicationBuiltins = map[string]builtinDefinition{
546553
if err != nil {
547554
return nil, err
548555
}
549-
if err := mgr.AuthorizeViaReplicationPriv(ctx); err != nil {
550-
return nil, err
551-
}
552556
reqBytes := []byte(tree.MustBeDBytes(args[0]))
553557
req := streampb.ReplicationProducerRequest{}
554558
if err := protoutil.Unmarshal(reqBytes, &req); err != nil {
555559
return nil, err
556560
}
561+
if err := mgr.AuthorizeViaReplicationPriv(ctx, req.TableNames...); err != nil {
562+
return nil, errors.Wrapf(err, "failed to auth")
563+
}
557564

558565
spec, err := mgr.StartReplicationStreamForTables(ctx, req)
559566
if err != nil {

pkg/sql/sem/eval/context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ type ReplicationStreamManager interface {
934934
) (streampb.ReplicationProducerSpec, error)
935935

936936
AuthorizeViaJob(ctx context.Context, streamID streampb.StreamID) error
937-
AuthorizeViaReplicationPriv(ctx context.Context) error
937+
AuthorizeViaReplicationPriv(ctx context.Context, tableNames ...string) error
938938
}
939939

940940
// StreamIngestManager represents a collection of APIs that streaming replication supports

pkg/sql/vecindex/cspann/BUILD.bazel

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ filegroup(
1111
go_library(
1212
name = "cspann",
1313
srcs = [
14+
"childkey_dedup.go",
1415
"cspannpb.go",
1516
"fixup_processor.go",
1617
"fixup_split.go",
@@ -49,6 +50,7 @@ go_library(
4950
go_test(
5051
name = "cspann_test",
5152
srcs = [
53+
"childkey_dedup_test.go",
5254
"cspannpb_test.go",
5355
"fixup_processor_test.go",
5456
"fixup_worker_test.go",
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright 2025 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the CockroachDB Software License
4+
// included in the /LICENSE file.
5+
6+
package cspann
7+
8+
import (
9+
"bytes"
10+
"hash/maphash"
11+
12+
"github.com/cockroachdb/errors"
13+
)
14+
15+
// hashKeyFunc is a function type for hashing KeyBytes.
16+
type hashKeyFunc func(KeyBytes) uint64
17+
18+
// childKeyDeDup provides de-duplication for ChildKey values. It supports both
19+
// PartitionKey and KeyBytes child keys efficiently without making unnecessary
20+
// allocations.
21+
type childKeyDeDup struct {
22+
// initialCapacity is used to initialize the size of the data structures used
23+
// by the de-duplicator.
24+
initialCapacity int
25+
26+
// partitionKeys is used for PartitionKey deduplication.
27+
partitionKeys map[PartitionKey]struct{}
28+
29+
// keyBytesMap maps from a KeyBytes hash to the actual KeyBytes for direct
30+
// lookup and deduplication.
31+
keyBytesMap map[uint64]KeyBytes
32+
33+
// seed pseudo-randomizes the hash function used by the de-duplicator.
34+
seed maphash.Seed
35+
36+
// hashKeyBytes is the function used to hash KeyBytes. This is primarily used
37+
// for testing to override the default hash function.
38+
hashKeyBytes hashKeyFunc
39+
}
40+
41+
// Init initializes the de-duplicator.
42+
func (dd *childKeyDeDup) Init(capacity int) {
43+
dd.initialCapacity = capacity
44+
dd.seed = maphash.MakeSeed()
45+
dd.hashKeyBytes = dd.defaultHashKeyBytes
46+
dd.Clear()
47+
}
48+
49+
// TryAdd attempts to add a child key to the deduplication set. It returns true
50+
// if the key was added (wasn't a duplicate), or false if the key already exists
51+
// (is a duplicate).
52+
func (dd *childKeyDeDup) TryAdd(childKey ChildKey) bool {
53+
// Handle PartitionKey case - simple map lookup.
54+
if childKey.PartitionKey != 0 {
55+
// Lazily initialize the partitionKeys map.
56+
if dd.partitionKeys == nil {
57+
dd.partitionKeys = make(map[PartitionKey]struct{}, dd.initialCapacity)
58+
}
59+
60+
if _, exists := dd.partitionKeys[childKey.PartitionKey]; exists {
61+
return false
62+
}
63+
dd.partitionKeys[childKey.PartitionKey] = struct{}{}
64+
return true
65+
}
66+
67+
// Handle KeyBytes case. Lazily initialize the KeyBytes structures.
68+
if dd.keyBytesMap == nil {
69+
dd.keyBytesMap = make(map[uint64]KeyBytes, dd.initialCapacity)
70+
}
71+
72+
// Calculate original hash for the key bytes.
73+
hash := dd.hashKeyBytes(childKey.KeyBytes)
74+
75+
// Check for the key, possibly having to look at multiple rehashed slots.
76+
iterations := 0
77+
for {
78+
existingKey, exists := dd.keyBytesMap[hash]
79+
if !exists {
80+
// No collision, we can use this hash.
81+
break
82+
}
83+
84+
// Check if this is the same key.
85+
if bytes.Equal(existingKey, childKey.KeyBytes) {
86+
return false
87+
}
88+
89+
// Hash collision, rehash to find a new slot.
90+
hash = dd.rehash(hash)
91+
92+
iterations++
93+
if iterations >= 100000 {
94+
// We must have hit a cycle, which should never happen.
95+
panic(errors.AssertionFailedf("rehash function cycled"))
96+
}
97+
}
98+
99+
// Add the key to the map.
100+
dd.keyBytesMap[hash] = childKey.KeyBytes
101+
return true
102+
}
103+
104+
// Clear removes all entries from the deduplication set.
105+
func (dd *childKeyDeDup) Clear() {
106+
// Reset all the data structures.
107+
clear(dd.partitionKeys)
108+
clear(dd.keyBytesMap)
109+
}
110+
111+
// defaultHashKeyBytes is the default implementation of hashKeyBytes.
112+
func (dd *childKeyDeDup) defaultHashKeyBytes(key KeyBytes) uint64 {
113+
return maphash.Bytes(dd.seed, key)
114+
}
115+
116+
// rehash creates a new hash from an existing hash to resolve collisions.
117+
func (dd *childKeyDeDup) rehash(hash uint64) uint64 {
118+
// These constants are large 64-bit primes.
119+
hash ^= 0xc3a5c85c97cb3127
120+
hash ^= hash >> 33
121+
hash *= 0xff51afd7ed558ccd
122+
hash ^= hash >> 33
123+
hash *= 0xc4ceb9fe1a85ec53
124+
hash ^= hash >> 33
125+
return hash
126+
}

0 commit comments

Comments
 (0)