Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ type Dialer struct {

// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver
dnsResolver cloudsql.NetResolver
failoverPeriod time.Duration

// metadataExchangeDisabled true when the dialer should never
Expand Down Expand Up @@ -213,6 +214,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: nullLogger{},
useragents: []string{userAgent},
failoverPeriod: cloudsql.FailoverPeriod,
dnsResolver: net.DefaultResolver,
}
for _, opt := range opts {
opt(cfg)
Expand Down Expand Up @@ -321,6 +323,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
dialerID: uuid.New().String(),
iamTokenProvider: cfg.iamLoginTokenProvider,
dialFunc: cfg.dialFunc,
dnsResolver: cfg.dnsResolver,
resolver: r,
failoverPeriod: cfg.failoverPeriod,
metadataExchangeDisabled: cfg.metadataExchangeDisabled,
Expand Down Expand Up @@ -407,6 +410,28 @@ func (d *Dialer) Dial(ctx context.Context, icn string, opts ...DialOption) (conn
d.removeCached(ctx, cn, c, err)
return nil, err
}

// If the connector is configured with a custom DNS name, attempt to use
// that DNS name to connect to the instance. Fall back to the metadata IP
// address if the DNS name does not resolve to an IP address.
if cn.HasDomainName() {
addrs, err := d.dnsResolver.LookupHost(ctx, cn.DomainName())
if err != nil {
d.logger.Debugf(ctx,
"[%v] custom DNS name %q did not resolve to an IP address: %v, using %s from instance metadata",
cn.String(), cn.DomainName(), err, addr)
} else if len(addrs) == 0 {
d.logger.Debugf(ctx,
"[%v] custom DNS name %q resolved but returned no entries, using %s from instance metadata",
cn.String(), cn.DomainName(), addr)
} else {
d.logger.Debugf(ctx,
"[%v] custom DNS name %q resolved to %q, using it to connect",
cn.String(), cn.DomainName(), addrs[0])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing it is always in the first item of the array?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usually it only returns 1 result. If it returns more than 1 result, that is legal, but probably a mistake.

addr = addrs[0]
}
}

addr = net.JoinHostPort(addr, serverProxyPort)
f := d.dialFunc
if cfg.dialFunc != nil {
Expand Down
167 changes: 130 additions & 37 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ func testSucessfulDialWithInstanceName(
}
}

// withMockDNSResolver replaces net.DefaultResolver with a mock resolver
func withMockDNSResolver(r cloudsql.NetResolver) Option {
return func(d *dialerConfig) {
d.dnsResolver = r
}
}

// setupConfig holds all the configuration to use when setting up a dialer.
type setupConfig struct {
testInstance mock.FakeCSQLInstance
Expand Down Expand Up @@ -1017,15 +1024,23 @@ func TestDialerInitializesLazyCache(t *testing.T) {
}
}

type fakeResolver struct {
entries map[string]instance.ConnName
type mockNetResolver struct {
txtEntries map[string]string
hostEntries map[string]string
}

func (r *mockNetResolver) LookupTXT(_ context.Context, name string) ([]string, error) {
if val, ok := r.txtEntries[name]; ok {
return []string{val}, nil
}
return nil, fmt.Errorf("no resolution for %q", name)
}

func (r *fakeResolver) Resolve(_ context.Context, name string) (instance.ConnName, error) {
if val, ok := r.entries[name]; ok {
return val, nil
func (r *mockNetResolver) LookupHost(_ context.Context, name string) ([]string, error) {
if val, ok := r.hostEntries[name]; ok {
return []string{val}, nil
}
return instance.ConnName{}, fmt.Errorf("no resolution for %q", name)
return nil, fmt.Errorf("no resolution for %q", name)
}

func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) {
Expand All @@ -1034,8 +1049,6 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) {
mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"),
mock.WithDNSMapping("db2.example.com", "INSTANCE", "CUSTOM_SAN"),
)
wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com")
wantName2, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db2.example.com")
// This will create 2 separate connectionInfoCache entries, one for
// each DNS name.
d := setupDialer(t, setupConfig{
Expand All @@ -1045,13 +1058,55 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) {
mock.CreateEphemeralSuccess(inst, 2),
},
dialerOptions: []Option{
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"db.example.com": "my-project:my-region:my-instance",
"db2.example.com": "my-project:my-region:my-instance",
},
}),
WithDNSResolver(),
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"db.example.com": wantName,
"db2.example.com": wantName2,
},
})

testSuccessfulDial(
context.Background(), t, d,
"db.example.com",
)
testSuccessfulDial(
context.Background(), t, d,
"db2.example.com",
)
}

func TestDialerSuccessfullyDialsDnsTxtRecordWithCustomARecords(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"),
mock.WithDNSMapping("db2.example.com", "INSTANCE", "CUSTOM_SAN"),
)

// This will create 2 separate connectionInfoCache entries, one for
// each DNS name.
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},
dialerOptions: []Option{
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"db.example.com": "my-project:my-region:my-instance",
"db2.example.com": "my-project:my-region:my-instance",
},
hostEntries: map[string]string{
"db.example.com": "127.0.0.1",
"db2.example.com": "127.0.0.2",
},
}),
WithTokenSource(mock.EmptyTokenSource{}),
WithDNSResolver(),
},
})

Expand All @@ -1065,16 +1120,55 @@ func TestDialerSuccessfullyDialsDnsTxtRecord(t *testing.T) {
)
}

