Skip to content

Commit a2c4ff6

Browse files
committed
Add checking to avoid XSS exploits or malformed URLs
1 parent ba76ad0 commit a2c4ff6

File tree

2 files changed

+151
-56
lines changed

2 files changed

+151
-56
lines changed

cmd/short-it/main.go

Lines changed: 121 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,11 @@ import (
88
"fmt"
99
"io"
1010
"log"
11+
"net"
1112
"net/http"
13+
"net/url"
1214
"os"
15+
"regexp"
1316
"strconv"
1417
"strings"
1518
"time"
@@ -94,6 +97,58 @@ func putURL(key, url string) error {
9497
})
9598
}
9699

100+
var domainRegex = regexp.MustCompile(`^[a-zA-Z0-9\.-]+$`)
101+
102+
func isValidStrictURL(s string) bool {
103+
// Avoid large strings
104+
if len(s) > 2048 {
105+
return false
106+
}
107+
// Disallow raw characters, these should already be encoded
108+
if strings.ContainsAny(s, "<>\"'`") {
109+
return false
110+
}
111+
112+
u, err := url.ParseRequestURI(s)
113+
if err != nil || u.Host == "" {
114+
return false
115+
}
116+
117+
// Only allow http and https schemes
118+
if u.Scheme != "http" && u.Scheme != "https" {
119+
return false
120+
}
121+
122+
// Reject URLs with user info (username:password@)
123+
if u.User != nil {
124+
return false
125+
}
126+
127+
hostname := u.Hostname()
128+
if hostname == "localhost" {
129+
return true
130+
}
131+
132+
if ip := net.ParseIP(hostname); ip != nil {
133+
return true
134+
}
135+
136+
// Make sure hostname is a valid domain name
137+
if !domainRegex.MatchString(hostname) {
138+
return false
139+
}
140+
if !strings.Contains(hostname, ".") {
141+
return false
142+
}
143+
parts := strings.Split(hostname, ".")
144+
tld := parts[len(parts)-1]
145+
if len(tld) < 2 {
146+
return false
147+
}
148+
149+
return true
150+
}
151+
97152
func deleteURL(key string) error {
98153
return db.Update(func(tx *bbolt.Tx) error {
99154
b := tx.Bucket([]byte(bucketName))
@@ -169,6 +224,11 @@ func handleCreateShortURL(w http.ResponseWriter, r *http.Request) {
169224
return
170225
}
171226

227+
if !isValidStrictURL(url) {
228+
http.Error(w, "Invalid URL", http.StatusBadRequest)
229+
return
230+
}
231+
172232
key := generateRandomKey()
173233

174234
if err := putURL(key, url); err != nil {
@@ -262,6 +322,11 @@ func handlePutCustomURL(w http.ResponseWriter, r *http.Request, path string) {
262322
return
263323
}
264324

325+
if !isValidStrictURL(url) {
326+
http.Error(w, "Invalid URL", http.StatusBadRequest)
327+
return
328+
}
329+
265330
if err := putURL(path, url); err != nil {
266331
http.Error(w, "Failed to store URL", http.StatusInternalServerError)
267332
return
@@ -280,62 +345,62 @@ func handleDeleteURL(w http.ResponseWriter, r *http.Request, path string) {
280345
}
281346

282347
func handlePageView(r *http.Request) {
283-
if rybbitSiteID == "" || rybbitSiteKey == "" || rybbitSiteURL == "" {
284-
return
285-
}
286-
287-
ip := r.Header.Get("X-Forwarded-For")
288-
if ip == "" {
289-
ip = strings.Split(r.RemoteAddr, ":")[0]
290-
} else {
291-
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
292-
}
293-
294-
hostname := r.Host
295-
if hostname == "" {
296-
if h := r.URL.Hostname(); h != "" {
297-
hostname = h
298-
}
299-
}
300-
301-
data := map[string]string{
302-
"site_id": rybbitSiteID,
303-
"type": "pageview",
304-
"pathname": r.URL.Path,
305-
"hostname": hostname,
306-
"referrer": r.Referer(),
307-
"language": r.Header.Get("Accept-Language"),
308-
"user_agent": r.UserAgent(),
309-
"ip_address": ip,
310-
}
311-
312-
jsonData, err := json.Marshal(data)
313-
if err != nil {
314-
return
315-
}
316-
317-
go func(body []byte) {
318-
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
319-
defer cancel()
320-
321-
req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(rybbitSiteURL, "/")+"/api/track", bytes.NewReader(body))
322-
if err != nil {
323-
return
324-
}
325-
req.Header.Set("Content-Type", "application/json")
326-
req.Header.Set("Authorization", "Bearer "+rybbitSiteKey)
327-
328-
resp, err := rybbitClient.Do(req)
329-
if err != nil {
330-
return
331-
}
332-
defer resp.Body.Close()
333-
334-
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
335-
b, _ := io.ReadAll(resp.Body)
336-
log.Printf("Rybbit tracking error: %s", string(b))
337-
}
338-
}(jsonData)
348+
if rybbitSiteID == "" || rybbitSiteKey == "" || rybbitSiteURL == "" {
349+
return
350+
}
351+
352+
ip := r.Header.Get("X-Forwarded-For")
353+
if ip == "" {
354+
ip = strings.Split(r.RemoteAddr, ":")[0]
355+
} else {
356+
ip = strings.TrimSpace(strings.Split(ip, ",")[0])
357+
}
358+
359+
hostname := r.Host
360+
if hostname == "" {
361+
if h := r.URL.Hostname(); h != "" {
362+
hostname = h
363+
}
364+
}
365+
366+
data := map[string]string{
367+
"site_id": rybbitSiteID,
368+
"type": "pageview",
369+
"pathname": r.URL.Path,
370+
"hostname": hostname,
371+
"referrer": r.Referer(),
372+
"language": r.Header.Get("Accept-Language"),
373+
"user_agent": r.UserAgent(),
374+
"ip_address": ip,
375+
}
376+
377+
jsonData, err := json.Marshal(data)
378+
if err != nil {
379+
return
380+
}
381+
382+
go func(body []byte) {
383+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
384+
defer cancel()
385+
386+
req, err := http.NewRequestWithContext(ctx, "POST", strings.TrimRight(rybbitSiteURL, "/")+"/api/track", bytes.NewReader(body))
387+
if err != nil {
388+
return
389+
}
390+
req.Header.Set("Content-Type", "application/json")
391+
req.Header.Set("Authorization", "Bearer "+rybbitSiteKey)
392+
393+
resp, err := rybbitClient.Do(req)
394+
if err != nil {
395+
return
396+
}
397+
defer resp.Body.Close()
398+
399+
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
400+
b, _ := io.ReadAll(resp.Body)
401+
log.Printf("Rybbit tracking error: %s", string(b))
402+
}
403+
}(jsonData)
339404
}
340405

341406
func main() {

cmd/short-it/main_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,36 @@ func TestListURLs(t *testing.T) {
315315
}
316316
}
317317

318+
func TestIsValidStrictURL(t *testing.T) {
319+
tests := []struct {
320+
url string
321+
want bool
322+
}{
323+
// Valid cases
324+
{"http://google.com", true},
325+
{"https://sub.domain.co.uk", true},
326+
{"http://localhost", true},
327+
{"http://127.0.0.1", true},
328+
{"http://[::1]", true},
329+
330+
// Invalid / Malicious cases
331+
{"javascript:alert(1)", false},
332+
{"http://foo.com/?q=<script>", false},
333+
{"http://user:pass@evil.com", false},
334+
{"http://internal", false},
335+
{"http://google.c", false},
336+
{"ftp://google.com", false},
337+
{"http://exa mple.com", false},
338+
{"http://ex$ample.com", false},
339+
}
340+
341+
for _, tt := range tests {
342+
if got := isValidStrictURL(tt.url); got != tt.want {
343+
t.Errorf("isValidStrictURL(%q) = %v, want %v", tt.url, got, tt.want)
344+
}
345+
}
346+
}
347+
318348
func BenchmarkGenerateRandomKey(b *testing.B) {
319349
testDB := setupTestDB(&testing.T{})
320350
defer teardownTestDB(testDB, &testing.T{})

0 commit comments

Comments
 (0)