Skip to content

Commit f520554

Browse files
craig[bot]dtgolgeek
committed
Merge #149980
149980: roachprod: add A records during cluster create r=golgeek a=dt Release note: none. Epic: none. Co-authored-by: David Taylor <[email protected]> Co-authored-by: Ludovic Leroux <[email protected]>
2 parents ae0c26a + c9b82fc commit f520554

File tree

11 files changed

+122
-47
lines changed

11 files changed

+122
-47
lines changed

pkg/roachprod/cloud/cluster_cloud.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,16 @@ func DestroyCluster(l *logger.Logger, c *Cluster) error {
534534
// DNS entries are destroyed first to ensure that the GC job will not try
535535
// and clean-up entries prematurely.
536536
stopSpinner := ui.NewDefaultSpinner(l, "Destroying DNS entries").Start()
537+
publicRecords := make([]string, 0, len(c.VMs))
538+
for _, v := range c.VMs {
539+
publicRecords = append(publicRecords, v.PublicDNS)
540+
}
537541
dnsErr := vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error {
538-
return p.DeleteRecordsBySubdomain(context.Background(), c.Name)
542+
publicRecordsErr := p.DeletePublicRecordsByName(context.Background(), publicRecords...)
543+
srvRecordsErr := p.DeleteSRVRecordsBySubdomain(context.Background(), c.Name)
544+
return errors.CombineErrors(publicRecordsErr, srvRecordsErr)
539545
})
546+
540547
stopSpinner()
541548

542549
stopSpinner = ui.NewDefaultSpinner(l, "Destroying VMs").Start()

