Skip to content

Commit b922c1c

Browse files
committed
GODRIVER-3284 Allow valid SRV hostnames with less than 3 parts.
1 parent b823409 commit b922c1c

File tree

2 files changed

+123
-4
lines changed

2 files changed

+123
-4
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// Copyright (C) MongoDB, Inc. 2024-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package connstring
8+
9+
import (
10+
"net"
11+
"testing"
12+
13+
"go.mongodb.org/mongo-driver/v2/internal/assert"
14+
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/dns"
15+
)
16+
17+
func TestInitialDNSSeedlistDiscoveryProse(t *testing.T) {
18+
newTestParser := func(record string) *parser {
19+
return &parser{&dns.Resolver{
20+
LookupSRV: func(_, _, _ string) (string, []*net.SRV, error) {
21+
return "", []*net.SRV{
22+
{
23+
Target: record,
24+
Port: 27017,
25+
},
26+
}, nil
27+
},
28+
LookupTXT: func(string) ([]string, error) {
29+
return nil, nil
30+
},
31+
}}
32+
}
33+
34+
t.Run("1. Allow SRVs with fewer than 3 . separated parts", func(t *testing.T) {
35+
t.Parallel()
36+
37+
cases := []struct {
38+
record string
39+
uri string
40+
}{
41+
{"test_1.localhost", "mongodb+srv://localhost"},
42+
{"test_1.mongo.local", "mongodb+srv://mongo.local"},
43+
}
44+
for _, c := range cases {
45+
c := c
46+
t.Run(c.uri, func(t *testing.T) {
47+
t.Parallel()
48+
49+
_, err := newTestParser(c.record).parse(c.uri)
50+
assert.NoError(t, err, "expected no URI parsing error, got %v", err)
51+
})
52+
}
53+
})
54+
t.Run("2. Throw when return address does not end with SRV domain", func(t *testing.T) {
55+
t.Parallel()
56+
57+
cases := []struct {
58+
record string
59+
uri string
60+
}{
61+
{"localhost.mongodb", "mongodb+srv://localhost"},
62+
{"test_1.evil.local", "mongodb+srv://mongo.local"},
63+
{"blogs.evil.com", "mongodb+srv://blogs.mongodb.com"},
64+
}
65+
for _, c := range cases {
66+
c := c
67+
t.Run(c.uri, func(t *testing.T) {
68+
t.Parallel()
69+
70+
_, err := newTestParser(c.record).parse(c.uri)
71+
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
72+
})
73+
}
74+
})
75+
t.Run("3. Throw when return address is identical to SRV hostname", func(t *testing.T) {
76+
t.Parallel()
77+
78+
cases := []struct {
79+
record string
80+
uri string
81+
}{
82+
{"localhost", "mongodb+srv://localhost"},
83+
{"mongo.local", "mongodb+srv://mongo.local"},
84+
}
85+
for _, c := range cases {
86+
c := c
87+
t.Run(c.uri, func(t *testing.T) {
88+
t.Parallel()
89+
90+
_, err := newTestParser(c.record).parse(c.uri)
91+
assert.ErrorContains(t, err, "DNS name must contain at least")
92+
})
93+
}
94+
})
95+
t.Run("4. Throw when return address does not contain . separating shared part of domain", func(t *testing.T) {
96+
t.Parallel()
97+
98+
cases := []struct {
99+
record string
100+
uri string
101+
}{
102+
{"test_1.cluster_1localhost", "mongodb+srv://localhost"},
103+
{"test_1.my_hostmongo.local", "mongodb+srv://mongo.local"},
104+
{"cluster.testmongodb.com", "mongodb+srv://blogs.mongodb.com"},
105+
}
106+
for _, c := range cases {
107+
c := c
108+
t.Run(c.uri, func(t *testing.T) {
109+
t.Parallel()
110+
111+
_, err := newTestParser(c.record).parse(c.uri)
112+
assert.ErrorContains(t, err, "Domain suffix from SRV record not matched input domain")
113+
})
114+
}
115+
})
116+
}

x/mongo/driver/dns/dns.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,18 @@ func (r *Resolver) fetchSeedlistFromSRV(host string, srvName string, stopOnErr b
113113
func validateSRVResult(recordFromSRV, inputHostName string) error {
114114
separatedInputDomain := strings.Split(strings.ToLower(inputHostName), ".")
115115
separatedRecord := strings.Split(strings.ToLower(recordFromSRV), ".")
116-
if len(separatedRecord) < 2 {
117-
return errors.New("DNS name must contain at least 2 labels")
116+
if l := len(separatedInputDomain); l < 3 && len(separatedRecord) <= l {
117+
return fmt.Errorf("DNS name must contain at least %d labels", l+1)
118118
}
119119
if len(separatedRecord) < len(separatedInputDomain) {
120120
return errors.New("Domain suffix from SRV record not matched input domain")
121121
}
122122

123-
inputDomainSuffix := separatedInputDomain[1:]
124-
domainSuffixOffset := len(separatedRecord) - (len(separatedInputDomain) - 1)
123+
inputDomainSuffix := separatedInputDomain
124+
if len(inputDomainSuffix) > 2 {
125+
inputDomainSuffix = inputDomainSuffix[1:]
126+
}
127+
domainSuffixOffset := len(separatedRecord) - len(inputDomainSuffix)
125128

126129
recordDomainSuffix := separatedRecord[domainSuffixOffset:]
127130
for ix, label := range inputDomainSuffix {

0 commit comments

Comments
 (0)