diff --git a/endpoint/endpoint.go b/endpoint/endpoint.go index 42d7f1d1e4..11870b3fc4 100644 --- a/endpoint/endpoint.go +++ b/endpoint/endpoint.go @@ -311,6 +311,22 @@ func (e *Endpoint) GetProviderSpecificProperty(key string) (string, bool) { return "", false } +// GetBoolProperty returns a boolean provider-specific property value. +func (e *Endpoint) GetBoolProviderSpecificProperty(key string) (bool, bool) { + prop, ok := e.GetProviderSpecificProperty(key) + if !ok { + return false, false + } + switch prop { + case "true": + return true, true + case "false": + return false, true + default: + return false, true + } +} + // SetProviderSpecificProperty sets the value of a ProviderSpecificProperty. func (e *Endpoint) SetProviderSpecificProperty(key string, value string) { for i, providerSpecific := range e.ProviderSpecific { diff --git a/endpoint/endpoint_test.go b/endpoint/endpoint_test.go index 381828c3b4..96d171806b 100644 --- a/endpoint/endpoint_test.go +++ b/endpoint/endpoint_test.go @@ -1072,3 +1072,97 @@ func TestNewEndpointWithTTLPreservesDotsInTXTRecords(t *testing.T) { require.NotNil(t, cnameEndpoint, "CNAME endpoint should be created") assert.Equal(t, "target.example.com", cnameEndpoint.Targets[0], "CNAME record should have trailing dot trimmed") } + +func TestGetBoolProviderSpecificProperty(t *testing.T) { + tests := []struct { + name string + endpoint Endpoint + key string + expectedValue bool + expectedExists bool + }{ + { + name: "key does not exist", + endpoint: Endpoint{}, + key: "nonexistent", + expectedValue: false, + expectedExists: false, + }, + { + name: "key exists with true value", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "enabled", Value: "true"}, + }, + }, + key: "enabled", + expectedValue: true, + expectedExists: true, + }, + { + name: "key exists with false value", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "disabled", Value: "false"}, + }, + }, + key: "disabled", + expectedValue: false, + expectedExists: true, + }, + { + name: "key exists with invalid boolean value", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "invalid", Value: "maybe"}, + }, + }, + key: "invalid", + expectedValue: false, + expectedExists: true, + }, + { + name: "key exists with empty value", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "empty", Value: ""}, + }, + }, + key: "empty", + expectedValue: false, + expectedExists: true, + }, + { + name: "key exists with numeric value", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "numeric", Value: "1"}, + }, + }, + key: "numeric", + expectedValue: false, + expectedExists: true, + }, + { + name: "multiple properties, find correct one", + endpoint: Endpoint{ + ProviderSpecific: []ProviderSpecificProperty{ + {Name: "first", Value: "invalid"}, + {Name: "second", Value: "true"}, + {Name: "third", Value: "false"}, + }, + }, + key: "second", + expectedValue: true, + expectedExists: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + value, exists := tt.endpoint.GetBoolProviderSpecificProperty(tt.key) + assert.Equal(t, tt.expectedValue, value) + assert.Equal(t, tt.expectedExists, exists) + }) + } +} diff --git a/provider/aws/aws.go b/provider/aws/aws.go index f33e0bb7e7..5d8351ae0d 100644 --- a/provider/aws/aws.go +++ b/provider/aws/aws.go @@ -922,8 +922,8 @@ func (p *AWSProvider) newChange(action route53types.ChangeAction, ep *endpoint.E change.ResourceRecordSet.Type = route53types.RRType(ep.RecordType) if targetHostedZone := isAWSAlias(ep); targetHostedZone != "" { evalTargetHealth := p.evaluateTargetHealth - if prop, ok := ep.GetProviderSpecificProperty(providerSpecificEvaluateTargetHealth); ok { - evalTargetHealth = prop == "true" + if prop, exists := ep.GetBoolProviderSpecificProperty(providerSpecificEvaluateTargetHealth); exists { + evalTargetHealth = prop } change.ResourceRecordSet.AliasTarget = &route53types.AliasTarget{ DNSName: aws.String(ep.Targets[0]), @@ -1346,8 +1346,8 @@ func useAlias(ep *endpoint.Endpoint, preferCNAME bool) bool { // isAWSAlias determines if a given endpoint is supposed to create an AWS Alias record // and (if so) returns the target hosted zone ID func isAWSAlias(ep *endpoint.Endpoint) string { - isAlias, exists := ep.GetProviderSpecificProperty(providerSpecificAlias) - if exists && isAlias == "true" && slices.Contains([]string{endpoint.RecordTypeA, endpoint.RecordTypeAAAA}, ep.RecordType) && len(ep.Targets) > 0 { + isAlias, _ := ep.GetBoolProviderSpecificProperty(providerSpecificAlias) + if isAlias && slices.Contains([]string{endpoint.RecordTypeA, endpoint.RecordTypeAAAA}, ep.RecordType) && len(ep.Targets) > 0 { // alias records can only point to canonical hosted zones (e.g. to ELBs) or other records in the same zone if hostedZoneID, ok := ep.GetProviderSpecificProperty(providerSpecificTargetHostedZone); ok { diff --git a/registry/txt/registry.go b/registry/txt/registry.go index 0c2cd7c885..404f51a986 100644 --- a/registry/txt/registry.go +++ b/registry/txt/registry.go @@ -244,7 +244,7 @@ func (im *TXTRegistry) Records(ctx context.Context) ([]*endpoint.Endpoint, error } // AWS Alias records have "new" format encoded as type "cname" - if isAlias, found := ep.GetProviderSpecificProperty("alias"); found && isAlias == "true" && ep.RecordType == endpoint.RecordTypeA { + if isAlias, found := ep.GetBoolProviderSpecificProperty("alias"); found && isAlias && ep.RecordType == endpoint.RecordTypeA { key.RecordType = endpoint.RecordTypeCNAME } @@ -299,7 +299,7 @@ func (im *TXTRegistry) generateTXTRecordWithFilter(r *endpoint.Endpoint, filter // Always create new format record recordType := r.RecordType // AWS Alias records are encoded as type "cname" - if isAlias, found := r.GetProviderSpecificProperty("alias"); found && isAlias == "true" && recordType == endpoint.RecordTypeA { + if isAlias, found := r.GetBoolProviderSpecificProperty("alias"); found && isAlias && recordType == endpoint.RecordTypeA { recordType = endpoint.RecordTypeCNAME }