Skip to content

Commit 466cdf1

Browse files
authored
Merge pull request #3138 from dolthub/elianddb/9624-fix-wildcard-user-auth
dolthub/dolt#9624 - Fix wildcard user authentication for IP patterns
2 parents 028d9ca + cd7c5f3 commit 466cdf1

File tree

2 files changed

+145
-1
lines changed

2 files changed

+145
-1
lines changed

sql/mysql_db/mysql_db.go

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"encoding/json"
2121
"fmt"
2222
"net"
23+
"regexp"
2324
"sort"
2425
"strings"
2526
"sync"
@@ -591,6 +592,22 @@ func (db *MySQLDb) AddLockedSuperUser(ed *Editor, username string, host string,
591592
}
592593
}
593594

595+
// matchesHostPattern checks if a host matches a host pattern with wildcards.
596+
func matchesHostPattern(host, pattern string) bool {
597+
// No wildcard, not a pattern
598+
if !strings.Contains(pattern, "%") {
599+
return false
600+
}
601+
602+
// Escape regex metacharacters, then replace % with .*
603+
regexPattern := regexp.QuoteMeta(pattern)
604+
regexPattern = strings.ReplaceAll(regexPattern, "%", ".*")
605+
regexPattern = "^" + regexPattern + "$"
606+
607+
matched, err := regexp.MatchString(regexPattern, host)
608+
return err == nil && matched
609+
}
610+
594611
// GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and
595612
// roles, roleSearch changes whether the search matches against user or role rules.
596613
func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSearch bool) *User {
@@ -607,6 +624,9 @@ func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSe
607624
//TODO: Allow for CIDR notation in hostnames
608625
//TODO: Which user do we choose when multiple host names match (e.g. host name with most characters matched, etc.)
609626

627+
// Store the original host for pattern matching against IP patterns
628+
originalHost := host
629+
610630
if "127.0.0.1" == host || "::1" == host {
611631
host = "localhost"
612632
}
@@ -626,7 +646,9 @@ func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSe
626646
if host == user.Host ||
627647
(host == "localhost" && user.Host == "::1") ||
628648
(host == "localhost" && user.Host == "127.0.0.1") ||
629-
(user.Host == "%" && (!roleSearch || host == "")) {
649+
(user.Host == "%" && (!roleSearch || host == "")) ||
650+
matchesHostPattern(host, user.Host) ||
651+
(originalHost != host && matchesHostPattern(originalHost, user.Host)) {
630652
return user
631653
}
632654
}

sql/mysql_db/mysql_db_test.go

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,125 @@ func TestMySQLDbOverwriteUsersAndGrantsData(t *testing.T) {
8686

8787
rd.Close()
8888
}
89+
90+
func TestMatchesHostPattern(t *testing.T) {
91+
tests := []struct {
92+
name string
93+
host string
94+
pattern string
95+
expected bool
96+
}{
97+
// Basic wildcard patterns
98+
{"IP wildcard - exact match", "127.0.0.1", "127.0.0.%", true},
99+
{"IP wildcard - different last octet", "127.0.0.255", "127.0.0.%", true},
100+
{"IP wildcard - no match", "192.168.1.1", "127.0.0.%", false},
101+
{"IP wildcard - partial match", "127.0.1.1", "127.0.0.%", false},
102+
103+
// Multiple wildcards
104+
{"Multiple wildcards", "192.168.1.100", "192.168.%.%", true},
105+
{"Multiple wildcards - no match", "10.0.1.100", "192.168.%.%", false},
106+
107+
// Single wildcard at different positions
108+
{"Wildcard first octet", "10.0.0.1", "%.0.0.1", true},
109+
{"Wildcard middle octet", "192.168.50.1", "192.%.50.1", true},
110+
{"Wildcard last octet", "192.168.1.255", "192.168.1.%", true},
111+
112+
// Non-IP patterns
113+
{"Hostname wildcard", "server1.example.com", "server%.example.com", true},
114+
{"Hostname wildcard - no match", "db1.example.com", "server%.example.com", false},
115+
{"Domain wildcard", "host.subdomain.example.com", "%.example.com", true},
116+
117+
// Edge cases
118+
{"Empty pattern", "127.0.0.1", "", false},
119+
{"Pattern without wildcard", "127.0.0.1", "127.0.0.1", false}, // Should return false as it's not a wildcard pattern
120+
{"Just wildcard", "anything", "%", true},
121+
{"Multiple wildcards together", "test", "%%", true},
122+
123+
// Special characters in patterns (should be escaped)
124+
{"Pattern with dots", "test.host", "test.%", true},
125+
{"Pattern with regex chars", "test[1]", "test[%]", true},
126+
}
127+
128+
for _, tt := range tests {
129+
t.Run(tt.name, func(t *testing.T) {
130+
result := matchesHostPattern(tt.host, tt.pattern)
131+
require.Equal(t, tt.expected, result, "matchesHostPattern(%q, %q) = %v, want %v", tt.host, tt.pattern, result, tt.expected)
132+
})
133+
}
134+
}
135+
136+
func TestGetUserWithWildcardAuthentication(t *testing.T) {
137+
ctx := sql.NewEmptyContext()
138+
db := CreateEmptyMySQLDb()
139+
p := &capturingPersistence{}
140+
db.SetPersister(p)
141+
142+
// Add test users with various host patterns
143+
ed := db.Editor()
144+
db.AddSuperUser(ed, "testuser", "127.0.0.1", "password")
145+
db.AddSuperUser(ed, "localhost_user", "localhost", "password")
146+
db.AddSuperUser(ed, "wildcard_user", "127.0.0.%", "password")
147+
db.AddSuperUser(ed, "subnet_user", "192.168.1.%", "password")
148+
db.AddSuperUser(ed, "hostname_user", "%.example.com", "password")
149+
db.AddSuperUser(ed, "any_user", "%", "password")
150+
db.Persist(ctx, ed)
151+
ed.Close()
152+
153+
rd := db.Reader()
154+
defer rd.Close()
155+
156+
tests := []struct {
157+
name string
158+
username string
159+
host string
160+
expectedUser string
161+
shouldFind bool
162+
}{
163+
// Specific IP tests
164+
{"Specific IP - exact match", "testuser", "127.0.0.1", "testuser", true},
165+
{"Localhost user - normalized", "localhost_user", "127.0.0.1", "localhost_user", true},
166+
{"Localhost user - ::1", "localhost_user", "::1", "localhost_user", true},
167+
{"Non-existent user", "nonexistent", "127.0.0.1", "", false},
168+
169+
// IP wildcard tests
170+
{"Wildcard IP - 127.0.0.% matches 127.0.0.1", "wildcard_user", "127.0.0.1", "wildcard_user", true},
171+
{"Wildcard IP - 127.0.0.% matches 127.0.0.100", "wildcard_user", "127.0.0.100", "wildcard_user", true},
172+
{"Wildcard IP - 127.0.0.% matches 127.0.0.255", "wildcard_user", "127.0.0.255", "wildcard_user", true},
173+
{"Wildcard IP - 127.0.0.% does not match 127.0.1.1", "wildcard_user", "127.0.1.1", "", false},
174+
{"Wildcard IP - 127.0.0.% does not match 192.168.1.1", "wildcard_user", "192.168.1.1", "", false},
175+
176+
// Subnet wildcard tests
177+
{"Subnet wildcard - 192.168.1.% matches 192.168.1.1", "subnet_user", "192.168.1.1", "subnet_user", true},
178+
{"Subnet wildcard - 192.168.1.% matches 192.168.1.100", "subnet_user", "192.168.1.100", "subnet_user", true},
179+
{"Subnet wildcard - 192.168.1.% does not match 192.168.2.1", "subnet_user", "192.168.2.1", "", false},
180+
{"Subnet wildcard - 192.168.1.% does not match 10.0.0.1", "subnet_user", "10.0.0.1", "", false},
181+
182+
// Hostname wildcard tests
183+
{"Hostname wildcard - %.example.com matches host.example.com", "hostname_user", "host.example.com", "hostname_user", true},
184+
{"Hostname wildcard - %.example.com matches www.example.com", "hostname_user", "www.example.com", "hostname_user", true},
185+
{"Hostname wildcard - %.example.com does not match example.com", "hostname_user", "example.com", "", false},
186+
{"Hostname wildcard - %.example.com does not match host.other.com", "hostname_user", "host.other.com", "", false},
187+
188+
// Global wildcard tests
189+
{"Global wildcard - % matches any IP", "any_user", "10.0.0.1", "any_user", true},
190+
{"Global wildcard - % matches any hostname", "any_user", "any.hostname.com", "any_user", true},
191+
{"Global wildcard - % matches localhost", "any_user", "localhost", "any_user", true},
192+
193+
// Issue #9624 scenario
194+
{"Customer scenario - matches connecting IP in range", "wildcard_user", "127.0.0.50", "wildcard_user", true},
195+
}
196+
197+
for _, tt := range tests {
198+
t.Run(tt.name, func(t *testing.T) {
199+
user := db.GetUser(rd, tt.username, tt.host, false)
200+
201+
if !tt.shouldFind {
202+
require.Nil(t, user, "Expected no user to be found for %s@%s", tt.username, tt.host)
203+
return
204+
}
205+
206+
require.NotNil(t, user, "Expected user to be found for %s@%s", tt.username, tt.host)
207+
require.Equal(t, tt.expectedUser, user.User, "Expected username %s, got %s", tt.expectedUser, user.User)
208+
})
209+
}
210+
}

0 commit comments

Comments
 (0)