diff --git a/internal/designate/provider/provider.go b/internal/designate/provider/provider.go index b75105a..2fede34 100644 --- a/internal/designate/provider/provider.go +++ b/internal/designate/provider/provider.go @@ -19,6 +19,7 @@ package provider import ( "context" + "external-dns-openstack-webhook/internal/designate/client" "fmt" "strings" @@ -31,8 +32,6 @@ import ( "sigs.k8s.io/external-dns/provider" ) -import "external-dns-openstack-webhook/internal/designate/client" - const ( // ID of the RecordSet from which endpoint was created designateRecordSetID = "designate-recordset-id" @@ -72,10 +71,23 @@ func NewDesignateProvider(domainFilter endpoint.DomainFilter, dryRun bool) (prov func canonicalizeDomainNames(domains []string) []string { var cDomains []string for _, d := range domains { - if !strings.HasSuffix(d, ".") { - d += "." - cDomains = append(cDomains, strings.ToLower(d)) + cDomains = append(cDomains, canonicalizeDomainName(d)) + } + return cDomains +} + +func canonicalizeDomainNamesForMX(domains []string) []string { + var cDomains []string + + for _, d := range domains { + parts := strings.Split(d, " ") + if len(parts) == 2 { + // If the format does not conform to the expected one, play it safe and just pass the value on. + // Otherwise, canonicalize the hostname. + d = fmt.Sprintf("%s %s", parts[0], canonicalizeDomainName(parts[1])) } + + cDomains = append(cDomains, canonicalizeDomainName(d)) } return cDomains } @@ -116,7 +128,7 @@ func getHostZoneID(hostname string, managedZones map[string]string) string { resultID := "" for zoneID, zoneName := range managedZones { - if !strings.HasSuffix(hostname, "." + zoneName) && hostname != zoneName { + if !strings.HasSuffix(hostname, "."+zoneName) && hostname != zoneName { continue } ln := len(zoneName) @@ -139,7 +151,7 @@ func (p designateProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, e for zoneID := range managedZones { err = p.client.ForEachRecordSet(ctx, zoneID, func(recordSet *recordsets.RecordSet) error { - if recordSet.Type != endpoint.RecordTypeA && recordSet.Type != endpoint.RecordTypeTXT && recordSet.Type != endpoint.RecordTypeCNAME { + if !p.supportedRecordType(recordSet.Type) { return nil } @@ -160,6 +172,15 @@ func (p designateProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, e return result, nil } +func (p designateProvider) supportedRecordType(recordType string) bool { + switch recordType { + case endpoint.RecordTypeA, endpoint.RecordTypeTXT, endpoint.RecordTypeCNAME, endpoint.RecordTypeNS, endpoint.RecordTypeMX: + return true + default: + return false + } +} + // temporary structure to hold recordset parameters so that we could aggregate endpoints into recordsets type recordSet struct { dnsName string @@ -197,9 +218,12 @@ func addEndpoint(ep *endpoint.Endpoint, recordSets map[string]*recordSet, oldEnd } } targets := ep.Targets - if ep.RecordType == endpoint.RecordTypeCNAME { + if ep.RecordType == endpoint.RecordTypeCNAME || ep.RecordType == endpoint.RecordTypeNS { targets = canonicalizeDomainNames(targets) } + if ep.RecordType == endpoint.RecordTypeMX { + targets = canonicalizeDomainNamesForMX(targets) + } for _, t := range targets { rs.names[t] = !delete } diff --git a/internal/designate/provider/provider_test.go b/internal/designate/provider/provider_test.go index 50415bc..3e6c45e 100644 --- a/internal/designate/provider/provider_test.go +++ b/internal/designate/provider/provider_test.go @@ -409,7 +409,6 @@ func testDesignateCreateRecords(t *testing.T, client *fakeDesignateClient) []*re TTL: 60, Type: endpoint.RecordTypeTXT, }) - if err != nil { t.Fatal("failed to prefill records") } @@ -452,6 +451,18 @@ func testDesignateCreateRecords(t *testing.T, client *fakeDesignateClient) []*re Targets: endpoint.Targets{"sql.test.net"}, Labels: map[string]string{}, }, + { + DNSName: "ns.test.net", + RecordType: endpoint.RecordTypeNS, + Targets: endpoint.Targets{"ns1.test.net", "ns2.test.net", "ns3.test.net"}, + Labels: map[string]string{}, + }, + { + DNSName: "mx.test.net", + RecordType: endpoint.RecordTypeMX, + Targets: endpoint.Targets{"10 mx1.test.net", "100 mx2.test.net"}, + Labels: map[string]string{}, + }, } expected := []*recordsets.RecordSet{ { @@ -485,6 +496,18 @@ func testDesignateCreateRecords(t *testing.T, client *fakeDesignateClient) []*re Records: []string{"sql.test.net."}, ZoneID: "zone-2", }, + { + Name: "ns.test.net.", + Type: endpoint.RecordTypeNS, + Records: []string{"ns1.test.net.", "ns2.test.net.", "ns3.test.net."}, + ZoneID: "zone-2", + }, + { + Name: "mx.test.net.", + Type: endpoint.RecordTypeMX, + Records: []string{"10 mx1.test.net.", "100 mx2.test.net."}, + ZoneID: "zone-2", + }, } expectedCopy := make([]*recordsets.RecordSet, len(expected)) copy(expectedCopy, expected) @@ -550,6 +573,16 @@ func testDesignateUpdateRecords(t *testing.T, client *fakeDesignateClient) []*re designateOriginalRecords: "10.2.1.1\00010.2.1.2", }, }, + { + DNSName: "ns.test.net.", + RecordType: endpoint.RecordTypeNS, + Targets: endpoint.Targets{"ns1.test.net", "ns2.test.net", "ns3.test.net"}, + Labels: map[string]string{ + designateZoneID: "zone-2", + designateRecordSetID: expected[5].ID, + designateOriginalRecords: "ns1.test.net.\000ns2.test.net.\000ns3.test.net.", + }, + }, } updatesNew := []*endpoint.Endpoint{ { @@ -573,6 +606,16 @@ func testDesignateUpdateRecords(t *testing.T, client *fakeDesignateClient) []*re designateOriginalRecords: "10.2.1.1\00010.2.1.2", }, }, + { + DNSName: "ns.test.net.", + RecordType: endpoint.RecordTypeNS, + Targets: endpoint.Targets{"ns1.test.invalid", "ns2.test.invalid", "ns3.test.invalid"}, + Labels: map[string]string{ + designateZoneID: "zone-2", + designateRecordSetID: expected[5].ID, + designateOriginalRecords: "ns1.test.net.\000ns2.test.net.\000ns3.test.net.", + }, + }, } expectedCopy := make([]*recordsets.RecordSet, len(expected)) copy(expectedCopy, expected) @@ -580,6 +623,7 @@ func testDesignateUpdateRecords(t *testing.T, client *fakeDesignateClient) []*re expected[2].Records = []string{"10.3.3.1"} expected[2].TTL = 60 expected[3].Records = []string{"10.2.1.1", "10.3.3.2"} + expected[5].Records = []string{"ns1.test.invalid.", "ns2.test.invalid.", "ns3.test.invalid."} err := client.ToProvider().ApplyChanges(context.Background(), &plan.Changes{UpdateOld: updatesOld, UpdateNew: updatesNew}) if err != nil {