Skip to content

Commit 908d0cf

Browse files
authored
fix: Use a new context for the domain name check loop. (#1007)
This fixes a bug in the go connector that was exposed by the Auth Proxy. If the Auth Proxy is configured with a domain name, and then the DNS record is changed to point to a new instance. Connections to the new instance will be terminated every 30 seconds. This was caused by a bug in how the MonitoredCache set up the background Context for for the goroutine that checks for changes in the DNS record. **test: Adds test to simulate the way that the Auth Proxy calls the dialer** Added test TestDialerClosesOldConnectionsAfterDnsChange which will demonstrate that the dialer behaves correctly when it conencts using a DNS name, and then the value of the DNS record changes. The dialer should - close existing connections, - open new connections to the updated instance. - new connections should stay open after the DNS loop completes its first cycle. This also makes the mock resolver more realistic. If it is called with a context that is done, the resolver will return an error. **fix: Create a new context for the domain name check loop. ** When a monitoredCache runs a goroutine to check for changes in the domain name, it needs its own context. The proxy (and other library users) call dialer.Dial(ctx...), they pass in a context with a 30 second timeout. The background process that checks for domains should not use this, it should create its own context with no timeout.
1 parent 81ca4d7 commit 908d0cf

File tree

5 files changed

+364
-40
lines changed

5 files changed

+364
-40
lines changed

dialer.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ func (d *Dialer) connectionInfoCache(
754754
d.dialerID, useIAMAuthNDial,
755755
)
756756
}
757-
c = newMonitoredCache(ctx, cache, cn, d.failoverPeriod, d.resolver, d.logger)
757+
c = newMonitoredCache(cache, cn, d.failoverPeriod, d.resolver, d.logger)
758758
d.cache[k] = c
759759

760760
return c, nil

dialer_test.go

Lines changed: 139 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) {
413413
spy := &spyConnectionInfoCache{
414414
connectInfoCalls: []connectionInfoResp{tc.resp},
415415
}
416-
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
416+
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)
417417

418418
_, err = d.EngineVersion(context.Background(), tc.icn)
419419
if err == nil {
@@ -561,7 +561,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) {
561561
spy := &spyConnectionInfoCache{
562562
connectInfoCalls: []connectionInfoResp{tc.resp},
563563
}
564-
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
564+
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)
565565

566566
err = d.Warmup(context.Background(), tc.icn, tc.opts...)
567567
if err == nil {
@@ -769,7 +769,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
769769
spy := &spyConnectionInfoCache{
770770
connectInfoCalls: []connectionInfoResp{tc.resp},
771771
}
772-
d.cache[createKey(inst)] = newMonitoredCache(nil, spy, inst, 0, nil, nil)
772+
d.cache[createKey(inst)] = newMonitoredCache(spy, inst, 0, nil, nil)
773773

774774
_, err = d.Dial(context.Background(), tc.icn, tc.opts...)
775775
if err == nil {
@@ -819,7 +819,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
819819
},
820820
},
821821
}
822-
d.cache[createKey(cn)] = newMonitoredCache(nil, spy, cn, 0, nil, nil)
822+
d.cache[createKey(cn)] = newMonitoredCache(spy, cn, 0, nil, nil)
823823

824824
_, err = d.Dial(context.Background(), icn)
825825
if !errors.Is(err, sentinel) {
@@ -1063,16 +1063,22 @@ type changingResolver struct {
10631063
stage atomic.Int32
10641064
}
10651065

1066-
func (r *changingResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) {
1067-
// For TestDialerFailoverOnInstanceChange
1068-
if name == "update.example.com" {
1069-
if r.stage.Load() == 0 {
1070-
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
1066+
func (r *changingResolver) Resolve(ctx context.Context, name string) (instance.ConnName, error) {
1067+
select {
1068+
// for TestDialerClosesOldConnectionsOpenAfterDnsChange
1069+
case <-ctx.Done():
1070+
return instance.ConnName{}, fmt.Errorf("mock dns timeout error")
1071+
default:
1072+
// For TestDialerFailoverOnInstanceChange
1073+
if name == "update.example.com" {
1074+
if r.stage.Load() == 0 {
1075+
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "update.example.com")
1076+
}
1077+
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com")
10711078
}
1072-
return instance.ParseConnNameWithDomainName("my-project:my-region:my-instance2", "update.example.com")
1079+
// TestDialerFailsDnsSrvRecordMissing
1080+
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
10731081
}
1074-
// TestDialerFailsDnsSrvRecordMissing
1075-
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
10761082
}
10771083

10781084
func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
@@ -1107,7 +1113,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11071113
},
11081114
})
11091115

