Skip to content

Commit 994120a

Browse files
elianddbclaude
andcommitted
Clean up wildcard authentication implementation
- Simplify error handling with single return statement - Avoid redundant matchesHostPattern calls when originalHost == host - Reduce verbose comments for cleaner code - Maintain same functionality with better performance 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 73b87b7 commit 994120a

File tree

2 files changed

+20
-24
lines changed

2 files changed

+20
-24
lines changed

sql/mysql_db/mysql_db.go

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -599,18 +599,14 @@ func matchesHostPattern(host, pattern string) bool {
599599
if !strings.Contains(pattern, "%") {
600600
return false
601601
}
602-
603-
// Replace % with .* for regex matching, but first escape other regex metacharacters
604-
// We need to escape everything except % first, then replace % with .*
605-
regexPattern := regexp.QuoteMeta(pattern) // This escapes everything including %
606-
regexPattern = strings.ReplaceAll(regexPattern, "%", ".*") // Replace % with .*
602+
603+
// Escape regex metacharacters, then replace % with .*
604+
regexPattern := regexp.QuoteMeta(pattern)
605+
regexPattern = strings.ReplaceAll(regexPattern, "%", ".*")
607606
regexPattern = "^" + regexPattern + "$"
608-
607+
609608
matched, err := regexp.MatchString(regexPattern, host)
610-
if err != nil {
611-
return false
612-
}
613-
return matched
609+
return err == nil && matched
614610
}
615611

616612
// GetUser returns a user matching the given user and host if it exists. Due to the slight difference between users and
@@ -631,7 +627,7 @@ func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSe
631627

632628
// Store the original host for pattern matching against IP patterns
633629
originalHost := host
634-
630+
635631
if "127.0.0.1" == host || "::1" == host {
636632
host = "localhost"
637633
}
@@ -653,7 +649,7 @@ func (db *MySQLDb) GetUser(fetcher UserFetcher, user string, host string, roleSe
653649
(host == "localhost" && user.Host == "127.0.0.1") ||
654650
(user.Host == "%" && (!roleSearch || host == "")) ||
655651
matchesHostPattern(host, user.Host) ||
656-
matchesHostPattern(originalHost, user.Host) {
652+
(originalHost != host && matchesHostPattern(originalHost, user.Host)) {
657653
return user
658654
}
659655
}

sql/mysql_db/mysql_db_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -99,32 +99,32 @@ func TestMatchesHostPattern(t *testing.T) {
9999
{"IP wildcard - different last octet", "127.0.0.255", "127.0.0.%", true},
100100
{"IP wildcard - no match", "192.168.1.1", "127.0.0.%", false},
101101
{"IP wildcard - partial match", "127.0.1.1", "127.0.0.%", false},
102-
102+
103103
// Multiple wildcards
104104
{"Multiple wildcards", "192.168.1.100", "192.168.%.%", true},
105105
{"Multiple wildcards - no match", "10.0.1.100", "192.168.%.%", false},
106-
106+
107107
// Single wildcard at different positions
108108
{"Wildcard first octet", "10.0.0.1", "%.0.0.1", true},
109109
{"Wildcard middle octet", "192.168.50.1", "192.%.50.1", true},
110110
{"Wildcard last octet", "192.168.1.255", "192.168.1.%", true},
111-
111+
112112
// Non-IP patterns
113113
{"Hostname wildcard", "server1.example.com", "server%.example.com", true},
114114
{"Hostname wildcard - no match", "db1.example.com", "server%.example.com", false},
115115
{"Domain wildcard", "host.subdomain.example.com", "%.example.com", true},
116-
116+
117117
// Edge cases
118118
{"Empty pattern", "127.0.0.1", "", false},
119119
{"Pattern without wildcard", "127.0.0.1", "127.0.0.1", false}, // Should return false as it's not a wildcard pattern
120120
{"Just wildcard", "anything", "%", true},
121121
{"Multiple wildcards together", "test", "%%", true},
122-
122+
123123
// Special characters in patterns (should be escaped)
124124
{"Pattern with dots", "test.host", "test.%", true},
125125
{"Pattern with regex chars", "test[1]", "test[%]", true},
126126
}
127-
127+
128128
for _, tt := range tests {
129129
t.Run(tt.name, func(t *testing.T) {
130130
result := matchesHostPattern(tt.host, tt.pattern)
@@ -138,17 +138,17 @@ func TestGetUserWithWildcardAuthentication(t *testing.T) {
138138
db := CreateEmptyMySQLDb()
139139
p := &capturingPersistence{}
140140
db.SetPersister(p)
141-
141+
142142
// Add test users with various host patterns
143143
ed := db.Editor()
144144
db.AddSuperUser(ed, "testuser", "127.0.0.1", "password")
145145
db.AddSuperUser(ed, "localhost_user", "localhost", "password")
146146
db.Persist(ctx, ed)
147147
ed.Close()
148-
148+
149149
rd := db.Reader()
150150
defer rd.Close()
151-
151+
152152
tests := []struct {
153153
name string
154154
username string
@@ -162,16 +162,16 @@ func TestGetUserWithWildcardAuthentication(t *testing.T) {
162162
{"Localhost user - ::1", "localhost_user", "::1", "localhost_user", true},
163163
{"Non-existent user", "nonexistent", "127.0.0.1", "", false},
164164
}
165-
165+
166166
for _, tt := range tests {
167167
t.Run(tt.name, func(t *testing.T) {
168168
user := db.GetUser(rd, tt.username, tt.host, false)
169-
169+
170170
if !tt.shouldFind {
171171
require.Nil(t, user, "Expected no user to be found for %s@%s", tt.username, tt.host)
172172
return
173173
}
174-
174+
175175
require.NotNil(t, user, "Expected user to be found for %s@%s", tt.username, tt.host)
176176
require.Equal(t, tt.expectedUser, user.User, "Expected username %s, got %s", tt.expectedUser, user.User)
177177
})

0 commit comments

Comments
 (0)