Skip to content

Commit 160b811

Browse files
authored
[client] Distinguish between NXDOMAIN and NODATA in the dns forwarder (#4321)
1 parent 5e607cf commit 160b811

File tree

2 files changed

+169
-6
lines changed

2 files changed

+169
-6
lines changed

client/internal/dnsfwd/forwarder.go

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) *dns
165165
defer cancel()
166166
ips, err := f.resolver.LookupNetIP(ctx, network, domain)
167167
if err != nil {
168-
f.handleDNSError(w, query, resp, domain, err)
168+
f.handleDNSError(ctx, w, question, resp, domain, err)
169169
return nil
170170
}
171171

@@ -244,20 +244,57 @@ func (f *DNSForwarder) updateFirewall(matchingEntries []*ForwarderEntry, prefixe
244244
}
245245
}
246246

247+
// setResponseCodeForNotFound determines and sets the appropriate response code when IsNotFound is true
248+
// It distinguishes between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of requested type)
249+
//
250+
// LIMITATION: This function only checks A and AAAA record types to determine domain existence.
251+
// If a domain has only other record types (MX, TXT, CNAME, etc.) but no A/AAAA records,
252+
// it may incorrectly return NXDOMAIN instead of NODATA. This is acceptable since the forwarder
253+
// only handles A/AAAA queries and returns NOTIMP for other types.
254+
func (f *DNSForwarder) setResponseCodeForNotFound(ctx context.Context, resp *dns.Msg, domain string, originalQtype uint16) {
255+
// Try querying for a different record type to see if the domain exists
256+
// If the original query was for AAAA, try A. If it was for A, try AAAA.
257+
// This helps distinguish between NXDOMAIN and NODATA.
258+
var alternativeNetwork string
259+
switch originalQtype {
260+
case dns.TypeAAAA:
261+
alternativeNetwork = "ip4"
262+
case dns.TypeA:
263+
alternativeNetwork = "ip6"
264+
default:
265+
resp.Rcode = dns.RcodeNameError
266+
return
267+
}
268+
269+
if _, err := f.resolver.LookupNetIP(ctx, alternativeNetwork, domain); err != nil {
270+
var dnsErr *net.DNSError
271+
if errors.As(err, &dnsErr) && dnsErr.IsNotFound {
272+
// Alternative query also returned not found - domain truly doesn't exist
273+
resp.Rcode = dns.RcodeNameError
274+
return
275+
}
276+
// Some other error (timeout, server failure, etc.) - can't determine, assume domain exists
277+
resp.Rcode = dns.RcodeSuccess
278+
return
279+
}
280+
281+
// Alternative query succeeded - domain exists but has no records of this type
282+
resp.Rcode = dns.RcodeSuccess
283+
}
284+
247285
// handleDNSError processes DNS lookup errors and sends an appropriate error response
248-
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, query, resp *dns.Msg, domain string, err error) {
286+
func (f *DNSForwarder) handleDNSError(ctx context.Context, w dns.ResponseWriter, question dns.Question, resp *dns.Msg, domain string, err error) {
249287
var dnsErr *net.DNSError
250288

251289
switch {
252290
case errors.As(err, &dnsErr):
253291
resp.Rcode = dns.RcodeServerFailure
254292
if dnsErr.IsNotFound {
255-
// Pass through NXDOMAIN
256-
resp.Rcode = dns.RcodeNameError
293+
f.setResponseCodeForNotFound(ctx, resp, domain, question.Qtype)
257294
}
258295

259296
if dnsErr.Server != "" {
260-
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[query.Question[0].Qtype], domain, dnsErr.Server, err)
297+
log.Warnf("failed to resolve query for type=%s domain=%s server=%s: %v", dns.TypeToString[question.Qtype], domain, dnsErr.Server, err)
261298
} else {
262299
log.Warnf(errResolveFailed, domain, err)
263300
}

client/internal/dnsfwd/forwarder_test.go

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dnsfwd
33
import (
44
"context"
55
"fmt"
6+
"net"
67
"net/netip"
78
"strings"
89
"testing"
@@ -16,8 +17,8 @@ import (
1617
firewall "github.com/netbirdio/netbird/client/firewall/manager"
1718
"github.com/netbirdio/netbird/client/internal/dns/test"
1819
"github.com/netbirdio/netbird/client/internal/peer"
19-
"github.com/netbirdio/netbird/shared/management/domain"
2020
"github.com/netbirdio/netbird/route"
21+
"github.com/netbirdio/netbird/shared/management/domain"
2122
)
2223

2324
func Test_getMatchingEntries(t *testing.T) {
@@ -708,6 +709,131 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
708709
assert.Len(t, matches, 3, "Should match 3 patterns")
709710
}
710711

712+
// TestDNSForwarder_NodataVsNxdomain tests that the forwarder correctly distinguishes
713+
// between NXDOMAIN (domain doesn't exist) and NODATA (domain exists but no records of that type)
714+
func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
715+
mockFirewall := &MockFirewall{}
716+
mockResolver := &MockResolver{}
717+
718+
forwarder := NewDNSForwarder("127.0.0.1:0", 300, mockFirewall, &peer.Status{})
719+
forwarder.resolver = mockResolver
720+
721+
d, err := domain.FromString("example.com")
722+
require.NoError(t, err)
723+
724+
set := firewall.NewDomainSet([]domain.Domain{d})
725+
entries := []*ForwarderEntry{{Domain: d, ResID: "test-res", Set: set}}
726+
forwarder.UpdateDomains(entries)
727+
728+
tests := []struct {
729+
name string
730+
queryType uint16
731+
setupMocks func()
732+
expectedCode int
733+
expectNoAnswer bool // true if we expect NOERROR with empty answer (NODATA case)
734+
description string
735+
}{
736+
{
737+
name: "domain exists but no AAAA records (NODATA)",
738+
queryType: dns.TypeAAAA,
739+
setupMocks: func() {
740+
// First query for AAAA returns not found
741+
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
742+
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
743+
// Check query for A records succeeds (domain exists)
744+
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
745+
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
746+
},
747+
expectedCode: dns.RcodeSuccess,
748+
expectNoAnswer: true,
749+
description: "Should return NOERROR when domain exists but has no records of requested type",
750+
},
751+
{
752+
name: "domain exists but no A records (NODATA)",
753+
queryType: dns.TypeA,
754+
setupMocks: func() {
755+
// First query for A returns not found
756+
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
757+
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
758+
// Check query for AAAA records succeeds (domain exists)
759+
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
760+
Return([]netip.Addr{netip.MustParseAddr("2001:db8::1")}, nil).Once()
761+
},
762+
expectedCode: dns.RcodeSuccess,
763+
expectNoAnswer: true,
764+
description: "Should return NOERROR when domain exists but has no A records",
765+
},
766+
{
767+
name: "domain doesn't exist (NXDOMAIN)",
768+
queryType: dns.TypeA,
769+
setupMocks: func() {
770+
// First query for A returns not found
771+
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
772+
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
773+
// Check query for AAAA also returns not found (domain doesn't exist)
774+
mockResolver.On("LookupNetIP", mock.Anything, "ip6", "example.com.").
775+
Return([]netip.Addr{}, &net.DNSError{IsNotFound: true, Name: "example.com"}).Once()
776+
},
777+
expectedCode: dns.RcodeNameError,
778+
expectNoAnswer: true,
779+
description: "Should return NXDOMAIN when domain doesn't exist at all",
780+
},
781+
{
782+
name: "domain exists with records (normal success)",
783+
queryType: dns.TypeA,
784+
setupMocks: func() {
785+
mockResolver.On("LookupNetIP", mock.Anything, "ip4", "example.com.").
786+
Return([]netip.Addr{netip.MustParseAddr("1.2.3.4")}, nil).Once()
787+
// Expect firewall update for successful resolution
788+
expectedPrefix := netip.PrefixFrom(netip.MustParseAddr("1.2.3.4"), 32)
789+
mockFirewall.On("UpdateSet", set, []netip.Prefix{expectedPrefix}).Return(nil).Once()
790+
},
791+
expectedCode: dns.RcodeSuccess,
792+
expectNoAnswer: false,
793+
description: "Should return NOERROR with answer when records exist",
794+
},
795+
}
796+
797+
for _, tt := range tests {
798+
t.Run(tt.name, func(t *testing.T) {
799+
// Reset mock expectations
800+
mockResolver.ExpectedCalls = nil
801+
mockResolver.Calls = nil
802+
mockFirewall.ExpectedCalls = nil
803+
mockFirewall.Calls = nil
804+
805+
tt.setupMocks()
806+
807+
query := &dns.Msg{}
808+
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
809+
810+
var writtenResp *dns.Msg
811+
mockWriter := &test.MockResponseWriter{
812+
WriteMsgFunc: func(m *dns.Msg) error {
813+
writtenResp = m
814+
return nil
815+
},
816+
}
817+
818+
resp := forwarder.handleDNSQuery(mockWriter, query)
819+
820+
// If a response was returned, it means it should be written (happens in wrapper functions)
821+
if resp != nil && writtenResp == nil {
822+
writtenResp = resp
823+
}
824+
825+
require.NotNil(t, writtenResp, "Expected response to be written")
826+
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
827+
828+
if tt.expectNoAnswer {
829+
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
830+
}
831+
832+
mockResolver.AssertExpectations(t)
833+
})
834+
}
835+
}
836+
711837
func TestDNSForwarder_EmptyQuery(t *testing.T) {
712838
// Test handling of malformed query with no questions
713839
forwarder := NewDNSForwarder("127.0.0.1:0", 300, nil, &peer.Status{})

0 commit comments

Comments
 (0)