Skip to content

Commit 41208db

Browse files
committed
crosscluster/producer: allow table level auth for REPLICATIONSOURCE
Release note (ops change): previously, the user provided in the source URI in the LDR stream required the REPLICATIONSOURCE priv at the system level. With this change, the user only needs this priv on the source tables (i.e. a table level priv).
1 parent 215b5de commit 41208db

File tree

5 files changed

+86
-16
lines changed

5 files changed

+86
-16
lines changed

pkg/crosscluster/logical/logical_replication_job_test.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,10 +498,11 @@ func TestCreateTables(t *testing.T) {
498498
gURL := replicationtestutils.GetExternalConnectionURI(t, srv, srv, serverutils.DBName("g"))
499499

500500
sqlG.Exec(t, "CREATE TABLE tab (pk int primary key, payload string)")
501+
sqlG.Exec(t, "CREATE TABLE foo (pk int primary key, payload string)")
501502

502503
var jobID jobspb.JobID
503504
// use create logically replicated table syntax
504-
sqlG.QueryRow(t, "CREATE LOGICALLY REPLICATED TABLE tab2 FROM TABLE tab ON $1 WITH UNIDIRECTIONAL", gURL.String()).Scan(&jobID)
505+
sqlG.QueryRow(t, "CREATE LOGICALLY REPLICATED TABLES (tab2, foo2) FROM TABLES (tab, foo) ON $1 WITH BIDIRECTIONAL ON $2", gURL.String(), gURL.String()).Scan(&jobID)
505506
WaitUntilReplicatedTime(t, srv.Clock().Now(), sqlG, jobID)
506507
// check that tab2 is empty
507508
sqlG.CheckQueryResults(t, "SELECT * FROM tab2", [][]string{})
@@ -1581,6 +1582,32 @@ func WaitUntilReplicatedTime(
15811582
})
15821583
}
15831584