func TestDialerFailsDnsTxtRecordMissing(t *testing.T) {
func TestDialerFailsDnsTxtRecordWithInvalidCustomARecords(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
mock.WithDNSMapping("db.example.com", "INSTANCE", "CUSTOM_SAN"),
)

// This will create 2 separate connectionInfoCache entries, one for
// each DNS name.
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
dialerOptions: []Option{
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"db.example.com": "my-project:my-region:my-instance",
},
hostEntries: map[string]string{
"db.example.com": "1.1.1.1",
},
}),
WithTokenSource(mock.EmptyTokenSource{}),
WithDNSResolver(),
},
})
ctx, cancelFn := context.WithTimeout(context.Background(), 1*time.Second)
defer cancelFn()
_, err := d.Dial(ctx, "db.example.com")
// Expect an error due to the timeout.
if err == nil {
t.Fatal("Dial should have failed due to bad IP address")
}
t.Log("timeout", err)

}

func TestDialerFailsDNSTxtRecordMissing(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
)
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{},
dialerOptions: []Option{
withMockDNSResolver(&mockNetResolver{}),
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{}),
WithDNSResolver(),
},
})
_, err := d.Dial(context.Background(), "doesnt-exist.example.com")
Expand Down Expand Up @@ -1106,6 +1200,10 @@ func (r *changingResolver) Resolve(ctx context.Context, name string) (instance.C
}
}

func (r *changingResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return net.DefaultResolver.LookupHost(ctx, host)
}

func TestDialerUpdatesAutomaticallyAfterDnsChange(t *testing.T) {
// At first, the resolver will resolve
// update.example.com to "my-instance"
Expand Down Expand Up @@ -1334,21 +1432,20 @@ func TestDialerChecksSubjectAlternativeNameAndSucceeds(t *testing.T) {
)
}

wantName, _ := instance.ParseConnNameWithDomainName(tc.icn, tc.dn)
d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"db.example.com": wantName,
"my-project:my-region:my-instance": wantName,
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"db.example.com": "my-project:my-region:my-instance",
},
}),
WithTokenSource(mock.EmptyTokenSource{}),
WithDNSResolver(),
},
})
dnOrIcn := tc.icn
Expand All @@ -1375,21 +1472,20 @@ func TestDialerChecksSubjectAlternativeNameAndFails(t *testing.T) {
)

// Resolve the dns name 'bad.example.com' to the the instance.
wantName, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "bad.example.com")

d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"bad.example.com": wantName,
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"bad.example.com": "my-project:my-region:my-instance",
},
}),
WithTokenSource(mock.EmptyTokenSource{}),
WithDNSResolver(),
},
})

Expand All @@ -1415,25 +1511,22 @@ func TestDialerChecksSubjectAlternativeNameAndFallsBackToCN(t *testing.T) {
)

// resolve db.example.com to the same instance
wantName, _ := instance.ParseConnNameWithDomainName("myProject:myRegion:myInstance", "db.example.com")

d := setupDialer(t, setupConfig{
testInstance: inst,
reqs: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},

dialerOptions: []Option{
WithTokenSource(mock.EmptyTokenSource{}),
WithResolver(&fakeResolver{
entries: map[string]instance.ConnName{
"db.example.com": wantName,
"myProject:myRegion:myInstance": wantName,
withMockDNSResolver(&mockNetResolver{
txtEntries: map[string]string{
"db.example.com": "myProject:myRegion:myInstance",
},
}),
},
})
WithTokenSource(mock.EmptyTokenSource{}),
WithDNSResolver(),
}})

tcs := []struct {
desc string
Expand Down
20 changes: 10 additions & 10 deletions internal/cloudsql/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,12 @@ package cloudsql
import (
"context"
"fmt"
"net"
"sort"

"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
)

// DNSResolver uses the default net.Resolver to find
// TXT records containing an instance name for a DNS record.
var DNSResolver = &DNSInstanceConnectionNameResolver{
dnsResolver: net.DefaultResolver,
}

// DefaultResolver simply parses instance names.
var DefaultResolver = &ConnNameResolver{}

Expand All @@ -44,19 +37,26 @@ func (r *ConnNameResolver) Resolve(_ context.Context, icn string) (instanceName
return instance.ParseConnName(icn)
}

// netResolver groups the methods on net.Resolver that are used by the DNS
// NetResolver groups the methods on net.Resolver that are used by the DNS
// resolver implementation. This allows an application to replace the default
// net.DefaultResolver with a custom implementation. For example: the
// application may need to connect to a specific DNS server using a specially
// configured instance of net.Resolver.
type netResolver interface {
type NetResolver interface {
LookupTXT(ctx context.Context, name string) ([]string, error)
LookupHost(ctx context.Context, name string) ([]string, error)
}

// NewDNSResolver returns a new DNSInstanceConnectionNameResolver with the
// provided resolver.
func NewDNSResolver(r NetResolver) *DNSInstanceConnectionNameResolver {
return &DNSInstanceConnectionNameResolver{dnsResolver: r}
}

// DNSInstanceConnectionNameResolver can resolve domain names into instance names using
// TXT records in DNS. Implements InstanceConnectionNameResolver
type DNSInstanceConnectionNameResolver struct {
dnsResolver netResolver
dnsResolver NetResolver
}

// Resolve returns the instance name, possibly using DNS. This will return an
Expand Down
4 changes: 4 additions & 0 deletions internal/cloudsql/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ func (r *fakeResolver) LookupTXT(_ context.Context, name string) (addrs []string
return nil, fmt.Errorf("no resolution for %v", name)
}

func (r *fakeResolver) LookupHost(_ context.Context, name string) (addrs []string, err error) {
return nil, fmt.Errorf("no resolution for %v", name)
}

func TestDNSInstanceNameResolver_Lookup_Success_TxtRecord(t *testing.T) {
want, _ := instance.ParseConnNameWithDomainName("my-project:my-region:my-instance", "db.example.com")

Expand Down
Loading
Loading