Skip to content

Commit 73421dd

Browse files
committed
Improve DSN env vairable expansion
1 parent 6b042da commit 73421dd

File tree

2 files changed

+666
-1
lines changed

2 files changed

+666
-1
lines changed

pkg/database/database.go

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/url"
99
"os"
1010
"regexp"
11+
"strings"
1112

1213
"github.com/conductorone/baton-sql/pkg/database/hdb"
1314
"github.com/conductorone/baton-sql/pkg/database/mysql"
@@ -50,8 +51,153 @@ func updateFromEnv(dsn string) (string, error) {
5051
return result, nil
5152
}
5253

54+
// extractPlaceholders replaces ${...} placeholders with unique numeric sentinels
55+
// and looks up the environment variable values immediately.
56+
// Numeric sentinels (999000, 999001, etc.) are valid in all URL components:
57+
// ports (must be numeric), hostnames, userinfo, paths, and query strings.
58+
// This allows us to parse the URL structure before expanding environment variables.
59+
// Returns: the string with sentinels, a mapping of sentinel->value, and any error.
60+
func extractPlaceholders(s string) (string, map[string]string, error) {
61+
mapping := make(map[string]string)
62+
counter := 0
63+
var err error
64+
65+
result := DSNREnvRegex.ReplaceAllStringFunc(s, func(match string) string {
66+
sentinel := fmt.Sprintf("999%03d", counter)
67+
varName := match[2 : len(match)-1] // Extract VAR from ${VAR}
68+
69+
// Look up the environment variable immediately
70+
value, exists := os.LookupEnv(varName)
71+
if !exists {
72+
err = errors.Join(err, fmt.Errorf("environment variable %s is not set", varName))
73+
return sentinel // Return sentinel anyway to allow URL parsing for better error messages
74+
}
75+
76+
mapping[sentinel] = value
77+
counter++
78+
return sentinel
79+
})
80+
81+
if err != nil {
82+
return "", nil, err
83+
}
84+
85+
return result, mapping, nil
86+
}
87+
88+
// expandWithMapping expands sentinels in a string by replacing them with their
89+
// corresponding values from the mapping.
90+
func expandWithMapping(s string, mapping map[string]string) string {
91+
result := s
92+
for sentinel, value := range mapping {
93+
result = strings.ReplaceAll(result, sentinel, value)
94+
}
95+
return result
96+
}
97+
98+
// expandUserInfo expands environment variable placeholders in the URL's user info component.
99+
// It handles the special case where an entire "user:password" string might be in a single variable.
100+
// The url.UserPassword function automatically handles URL encoding of special characters.
101+
func expandUserInfo(parsedUrl *url.URL, mapping map[string]string) {
102+
if parsedUrl.User == nil {
103+
return
104+
}
105+
106+
username := parsedUrl.User.Username()
107+
password, hasPass := parsedUrl.User.Password()
108+
109+
// Expand sentinels in username
110+
expandedUser := expandWithMapping(username, mapping)
111+
112+
// Expand sentinels in password
113+
var expandedPass string
114+
if hasPass {
115+
expandedPass = expandWithMapping(password, mapping)
116+
}
117+
118+
// Handle the case where the entire userinfo (user:password) is in a single variable.
119+
// For example: ${CREDENTIALS} where CREDENTIALS="admin:p@ss#word"
120+
if strings.Contains(expandedUser, ":") && !hasPass {
121+
parts := strings.SplitN(expandedUser, ":", 2)
122+
expandedUser = parts[0]
123+
expandedPass = parts[1]
124+
hasPass = true
125+
}
126+
127+
// url.UserPassword automatically handles URL encoding of special characters
128+
if hasPass {
129+
parsedUrl.User = url.UserPassword(expandedUser, expandedPass)
130+
} else {
131+
parsedUrl.User = url.User(expandedUser)
132+
}
133+
}
134+
135+
// expandHost expands environment variable placeholders in the URL's host component.
136+
func expandHost(parsedUrl *url.URL, mapping map[string]string) {
137+
if parsedUrl.Host != "" {
138+
parsedUrl.Host = expandWithMapping(parsedUrl.Host, mapping)
139+
}
140+
}
141+
142+
// expandPath expands environment variable placeholders in the URL's path component.
143+
func expandPath(parsedUrl *url.URL, mapping map[string]string) {
144+
if parsedUrl.Path != "" {
145+
parsedUrl.Path = expandWithMapping(parsedUrl.Path, mapping)
146+
}
147+
}
148+
149+
// expandQuery expands environment variable placeholders in the URL's query string.
150+
func expandQuery(parsedUrl *url.URL, mapping map[string]string) {
151+
if parsedUrl.RawQuery != "" {
152+
parsedUrl.RawQuery = expandWithMapping(parsedUrl.RawQuery, mapping)
153+
}
154+
}
155+
156+
// expandFragment expands environment variable placeholders in the URL's fragment.
157+
func expandFragment(parsedUrl *url.URL, mapping map[string]string) {
158+
if parsedUrl.Fragment != "" {
159+
parsedUrl.Fragment = expandWithMapping(parsedUrl.Fragment, mapping)
160+
}
161+
}
162+
163+
// expandDSN expands environment variable placeholders in a DSN using a three-phase approach:
164+
// 1. Replace ${...} with safe sentinel values and lookup env vars
165+
// 2. Parse the URL structure with sentinels
166+
// 3. Expand sentinels component-by-component with appropriate encoding
167+
//
168+
// This approach ensures that special characters in environment variables (like #, @, :)
169+
// don't break URL parsing, since they are expanded after the URL structure is established.
170+
func expandDSN(dsn string) (string, error) {
171+
// Phase 1: Replace ${...} with sentinels and lookup env vars
172+
sentinelDSN, mapping, err := extractPlaceholders(dsn)
173+
if err != nil {
174+
return "", err
175+
}
176+
177+
// If there are no placeholders, return as-is
178+
if len(mapping) == 0 {
179+
return dsn, nil
180+
}
181+
182+
// Phase 2: Parse with sentinels
183+
parsedUrl, err := url.Parse(sentinelDSN)
184+
if err != nil {
185+
return "", fmt.Errorf("invalid DSN structure: %w", err)
186+
}
187+
188+
// Phase 3: Expand by component with appropriate encoding
189+
expandUserInfo(parsedUrl, mapping)
190+
expandHost(parsedUrl, mapping)
191+
expandPath(parsedUrl, mapping)
192+
expandQuery(parsedUrl, mapping)
193+
expandFragment(parsedUrl, mapping)
194+
195+
return parsedUrl.String(), nil
196+
}
197+
53198
func Connect(ctx context.Context, dsn string, user string, password string) (*sql.DB, DbEngine, error) {
54-
populatedDSN, err := updateFromEnv(dsn)
199+
// Use the new expandDSN function which handles special characters correctly
200+
populatedDSN, err := expandDSN(dsn)
55201
if err != nil {
56202
return nil, Unknown, err
57203
}

0 commit comments

Comments
 (0)