1585+
func GetReverseJobID(
1586+
ctx context.Context, t *testing.T, db *sqlutils.SQLRunner, parentID jobspb.JobID,
1587+
) jobspb.JobID {
1588+
// get the created time of the parent job
1589+
var created time.Time
1590+
db.QueryRow(t, "SELECT created FROM system.jobs WHERE id = $1", parentID).Scan(&created)
1591+
1592+
var jobID jobspb.JobID
1593+
testutils.SucceedsSoon(t, func() error {
1594+
err := db.DB.QueryRowContext(ctx, `
1595+
SELECT id
1596+
FROM system.jobs
1597+
WHERE job_type = 'LOGICAL REPLICATION'
1598+
AND id != $1
1599+
AND created > $2
1600+
ORDER BY created DESC
1601+
LIMIT 1`,
1602+
parentID, created).Scan(&jobID)
1603+
if err != nil {
1604+
return errors.Wrapf(err, "reverse job not found for parent %d", parentID)
1605+
}
1606+
return nil
1607+
})
1608+
return jobID
1609+
}
1610+
15841611
type mockBatchHandler struct {
15851612
err error
15861613
}
@@ -2028,10 +2055,11 @@ func TestUserPrivileges(t *testing.T) {
20282055
testuser.ExpectErr(t, "failed privilege check: table or system level REPLICATIONDEST privilege required: user testuser does not have REPLICATIONDEST privilege on relation tab", createStmt, dbBURL.String())
20292056
dbA.Exec(t, fmt.Sprintf("GRANT SYSTEM REPLICATIONDEST TO %s", username.TestUser))
20302057
testuser.QueryRow(t, createStmt, dbBURL.String()).Scan(&jobAID)
2058+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
20312059
dbA.Exec(t, fmt.Sprintf("REVOKE SYSTEM REPLICATIONDEST FROM %s", username.TestUser))
20322060
})
20332061
t.Run("replication-src", func(t *testing.T) {
2034-
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE system privilege", createStmt, dbBURL2.String())
2062+
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE privilege on relation tab", createStmt, dbBURL2.String())
20352063
sourcePriv := "REPLICATIONSOURCE"
20362064
if rng.Intn(3) == 0 {
20372065
// Test deprecated privilege name.
@@ -2042,6 +2070,15 @@ func TestUserPrivileges(t *testing.T) {
20422070
dbA.QueryRow(t, createStmt, dbBURL2.String()).Scan(&jobAID)
20432071
dbB.Exec(t, fmt.Sprintf("REVOKE SYSTEM %s FROM %s", sourcePriv, username.TestUser+"3"))
20442072
})
2073+
t.Run("table-level-replication-src", func(t *testing.T) {
2074+
dbA.ExpectErr(t, "user testuser3 does not have REPLICATIONSOURCE privilege on relation tab", createStmt, dbBURL2.String())
2075+
2076+
dbB.Exec(t, fmt.Sprintf("GRANT REPLICATIONSOURCE ON TABLE tab TO %s", username.TestUser+"3"))
2077+
dbA.QueryRow(t, createStmt, dbBURL2.String()).Scan(&jobAID)
2078+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
2079+
2080+
dbB.Exec(t, fmt.Sprintf("REVOKE REPLICATIONSOURCE ON TABLE tab FROM %s", username.TestUser+"3"))
2081+
})
20452082
t.Run("table-level-replication-dest", func(t *testing.T) {
20462083

20472084
dbA.Exec(t, `CREATE TABLE tab2 (x INT PRIMARY KEY)`)
@@ -2055,7 +2092,9 @@ func TestUserPrivileges(t *testing.T) {
20552092
testuser.ExpectErr(t, "failed privilege check: table or system level REPLICATIONDEST privilege required: user testuser does not have REPLICATIONDEST privilege on relation tab2", multiTableStmt, dbBURL.String())
20562093

20572094
dbA.Exec(t, fmt.Sprintf(`GRANT REPLICATIONDEST ON TABLE tab2 TO %s`, username.TestUser))
2058-
testuser.Exec(t, multiTableStmt, dbBURL.String())
2095+
testuser.QueryRow(t, multiTableStmt, dbBURL.String()).Scan(&jobAID)
2096+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
2097+
20592098
dbA.Exec(t, fmt.Sprintf(`REVOKE REPLICATIONDEST ON TABLE tab FROM %s`, username.TestUser))
20602099
dbA.Exec(t, fmt.Sprintf(`REVOKE REPLICATIONDEST ON TABLE tab2 FROM %s`, username.TestUser))
20612100
})
@@ -2066,7 +2105,8 @@ func TestUserPrivileges(t *testing.T) {
20662105

20672106
// Grant CREATE privilege on destination database - should now succeed
20682107
dbA.Exec(t, `GRANT CREATE ON DATABASE a TO testuser`)
2069-
testuser.Exec(t, createStmt, dbBURL.String())
2108+
testuser.QueryRow(t, createStmt, dbBURL.String()).Scan(&jobAID)
2109+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID)
20702110

20712111
dbAURL := replicationtestutils.GetExternalConnectionURI(t, s, s, serverutils.DBName("a"), serverutils.User(username.TestUser))
20722112

@@ -2075,8 +2115,14 @@ func TestUserPrivileges(t *testing.T) {
20752115
testuser.ExpectErr(t, " uri requires REPLICATIONDEST privilege for bidirectional replication: user testuser3 does not have REPLICATIONDEST privilege on relation tab", createStmtBidi, dbBURL2.String(), dbAURL.String())
20762116

20772117
dbB.Exec(t, fmt.Sprintf("GRANT SYSTEM REPLICATIONDEST TO %s", username.TestUser+"3"))
2078-
testuser.QueryRow(t, createStmtBidi, dbBURL2.String(), dbAURL.String()).Scan(&jobAID)
20792118

2119+
var jobAID2 jobspb.JobID
2120+
testuser.QueryRow(t, createStmtBidi, dbBURL2.String(), dbAURL.String()).Scan(&jobAID2)
2121+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, jobAID2)
2122+
2123+
// Ensure the reverse job advances as well
2124+
reverseJobID := GetReverseJobID(ctx, t, dbA, jobAID2)
2125+
WaitUntilReplicatedTime(t, s.Clock().Now(), dbA, reverseJobID)
20802126
})
20812127
}
20822128