pkg/roachprod/cloud/gc.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ func GCDNS(l *logger.Logger, cloud *Cloud, dryrun bool) error {
570570
sort.Strings(recordNames)
571571

572572
if err := destroyResource(dryrun, func() error {
573-
return p.DeleteRecordsByName(ctx, recordNames...)
573+
return p.DeleteSRVRecordsByName(ctx, recordNames...)
574574
}); err != nil {
575575
return err
576576
}

pkg/roachprod/install/services.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ func (c *SyncedCluster) discoverServices(
154154
mu := syncutil.Mutex{}
155155
records := make([]vm.DNSRecord, 0)
156156
err := vm.FanOutDNS(c.VMs, func(dnsProvider vm.DNSProvider, _ vm.List) error {
157-
r, lookupErr := dnsProvider.LookupSRVRecords(ctx, serviceDNSName(dnsProvider, virtualClusterName, serviceType, c.Name))
157+
r, lookupErr := dnsProvider.LookupRecords(ctx, vm.SRV, serviceDNSName(dnsProvider, virtualClusterName, serviceType, c.Name))
158158
if lookupErr != nil {
159159
return lookupErr
160160
}

pkg/roachprod/install/services_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ func TestMultipleRegistrations(t *testing.T) {
334334
verify := func(c *SyncedCluster, servicesToRegister [][]ServiceDesc) bool {
335335
for _, services := range servicesToRegister {
336336
if len(services) == 0 {
337-
err := testDNS.DeleteRecordsBySubdomain(ctx, c.Name)
337+
err := testDNS.DeleteSRVRecordsBySubdomain(ctx, c.Name)
338338
require.NoError(t, err)
339339
continue
340340
}

pkg/roachprod/roachprod.go

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,6 +1772,11 @@ func Create(
17721772
// No need for ssh for local clusters.
17731773
return LoadClusters()
17741774
}
1775+
1776+
if err := CreatePublicDNS(ctx, l, clusterName); err != nil {
1777+
l.Printf("Failed to create DNS for cluster %s: %v", clusterName, err)
1778+
}
1779+
17751780
l.Printf("Created cluster %s; setting up SSH...", clusterName)
17761781
return SetupSSH(ctx, l, clusterName, false /* sync */)
17771782
}
@@ -2597,8 +2602,34 @@ func DestroyDNS(ctx context.Context, l *logger.Logger, clusterName string) error
25972602
if err != nil {
25982603
return err
25992604
}
2605+
publicRecords := make([]string, 0, len(c.VMs))
2606+
for _, v := range c.VMs {
2607+
publicRecords = append(publicRecords, v.PublicDNS)
2608+
}
2609+
2610+
return vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error {
2611+
return errors.CombineErrors(
2612+
p.DeleteSRVRecordsBySubdomain(ctx, c.Name),
2613+
p.DeletePublicRecordsByName(ctx, publicRecords...),
2614+
)
2615+
})
2616+
}
2617+
2618+
// CreatePublicDNS creates or updates the public A records for the given cluster.
2619+
func CreatePublicDNS(ctx context.Context, l *logger.Logger, clusterName string) error {
2620+
c, err := GetClusterFromCache(l, clusterName)
2621+
if err != nil {
2622+
return err
2623+
}
2624+
26002625
return vm.FanOutDNS(c.VMs, func(p vm.DNSProvider, vms vm.List) error {
2601-
return p.DeleteRecordsBySubdomain(ctx, c.Name)
2626+
recs := make([]vm.DNSRecord, 0, len(c.VMs))
2627+
for _, v := range c.VMs {
2628+
rec := vm.CreateDNSRecord(v.PublicDNS, vm.A, v.PublicIP, 60)
2629+
rec.Public = true
2630+
recs = append(recs, rec)
2631+
}
2632+
return p.CreateRecords(ctx, recs...)
26022633
})
26032634
}
26042635

pkg/roachprod/vm/dns.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ type DNSRecord struct {
3838
Data string `json:"data"`
3939
// TTL is the time to live of the DNS record.
4040
TTL int `json:"TTL"`
41+
// Public indicates whether the DNS should be published in the public zone.
42+
Public bool
4143
}
4244

4345
// DNSProvider is an optional capability for a Provider that provides DNS
@@ -49,13 +51,15 @@ type DNSProvider interface {
4951
// subdomain. The protocol is usually "tcp" and the subdomain is usually the
5052
// cluster name. The service is a combination of the virtual cluster name and
5153
// type of service.
52-
LookupSRVRecords(ctx context.Context, name string) ([]DNSRecord, error)
54+
LookupRecords(ctx context.Context, recordType DNSType, name string) ([]DNSRecord, error)
5355
// ListRecords lists all DNS records managed for the zone.
5456
ListRecords(ctx context.Context) ([]DNSRecord, error)
55-
// DeleteRecordsBySubdomain deletes all DNS records with the given subdomain.
56-
DeleteRecordsBySubdomain(ctx context.Context, subdomain string) error
57-
// DeleteRecordsByName deletes all DNS records with the given name.
58-
DeleteRecordsByName(ctx context.Context, names ...string) error
57+
// DeleteSRVRecordsBySubdomain deletes all DNS SRV records with the given subdomain.
58+
DeleteSRVRecordsBySubdomain(ctx context.Context, subdomain string) error
59+
// DeleteRecordsByName deletes all DNS SRV records with the given name.
60+
DeleteSRVRecordsByName(ctx context.Context, names ...string) error
61+
// DeletePublicRecordsByName deletes all DNS A records named.
62+
DeletePublicRecordsByName(ctx context.Context, names ...string) error
5963
// Domain returns the domain name (zone) of the DNS provider.
6064
Domain() string
6165
}

pkg/roachprod/vm/gce/dns.go

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,12 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
131131
}
132132

133133
for name, recordGroup := range recordsByName {
134+
// We assume that all records in a group have the same name, type, and ttl.
135+
// TODO(herko): Add error checking to ensure that the above is the case.
136+
firstRecord := recordGroup[0]
137+
134138
err := n.withRecordLock(name, func() error {
135-
existingRecords, err := n.lookupSRVRecords(ctx, name)
139+
existingRecords, err := n.lookupRecords(ctx, firstRecord.Type, name)
136140
if err != nil {
137141
return err
138142
}
@@ -151,15 +155,16 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
151155
combinedRecords[record.Data] = record
152156
}
153157

154-
// We assume that all records in a group have the same name, type, and ttl.
155-
// TODO(herko): Add error checking to ensure that the above is the case.
156-
firstRecord := recordGroup[0]
157158
data := maps.Keys(combinedRecords)
158159
sort.Strings(data)
160+
zone := n.managedZone
161+
if firstRecord.Public {
162+
zone = n.publicZone
163+
}
159164
args := []string{"--project", n.dnsProject, "dns", "record-sets", command, name,
160165
"--type", string(firstRecord.Type),
161166
"--ttl", strconv.Itoa(firstRecord.TTL),
162-
"--zone", n.managedZone,
167+
"--zone", zone,
163168
"--rrdatas", strings.Join(data, ","),
164169
}
165170
cmd := exec.CommandContext(ctx, "gcloud", args...)
@@ -170,10 +175,10 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
170175
n.clearCacheEntry(name)
171176
return rperrors.TransientFailure(errors.Wrapf(err, "output: %s", out), dnsProblemLabel)
172177
}
173-
// If fastDNS is enabled, we need to wait for the records to become available
178+
// If fastDNS is enabled, we need to wait for the SRV records to become available
174179
// on the Google DNS servers.
175-
if config.FastDNS {
176-
err = n.waitForRecordsAvailable(ctx, maps.Values(combinedRecords)...)
180+
if config.FastDNS && !firstRecord.Public {
181+
err = n.waitForSRVRecordsAvailable(ctx, maps.Values(combinedRecords)...)
177182
if err != nil {
178183
return err
179184
}
@@ -190,33 +195,36 @@ func (n *dnsProvider) CreateRecords(ctx context.Context, records ...vm.DNSRecord
190195
}
191196

192197
// LookupSRVRecords implements the vm.DNSProvider interface.
193-
func (n *dnsProvider) LookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) {
198+
func (n *dnsProvider) LookupRecords(
199+
ctx context.Context, recordType vm.DNSType, name string,
200+
) ([]vm.DNSRecord, error) {
194201
var records []vm.DNSRecord
195202
var err error
196203
err = n.withRecordLock(name, func() error {
197-
if config.FastDNS {
204+
if config.FastDNS && recordType == vm.SRV {
198205
rIdx := randutil.FastUint32() % uint32(len(n.resolvers))
199206
records, err = n.fastLookupSRVRecords(ctx, n.resolvers[rIdx], name, true)
200207
return err
201208
}
202-
records, err = n.lookupSRVRecords(ctx, name)
209+
records, err = n.lookupRecords(ctx, recordType, name)
203210
return err
204211
})
205212
return records, err
206213
}
207214

208215
// ListRecords implements the vm.DNSProvider interface.
209216
func (n *dnsProvider) ListRecords(ctx context.Context) ([]vm.DNSRecord, error) {
210-
return n.listSRVRecords(ctx, "", dnsMaxResults)
217+
return n.listRecords(ctx, vm.SRV, "", dnsMaxResults)
211218
}
212219

213-
// DeleteRecordsByName implements the vm.DNSProvider interface.
214-
func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string) error {
220+
func (n *dnsProvider) deleteRecords(
221+
ctx context.Context, zone string, recordType vm.DNSType, names ...string,
222+
) error {
215223
for _, name := range names {
216224
err := n.withRecordLock(name, func() error {
217225
args := []string{"--project", n.dnsProject, "dns", "record-sets", "delete", name,
218-
"--type", string(vm.SRV),
219-
"--zone", n.managedZone,
226+
"--type", string(recordType),
227+
"--zone", zone,
220228
}
221229
cmd := exec.CommandContext(ctx, "gcloud", args...)
222230
out, err := n.execFn(cmd)
@@ -235,10 +243,20 @@ func (n *dnsProvider) DeleteRecordsByName(ctx context.Context, names ...string)
235243
return nil
236244
}
237245

246+
// DeleteSRVRecordsByName implements the vm.DNSProvider interface.
247+
func (n *dnsProvider) DeleteSRVRecordsByName(ctx context.Context, names ...string) error {
248+
return n.deleteRecords(ctx, n.managedZone, vm.SRV, names...)
249+
}
250+
251+
// DeletePublicRecordsByName implements the vm.DNSProvider interface
252+
func (n *dnsProvider) DeletePublicRecordsByName(ctx context.Context, names ...string) error {
253+
return n.deleteRecords(ctx, n.publicZone, vm.A, names...)
254+
}
255+
238256
// DeleteRecordsBySubdomain implements the vm.DNSProvider interface.
239-
func (n *dnsProvider) DeleteRecordsBySubdomain(ctx context.Context, subdomain string) error {
257+
func (n *dnsProvider) DeleteSRVRecordsBySubdomain(ctx context.Context, subdomain string) error {
240258
suffix := fmt.Sprintf("%s.%s.", subdomain, n.Domain())
241-
records, err := n.listSRVRecords(ctx, suffix, dnsMaxResults)
259+
records, err := n.listRecords(ctx, vm.SRV, suffix, dnsMaxResults)
242260
if err != nil {
243261
return err
244262
}
@@ -256,7 +274,7 @@ func (n *dnsProvider) DeleteRecordsBySubdomain(ctx context.Context, subdomain st
256274
delete(names, name)
257275
}
258276
}
259-
return n.DeleteRecordsByName(ctx, maps.Keys(names)...)
277+
return n.DeleteSRVRecordsByName(ctx, maps.Keys(names)...)
260278
}
261279

262280
// Domain implements the vm.DNSProvider interface.
@@ -272,13 +290,15 @@ func (n *dnsProvider) Domain() string {
272290
// network problems. For lookups, we prefer this to using the gcloud command as
273291
// it is faster, and preferable when service information is being queried
274292
// regularly.
275-
func (n *dnsProvider) lookupSRVRecords(ctx context.Context, name string) ([]vm.DNSRecord, error) {
293+
func (n *dnsProvider) lookupRecords(
294+
ctx context.Context, recordType vm.DNSType, name string,
295+
) ([]vm.DNSRecord, error) {
276296
// Check the cache first.
277297
if cachedRecords, ok := n.getCache(name); ok {
278298
return cachedRecords, nil
279299
}
280300
// Lookup the records, if no records are found in the cache.
281-
records, err := n.listSRVRecords(ctx, name, dnsMaxResults)
301+
records, err := n.listRecords(ctx, recordType, name, dnsMaxResults)
282302
if err != nil {
283303
return nil, err
284304
}
@@ -295,16 +315,21 @@ func (n *dnsProvider) lookupSRVRecords(ctx context.Context, name string) ([]vm.D
295315
return filteredRecords, nil
296316
}
297317

298-
// listSRVRecords returns all SRV records that match the given filter from Google Cloud DNS.
318+
// listRecords returns all records that match the given filter from Google Cloud DNS.
299319
// The data field of the records could be a comma-separated list of values if multiple
300320
// records are returned for the same name.
301-
func (n *dnsProvider) listSRVRecords(
302-
ctx context.Context, filter string, limit int,
321+
func (n *dnsProvider) listRecords(
322+
ctx context.Context, recordType vm.DNSType, filter string, limit int,
303323
) ([]vm.DNSRecord, error) {
324+
zone := n.managedZone
325+
if recordType == vm.A {
326+
zone = n.publicZone
327+
}
328+
304329
args := []string{"--project", n.dnsProject, "dns", "record-sets", "list",
305330
"--limit", strconv.Itoa(limit),
306331
"--page-size", strconv.Itoa(limit),
307-
"--zone", n.managedZone,
332+
"--zone", zone,
308333
"--format", "json",
309334
}
310335
if filter != "" {
@@ -333,11 +358,11 @@ func (n *dnsProvider) listSRVRecords(
333358
if record.Kind != "dns#resourceRecordSet" {
334359
continue
335360
}
336-
if record.RecordType != string(vm.SRV) {
361+
if record.RecordType != string(recordType) {
337362
continue
338363
}
339364
for _, data := range record.RRDatas {
340-
records = append(records, vm.CreateDNSRecord(record.Name, vm.SRV, data, record.TTL))
365+
records = append(records, vm.CreateDNSRecord(record.Name, recordType, data, record.TTL))
341366
}
342367
}
343368
return records, nil

pkg/roachprod/vm/gce/fast_dns.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ func googleDNSResolvers() []*net.Resolver {
3535

3636
// waitForRecordsAvailable waits for the DNS records to become available on all
3737
// the DNS servers through a standard net tools lookup.
38-
func (n *dnsProvider) waitForRecordsAvailable(ctx context.Context, records ...vm.DNSRecord) error {
38+
func (n *dnsProvider) waitForSRVRecordsAvailable(
39+
ctx context.Context, records ...vm.DNSRecord,
40+
) error {
3941
checkResolver := func(resolver *net.Resolver) error {
4042
type recordKey struct {
4143
name string

pkg/roachprod/vm/local/dns.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,17 @@ func (n *dnsProvider) CreateRecords(_ context.Context, records ...vm.DNSRecord)
5858
return n.saveRecords(entries)
5959
}
6060

61-
// LookupSRVRecords is part of the vm.DNSProvider interface.
62-
func (n *dnsProvider) LookupSRVRecords(_ context.Context, name string) ([]vm.DNSRecord, error) {
61+
// LookupRecords is part of the vm.DNSProvider interface.
62+
func (n *dnsProvider) LookupRecords(
63+
_ context.Context, recordType vm.DNSType, name string,
64+
) ([]vm.DNSRecord, error) {
6365
records, err := n.loadRecords()
6466
if err != nil {
6567
return nil, err
6668
}
6769
var matchingRecords []vm.DNSRecord
6870
for _, record := range records {
69-
if record.Name == name && record.Type == vm.SRV {
71+
if record.Name == name && record.Type == recordType {
7072
matchingRecords = append(matchingRecords, record)
7173
}
7274
}
@@ -82,8 +84,12 @@ func (n *dnsProvider) ListRecords(_ context.Context) ([]vm.DNSRecord, error) {
8284
return maps.Values(records), nil
8385
}
8486

87+
func (n *dnsProvider) DeletePublicRecordsByName(ctx context.Context, names ...string) error {
88+
return n.DeleteSRVRecordsByName(ctx, names...)
89+
}
90+
8591
// DeleteRecordsByName is part of the vm.DNSProvider interface.
86-
func (n *dnsProvider) DeleteRecordsByName(_ context.Context, names ...string) error {
92+
func (n *dnsProvider) DeleteSRVRecordsByName(_ context.Context, names ...string) error {
8793
unlock, err := lock.AcquireFilesystemLock(n.lockFilePath)
8894
if err != nil {
8995
return err
@@ -100,8 +106,8 @@ func (n *dnsProvider) DeleteRecordsByName(_ context.Context, names ...string) er
100106
return n.saveRecords(entries)
101107
}
102108

103-
// DeleteRecordsBySubdomain is part of the vm.DNSProvider interface.
104-
func (n *dnsProvider) DeleteRecordsBySubdomain(_ context.Context, subdomain string) error {
109+
// DeleteSRVRecordsBySubdomain is part of the vm.DNSProvider interface.
110+
func (n *dnsProvider) DeleteSRVRecordsBySubdomain(_ context.Context, subdomain string) error {
105111
unlock, err := lock.AcquireFilesystemLock(n.lockFilePath)
106112
if err != nil {
107113
return err

pkg/roachprod/vm/local/dns_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func TestLookupRecords(t *testing.T) {
4848
}...)
4949

5050
t.Run("lookup system", func(t *testing.T) {
51-
records, err := p.LookupSRVRecords(ctx, "_system-sql._tcp.local.local-zone")
51+
records, err := p.LookupRecords(ctx, vm.SRV, "_system-sql._tcp.local.local-zone")
5252
require.NoError(t, err)
5353
require.Equal(t, 3, len(records))
5454
for _, r := range records {
@@ -58,7 +58,7 @@ func TestLookupRecords(t *testing.T) {
5858
})
5959

6060
t.Run("parse SRV data", func(t *testing.T) {
61-
records, err := p.LookupSRVRecords(ctx, "_tenant-1-sql._tcp.local.local-zone")
61+
records, err := p.LookupRecords(ctx, vm.SRV, "_tenant-1-sql._tcp.local.local-zone")
6262
require.NoError(t, err)
6363
require.Equal(t, 1, len(records))
6464
data, err := records[0].ParseSRVRecord()

0 commit comments

Comments
 (0)