1110-
// Start the proxy for instance 1
1116+
// Execute the proxy for instance 1
11111117
stop1 := mock.StartServerProxy(t, inst)
11121118
t.Cleanup(func() {
11131119
stop1()
@@ -1127,7 +1133,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11271133
t.Fatal("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed.")
11281134
}
11291135

1130-
// Start the proxy for instance 2
1136+
// Execute the proxy for instance 2
11311137
stop2 := mock.StartServerProxy(t, inst2)
11321138
t.Cleanup(func() {
11331139
stop2()
@@ -1140,6 +1146,124 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11401146

11411147
}
11421148

1149+
func TestDialerClosesOldConnectionsOpenAfterDnsChange(t *testing.T) {
1150+
// At first, the resolver will resolve
1151+
// update.example.com to "my-instance"
1152+
// Then, the resolver will resolve the same domain name to
1153+
// "my-instance2".
1154+
// This shows that on every call to Dial(), the dialer will resolve the
1155+
// SRV record and connect to the correct instance.
1156+
inst := mock.NewFakeCSQLInstance(
1157+
"my-project", "my-region", "my-instance",
1158+
mock.WithDNS("update.example.com"),
1159+
)
1160+
inst2 := mock.NewFakeCSQLInstance(
1161+
"my-project", "my-region", "my-instance2",
1162+
mock.WithDNS("update.example.com"),
1163+
)
1164+
r := &changingResolver{}
1165+
1166+
d := setupDialer(t, setupConfig{
1167+
skipServer: true,
1168+
reqs: []*mock.Request{
1169+
mock.InstanceGetSuccess(inst, 1),
1170+
mock.CreateEphemeralSuccess(inst, 1),
1171+
mock.InstanceGetSuccess(inst2, 1),
1172+
mock.CreateEphemeralSuccess(inst2, 1),
1173+
},
1174+
dialerOptions: []Option{
1175+
WithFailoverPeriod(10 * time.Millisecond),
1176+
WithResolver(r),
1177+
WithTokenSource(mock.EmptyTokenSource{}),
1178+
WithContextDebugLogger(&testLog{t: t}),
1179+
},
1180+
})
1181+
1182+
srv := mock.NewFailoverTestServer(t)
1183+
t.Cleanup(func() {
1184+
srv.Close()
1185+
})
1186+
1187+
// Execute the mock server on 3307 for instance 1
1188+
srv.Start(&inst)
1189+
ctx, cancelFn := context.WithCancel(context.Background())
1190+
defer cancelFn()
1191+
1192+
// Dial using a context with a timeout, similar to how the auth proxy
1193+
// uses the dialer
1194+
dialCtx, cancelFn := context.WithTimeout(context.Background(), 5*time.Second)
1195+
defer cancelFn()
1196+
conn, err := d.Dial(dialCtx, "update.example.com")
1197+
if err != nil {
1198+
t.Fatalf("Dial failed: %v", err)
1199+
}
1200+
c1 := mock.NewDbClient(t, conn, "c1")
1201+
go c1.Execute(ctx)
1202+
defer c1.Close()
1203+
1204+
// Wait 15 seconds to give the fake client time to connect and
1205+
// read from the socket
1206+
time.Sleep(15 * time.Second)
1207+
1208+
// Stop the instance1 mock, then start the instance2 mock server.
1209+
// The dialer will need to refresh before it can connect to instance2.
1210+
t.Logf("Switching to instance2 server")
1211+
srv.Stop()
1212+
srv.Start(&inst2)
1213+
1214+
// Update the DNS resolver. This should signal to the dialer that
1215+
// it should disconnect and refresh certificates.
1216+
t.Logf("Updating DNS record")
1217+
r.stage.Store(1)
1218+
time.Sleep(1 * time.Second)
1219+
1220+
// Dial the domain name again. This should now connect to instance2.
1221+
dialCtx, cancelFn = context.WithTimeout(context.Background(), 5*time.Second)
1222+
defer cancelFn()
1223+
conn, err = d.Dial(dialCtx, "update.example.com")
1224+
if err != nil {
1225+
t.Fatalf("Dial failed: %v", err)
1226+
}
1227+
c2 := mock.NewDbClient(t, conn, "c2")
1228+
go c2.Execute(ctx)
1229+
defer c2.Close()
1230+
1231+
// Wait for the dial context to expire. This ensures that the
1232+
// DNS record loop in the dialer is not dependent on the dial context.
1233+
time.Sleep(15 * time.Second)
1234+
1235+
// Check that the client connections are in the correct state:
1236+
// c1 should be closed because it connected before the domain name changed.
1237+
// c2 should be open because it connected after the domain name changed.
1238+
1239+
// Assert that c1 is closed because the dns record changed
1240+
if !c1.Closed() {
1241+
t.Errorf("want c1 closed, was open")
1242+
}
1243+
// Assert that c1 received some messages, only from inst1
1244+
if len(c1.Recv()) == 0 {
1245+
t.Errorf("c1 wants >0 messages received, got 0")
1246+
}
1247+
for _, m := range c1.Recv() {
1248+
if m != "my-instance" {
1249+
t.Errorf("c1 wants messages from my-instance, got %q", m)
1250+
}
1251+
}
1252+
// Assert that c2 is open. No domain name changes have occurred.
1253+
if c2.Closed() {
1254+
t.Errorf("want c2 open, was closed")
1255+
}
1256+
// Assert that c2 received some messages, only from inst2
1257+
if len(c2.Recv()) == 0 {
1258+
t.Errorf("c2 wants >0 messages received, got 0")
1259+
}
1260+
for _, m := range c2.Recv() {
1261+
if m != "my-instance2" {
1262+
t.Errorf("c2 wants messages from my-instance2, got %q", m)
1263+
}
1264+
}
1265+
}
1266+
11431267
func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) {
11441268

11451269
tcs := []struct {
@@ -1389,7 +1513,7 @@ func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
13891513
mock.RotateCA(inst)
13901514
}
13911515

1392-
// Start the server with new certificates
1516+
// Execute the server with new certificates
13931517
cancel2 := mock.StartServerProxy(t, inst)
13941518
defer cancel2()
13951519

0 commit comments

Comments
 (0)