pkg/crosscluster/producer/replication_manager.go

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -450,12 +450,30 @@ func (r *replicationStreamManagerImpl) AuthorizeViaJob(
450450
return nil
451451
}
452452

453-
func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(ctx context.Context) error {
454-
err := r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
455-
syntheticprivilege.GlobalPrivilegeObject,
456-
privilege.REPLICATIONSOURCE)
453+
// AuthorizeViaReplicationPriv ensures the user has the REPLICATIONSOUCE privilege. If tableNames is passed, then table level auth is tried.
454+
func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(
455+
ctx context.Context, tableNames ...string,
456+
) (err error) {
457+
458+
authorize := func() error {
459+
// First try fast path for system level priv.
460+
err = r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
461+
syntheticprivilege.GlobalPrivilegeObject,
462+
privilege.REPLICATIONSOURCE)
463+
if err == nil {
464+
return nil
465+
} else if pgerror.GetPGCode(err) != pgcode.InsufficientPrivilege {
466+
return err
467+
}
468+
if len(tableNames) != 0 {
469+
err = replicationutils.AuthorizeTableLevelPriv(ctx, r.resolver, r.evalCtx.SessionAccessor, privilege.REPLICATIONSOURCE, tableNames)
470+
if err == nil {
471+
return nil
472+
} else if pgerror.GetPGCode(err) != pgcode.InsufficientPrivilege {
473+
return err
474+
}
475+
}
457476

458-
if pgerror.GetPGCode(err) == pgcode.InsufficientPrivilege {
459477
// Fallback to legacy REPLICATION priv.
460478
if fallbackErr := r.evalCtx.SessionAccessor.CheckPrivilege(ctx,
461479
syntheticprivilege.GlobalPrivilegeObject,
@@ -465,10 +483,12 @@ func (r *replicationStreamManagerImpl) AuthorizeViaReplicationPriv(ctx context.C
465483
// the deprecated REPLICATION priv.
466484
return err
467485
}
468-
} else if err != nil {
469-
return err
486+
return nil
470487
}
471488

489+
if err = authorize(); err != nil {
490+
return err
491+
}
472492
r.authorized = true
473493
return nil
474494
}

pkg/crosscluster/replicationutils/utils.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,10 @@ func AuthorizeTableLevelPriv(
347347
return err
348348
}
349349
lookupFlags := tree.ObjectLookupFlags{
350+
// TODO(msbutler): for reasons beyond my paygrade, to grab offline
351+
// descriptors, we need to also pass RequireMutable.
352+
RequireMutable: true,
353+
IncludeOffline: true,
350354
Required: true,
351355
DesiredObjectKind: tree.TableObject,
352356
DesiredTableDescKind: tree.ResolveRequireTableDesc,

pkg/sql/sem/builtins/replication_builtins.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -553,14 +553,14 @@ var replicationBuiltins = map[string]builtinDefinition{
553553
if err != nil {
554554
return nil, err
555555
}
556-
if err := mgr.AuthorizeViaReplicationPriv(ctx); err != nil {
557-
return nil, err
558-
}
559556
reqBytes := []byte(tree.MustBeDBytes(args[0]))
560557
req := streampb.ReplicationProducerRequest{}
561558
if err := protoutil.Unmarshal(reqBytes, &req); err != nil {
562559
return nil, err
563560
}
561+
if err := mgr.AuthorizeViaReplicationPriv(ctx, req.TableNames...); err != nil {
562+
return nil, errors.Wrapf(err, "failed to auth")
563+
}
564564

565565
spec, err := mgr.StartReplicationStreamForTables(ctx, req)
566566
if err != nil {

pkg/sql/sem/eval/context.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ type ReplicationStreamManager interface {
934934
) (streampb.ReplicationProducerSpec, error)
935935

936936
AuthorizeViaJob(ctx context.Context, streamID streampb.StreamID) error
937-
AuthorizeViaReplicationPriv(ctx context.Context) error
937+
AuthorizeViaReplicationPriv(ctx context.Context, tableNames ...string) error
938938
}
939939

940940
// StreamIngestManager represents a collection of APIs that streaming replication supports

0 commit comments

Comments
 (0)