@@ -413,7 +413,7 @@ func TestEngineVersionRemovesInvalidInstancesFromCache(t *testing.T) {
413413 spy := & spyConnectionInfoCache {
414414 connectInfoCalls : []connectionInfoResp {tc .resp },
415415 }
416- d .cache [createKey (inst )] = newMonitoredCache (nil , spy , inst , 0 , nil , nil )
416+ d .cache [createKey (inst )] = newMonitoredCache (spy , inst , 0 , nil , nil )
417417
418418 _ , err = d .EngineVersion (context .Background (), tc .icn )
419419 if err == nil {
@@ -561,7 +561,7 @@ func TestWarmupRemovesInvalidInstancesFromCache(t *testing.T) {
561561 spy := & spyConnectionInfoCache {
562562 connectInfoCalls : []connectionInfoResp {tc .resp },
563563 }
564- d .cache [createKey (inst )] = newMonitoredCache (nil , spy , inst , 0 , nil , nil )
564+ d .cache [createKey (inst )] = newMonitoredCache (spy , inst , 0 , nil , nil )
565565
566566 err = d .Warmup (context .Background (), tc .icn , tc .opts ... )
567567 if err == nil {
@@ -769,7 +769,7 @@ func TestDialerRemovesInvalidInstancesFromCache(t *testing.T) {
769769 spy := & spyConnectionInfoCache {
770770 connectInfoCalls : []connectionInfoResp {tc .resp },
771771 }
772- d .cache [createKey (inst )] = newMonitoredCache (nil , spy , inst , 0 , nil , nil )
772+ d .cache [createKey (inst )] = newMonitoredCache (spy , inst , 0 , nil , nil )
773773
774774 _ , err = d .Dial (context .Background (), tc .icn , tc .opts ... )
775775 if err == nil {
@@ -819,7 +819,7 @@ func TestDialRefreshesExpiredCertificates(t *testing.T) {
819819 },
820820 },
821821 }
822- d .cache [createKey (cn )] = newMonitoredCache (nil , spy , cn , 0 , nil , nil )
822+ d .cache [createKey (cn )] = newMonitoredCache (spy , cn , 0 , nil , nil )
823823
824824 _ , err = d .Dial (context .Background (), icn )
825825 if ! errors .Is (err , sentinel ) {
@@ -1063,16 +1063,22 @@ type changingResolver struct {
10631063 stage atomic.Int32
10641064}
10651065
1066- func (r * changingResolver ) Resolve (_ context.Context , name string ) (instance.ConnName , error ) {
1067- // For TestDialerFailoverOnInstanceChange
1068- if name == "update.example.com" {
1069- if r .stage .Load () == 0 {
1070- return instance .ParseConnNameWithDomainName ("my-project:my-region:my-instance" , "update.example.com" )
1066+ func (r * changingResolver ) Resolve (ctx context.Context , name string ) (instance.ConnName , error ) {
1067+ select {
1068+ // for TestDialerClosesOldConnectionsOpenAfterDnsChange
1069+ case <- ctx .Done ():
1070+ return instance.ConnName {}, fmt .Errorf ("mock dns timeout error" )
1071+ default :
1072+ // For TestDialerFailoverOnInstanceChange
1073+ if name == "update.example.com" {
1074+ if r .stage .Load () == 0 {
1075+ return instance .ParseConnNameWithDomainName ("my-project:my-region:my-instance" , "update.example.com" )
1076+ }
1077+ return instance .ParseConnNameWithDomainName ("my-project:my-region:my-instance2" , "update.example.com" )
10711078 }
1072- return instance .ParseConnNameWithDomainName ("my-project:my-region:my-instance2" , "update.example.com" )
1079+ // TestDialerFailsDnsSrvRecordMissing
1080+ return instance.ConnName {}, fmt .Errorf ("no resolution for %q" , name )
10731081 }
1074- // TestDialerFailsDnsSrvRecordMissing
1075- return instance.ConnName {}, fmt .Errorf ("no resolution for %q" , name )
10761082}
10771083
10781084func TestDialerUpdatesAutomaticallyAfterDnsChange (t * testing.T ) {
@@ -1107,7 +1113,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11071113 },
11081114 })
11091115
1110- // Start the proxy for instance 1
1116+ // Execute the proxy for instance 1
11111117 stop1 := mock .StartServerProxy (t , inst )
11121118 t .Cleanup (func () {
11131119 stop1 ()
@@ -1127,7 +1133,7 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11271133 t .Fatal ("Expected monitoredCache to be closed after domain name changed. monitoredCache was not closed." )
11281134 }
11291135
1130- // Start the proxy for instance 2
1136+ // Execute the proxy for instance 2
11311137 stop2 := mock .StartServerProxy (t , inst2 )
11321138 t .Cleanup (func () {
11331139 stop2 ()
@@ -1140,6 +1146,124 @@ func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
11401146
11411147}
11421148
1149+ func TestDialerClosesOldConnectionsOpenAfterDnsChange (t * testing.T ) {
1150+ // At first, the resolver will resolve
1151+ // update.example.com to "my-instance"
1152+ // Then, the resolver will resolve the same domain name to
1153+ // "my-instance2".
1154+ // This shows that on every call to Dial(), the dialer will resolve the
1155+ // SRV record and connect to the correct instance.
1156+ inst := mock .NewFakeCSQLInstance (
1157+ "my-project" , "my-region" , "my-instance" ,
1158+ mock .WithDNS ("update.example.com" ),
1159+ )
1160+ inst2 := mock .NewFakeCSQLInstance (
1161+ "my-project" , "my-region" , "my-instance2" ,
1162+ mock .WithDNS ("update.example.com" ),
1163+ )
1164+ r := & changingResolver {}
1165+
1166+ d := setupDialer (t , setupConfig {
1167+ skipServer : true ,
1168+ reqs : []* mock.Request {
1169+ mock .InstanceGetSuccess (inst , 1 ),
1170+ mock .CreateEphemeralSuccess (inst , 1 ),
1171+ mock .InstanceGetSuccess (inst2 , 1 ),
1172+ mock .CreateEphemeralSuccess (inst2 , 1 ),
1173+ },
1174+ dialerOptions : []Option {
1175+ WithFailoverPeriod (10 * time .Millisecond ),
1176+ WithResolver (r ),
1177+ WithTokenSource (mock.EmptyTokenSource {}),
1178+ WithContextDebugLogger (& testLog {t : t }),
1179+ },
1180+ })
1181+
1182+ srv := mock .NewFailoverTestServer (t )
1183+ t .Cleanup (func () {
1184+ srv .Close ()
1185+ })
1186+
1187+ // Execute the mock server on 3307 for instance 1
1188+ srv .Start (& inst )
1189+ ctx , cancelFn := context .WithCancel (context .Background ())
1190+ defer cancelFn ()
1191+
1192+ // Dial using a context with a timeout, similar to how the auth proxy
1193+ // uses the dialer
1194+ dialCtx , cancelFn := context .WithTimeout (context .Background (), 5 * time .Second )
1195+ defer cancelFn ()
1196+ conn , err := d .Dial (dialCtx , "update.example.com" )
1197+ if err != nil {
1198+ t .Fatalf ("Dial failed: %v" , err )
1199+ }
1200+ c1 := mock .NewDbClient (t , conn , "c1" )
1201+ go c1 .Execute (ctx )
1202+ defer c1 .Close ()
1203+
1204+ // Wait 15 seconds to give the fake client time to connect and
1205+ // read from the socket
1206+ time .Sleep (15 * time .Second )
1207+
1208+ // Stop the instance1 mock, then start the instance2 mock server.
1209+ // The dialer will need to refresh before it can connect to instance2.
1210+ t .Logf ("Switching to instance2 server" )
1211+ srv .Stop ()
1212+ srv .Start (& inst2 )
1213+
1214+ // Update the DNS resolver. This should signal to the dialer that
1215+ // it should disconnect and refresh certificates.
1216+ t .Logf ("Updating DNS record" )
1217+ r .stage .Store (1 )
1218+ time .Sleep (1 * time .Second )
1219+
1220+ // Dial the domain name again. This should now connect to instance2.
1221+ dialCtx , cancelFn = context .WithTimeout (context .Background (), 5 * time .Second )
1222+ defer cancelFn ()
1223+ conn , err = d .Dial (dialCtx , "update.example.com" )
1224+ if err != nil {
1225+ t .Fatalf ("Dial failed: %v" , err )
1226+ }
1227+ c2 := mock .NewDbClient (t , conn , "c2" )
1228+ go c2 .Execute (ctx )
1229+ defer c2 .Close ()
1230+
1231+ // Wait for the dial context to expire. This ensures that the
1232+ // DNS record loop in the dialer is not dependent on the dial context.
1233+ time .Sleep (15 * time .Second )
1234+
1235+ // Check that the client connections are in the correct state:
1236+ // c1 should be closed because it connected before the domain name changed.
1237+ // c2 should be open because it connected after the domain name changed.
1238+
1239+ // Assert that c1 is closed because the dns record changed
1240+ if ! c1 .Closed () {
1241+ t .Errorf ("want c1 closed, was open" )
1242+ }
1243+ // Assert that c1 received some messages, only from inst1
1244+ if len (c1 .Recv ()) == 0 {
1245+ t .Errorf ("c1 wants >0 messages received, got 0" )
1246+ }
1247+ for _ , m := range c1 .Recv () {
1248+ if m != "my-instance" {
1249+ t .Errorf ("c1 wants messages from my-instance, got %q" , m )
1250+ }
1251+ }
1252+ // Assert that c2 is open. No domain name changes have occurred.
1253+ if c2 .Closed () {
1254+ t .Errorf ("want c2 open, was closed" )
1255+ }
1256+ // Assert that c2 received some messages, only from inst2
1257+ if len (c2 .Recv ()) == 0 {
1258+ t .Errorf ("c2 wants >0 messages received, got 0" )
1259+ }
1260+ for _ , m := range c2 .Recv () {
1261+ if m != "my-instance2" {
1262+ t .Errorf ("c2 wants messages from my-instance2, got %q" , m )
1263+ }
1264+ }
1265+ }
1266+
11431267func TestDialerChecksSubjectAlternativeNameAndSucceeds (t * testing.T ) {
11441268
11451269 tcs := []struct {
@@ -1389,7 +1513,7 @@ func TestDialerRefreshesAfterRotateCACerts(t *testing.T) {
13891513 mock .RotateCA (inst )
13901514 }
13911515
1392- // Start the server with new certificates
1516+ // Execute the server with new certificates
13931517 cancel2 := mock .StartServerProxy (t , inst )
13941518 defer cancel2 ()
13951519
0 commit comments