@@ -8,8 +8,11 @@ import (
88 "fmt"
99 "io"
1010 "log"
11+ "net"
1112 "net/http"
13+ "net/url"
1214 "os"
15+ "regexp"
1316 "strconv"
1417 "strings"
1518 "time"
@@ -94,6 +97,58 @@ func putURL(key, url string) error {
9497 })
9598}
9699
100+ var domainRegex = regexp .MustCompile (`^[a-zA-Z0-9\.-]+$` )
101+
102+ func isValidStrictURL (s string ) bool {
103+ // Avoid large strings
104+ if len (s ) > 2048 {
105+ return false
106+ }
107+ // Disallow raw characters, these should already be encoded
108+ if strings .ContainsAny (s , "<>\" '`" ) {
109+ return false
110+ }
111+
112+ u , err := url .ParseRequestURI (s )
113+ if err != nil || u .Host == "" {
114+ return false
115+ }
116+
117+ // Only allow http and https schemes
118+ if u .Scheme != "http" && u .Scheme != "https" {
119+ return false
120+ }
121+
122+ // Reject URLs with user info (username:password@)
123+ if u .User != nil {
124+ return false
125+ }
126+
127+ hostname := u .Hostname ()
128+ if hostname == "localhost" {
129+ return true
130+ }
131+
132+ if ip := net .ParseIP (hostname ); ip != nil {
133+ return true
134+ }
135+
136+ // Make sure hostname is a valid domain name
137+ if ! domainRegex .MatchString (hostname ) {
138+ return false
139+ }
140+ if ! strings .Contains (hostname , "." ) {
141+ return false
142+ }
143+ parts := strings .Split (hostname , "." )
144+ tld := parts [len (parts )- 1 ]
145+ if len (tld ) < 2 {
146+ return false
147+ }
148+
149+ return true
150+ }
151+
97152func deleteURL (key string ) error {
98153 return db .Update (func (tx * bbolt.Tx ) error {
99154 b := tx .Bucket ([]byte (bucketName ))
@@ -169,6 +224,11 @@ func handleCreateShortURL(w http.ResponseWriter, r *http.Request) {
169224 return
170225 }
171226
227+ if ! isValidStrictURL (url ) {
228+ http .Error (w , "Invalid URL" , http .StatusBadRequest )
229+ return
230+ }
231+
172232 key := generateRandomKey ()
173233
174234 if err := putURL (key , url ); err != nil {
@@ -262,6 +322,11 @@ func handlePutCustomURL(w http.ResponseWriter, r *http.Request, path string) {
262322 return
263323 }
264324
325+ if ! isValidStrictURL (url ) {
326+ http .Error (w , "Invalid URL" , http .StatusBadRequest )
327+ return
328+ }
329+
265330 if err := putURL (path , url ); err != nil {
266331 http .Error (w , "Failed to store URL" , http .StatusInternalServerError )
267332 return
@@ -280,62 +345,62 @@ func handleDeleteURL(w http.ResponseWriter, r *http.Request, path string) {
280345}
281346
282347func handlePageView (r * http.Request ) {
283- if rybbitSiteID == "" || rybbitSiteKey == "" || rybbitSiteURL == "" {
284- return
285- }
286-
287- ip := r .Header .Get ("X-Forwarded-For" )
288- if ip == "" {
289- ip = strings .Split (r .RemoteAddr , ":" )[0 ]
290- } else {
291- ip = strings .TrimSpace (strings .Split (ip , "," )[0 ])
292- }
293-
294- hostname := r .Host
295- if hostname == "" {
296- if h := r .URL .Hostname (); h != "" {
297- hostname = h
298- }
299- }
300-
301- data := map [string ]string {
302- "site_id" : rybbitSiteID ,
303- "type" : "pageview" ,
304- "pathname" : r .URL .Path ,
305- "hostname" : hostname ,
306- "referrer" : r .Referer (),
307- "language" : r .Header .Get ("Accept-Language" ),
308- "user_agent" : r .UserAgent (),
309- "ip_address" : ip ,
310- }
311-
312- jsonData , err := json .Marshal (data )
313- if err != nil {
314- return
315- }
316-
317- go func (body []byte ) {
318- ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
319- defer cancel ()
320-
321- req , err := http .NewRequestWithContext (ctx , "POST" , strings .TrimRight (rybbitSiteURL , "/" )+ "/api/track" , bytes .NewReader (body ))
322- if err != nil {
323- return
324- }
325- req .Header .Set ("Content-Type" , "application/json" )
326- req .Header .Set ("Authorization" , "Bearer " + rybbitSiteKey )
327-
328- resp , err := rybbitClient .Do (req )
329- if err != nil {
330- return
331- }
332- defer resp .Body .Close ()
333-
334- if resp .StatusCode < 200 || resp .StatusCode >= 300 {
335- b , _ := io .ReadAll (resp .Body )
336- log .Printf ("Rybbit tracking error: %s" , string (b ))
337- }
338- }(jsonData )
348+ if rybbitSiteID == "" || rybbitSiteKey == "" || rybbitSiteURL == "" {
349+ return
350+ }
351+
352+ ip := r .Header .Get ("X-Forwarded-For" )
353+ if ip == "" {
354+ ip = strings .Split (r .RemoteAddr , ":" )[0 ]
355+ } else {
356+ ip = strings .TrimSpace (strings .Split (ip , "," )[0 ])
357+ }
358+
359+ hostname := r .Host
360+ if hostname == "" {
361+ if h := r .URL .Hostname (); h != "" {
362+ hostname = h
363+ }
364+ }
365+
366+ data := map [string ]string {
367+ "site_id" : rybbitSiteID ,
368+ "type" : "pageview" ,
369+ "pathname" : r .URL .Path ,
370+ "hostname" : hostname ,
371+ "referrer" : r .Referer (),
372+ "language" : r .Header .Get ("Accept-Language" ),
373+ "user_agent" : r .UserAgent (),
374+ "ip_address" : ip ,
375+ }
376+
377+ jsonData , err := json .Marshal (data )
378+ if err != nil {
379+ return
380+ }
381+
382+ go func (body []byte ) {
383+ ctx , cancel := context .WithTimeout (context .Background (), 2 * time .Second )
384+ defer cancel ()
385+
386+ req , err := http .NewRequestWithContext (ctx , "POST" , strings .TrimRight (rybbitSiteURL , "/" )+ "/api/track" , bytes .NewReader (body ))
387+ if err != nil {
388+ return
389+ }
390+ req .Header .Set ("Content-Type" , "application/json" )
391+ req .Header .Set ("Authorization" , "Bearer " + rybbitSiteKey )
392+
393+ resp , err := rybbitClient .Do (req )
394+ if err != nil {
395+ return
396+ }
397+ defer resp .Body .Close ()
398+
399+ if resp .StatusCode < 200 || resp .StatusCode >= 300 {
400+ b , _ := io .ReadAll (resp .Body )
401+ log .Printf ("Rybbit tracking error: %s" , string (b ))
402+ }
403+ }(jsonData )
339404}
340405
341406func main () {
0 commit comments