Skip to content

Commit 17be025

Browse files
fleandreiCopilot
andauthored
move http domain utils from middelware to pkg folder (#1437)
* move http domain utils from middelware to pkg folder * Update pkg/http/domain.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update pkg/http/domain.go Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 00b5e41 commit 17be025

File tree

4 files changed

+173
-163
lines changed

4 files changed

+173
-163
lines changed

api/domain_restriction.go

Lines changed: 4 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package api
22

33
import (
44
"net/http"
5-
"net/url"
65
"strings"
76

7+
stationHttp "github.com/massalabs/station/pkg/http"
88
"github.com/massalabs/station/pkg/logger"
99
)
1010

@@ -18,7 +18,7 @@ func DomainRestrictionMiddleware(handler http.Handler) http.Handler {
1818
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1919
if r.Method != "GET" && IsRestrictedPath(r.URL.Path) {
2020
if !isRequestFromAllowedDomain(r) {
21-
logger.Warnf("Blocked operation from unauthorized domain: %s", getRequestOrigin(r))
21+
logger.Warnf("Blocked operation from unauthorized domain: %s", stationHttp.GetRequestOrigin(r))
2222
http.Error(w, "Forbidden: Operations restricted to authorized domains", http.StatusForbidden)
2323
return
2424
}
@@ -39,8 +39,8 @@ func IsRestrictedPath(path string) bool {
3939
}
4040

4141
func isRequestFromAllowedDomain(r *http.Request) bool {
42-
origin := getRequestOrigin(r)
43-
hostname := extractHostname(origin)
42+
origin := stationHttp.GetRequestOrigin(r)
43+
hostname := stationHttp.ExtractHostname(origin)
4444

4545
for _, allowedDomain := range allowedDomains() {
4646
if hostname == allowedDomain {
@@ -50,80 +50,3 @@ func isRequestFromAllowedDomain(r *http.Request) bool {
5050

5151
return false
5252
}
53-
54-
func getRequestOrigin(r *http.Request) string {
55-
if origin := r.Header.Get("Origin"); origin != "" {
56-
return origin
57-
}
58-
59-
// Check Referer header as fallback
60-
if referer := r.Header.Get("Referer"); referer != "" {
61-
return referer
62-
}
63-
64-
// Check Host header for local requests
65-
if host := r.Header.Get("Host"); host != "" {
66-
return host
67-
}
68-
69-
return "unknown"
70-
}
71-
72-
// extractHostname safely extracts the hostname from a URL string
73-
func extractHostname(origin string) string {
74-
if origin == "" {
75-
return ""
76-
}
77-
78-
// Handle cases where the origin might just be a hostname without protocol
79-
if !strings.Contains(origin, "://") {
80-
if parsed, err := url.Parse("http://" + origin); err == nil {
81-
return parsed.Hostname()
82-
}
83-
84-
// Fallback: if parsing fails but origin looks like a simple hostname,
85-
// check if it matches any allowed domain directly
86-
if isSimpleHostname(origin) {
87-
return origin
88-
}
89-
return ""
90-
}
91-
92-
parsed, err := url.Parse(origin)
93-
if err != nil {
94-
return ""
95-
}
96-
97-
return parsed.Hostname()
98-
}
99-
100-
// isSimpleHostname checks if a string looks like a simple hostname
101-
// (contains only valid hostname characters and no suspicious patterns)
102-
func isSimpleHostname(s string) bool {
103-
if s == "" {
104-
return false
105-
}
106-
107-
// Only allow valid hostname characters: letters, digits, hyphens, dots, and colons (for ports)
108-
for _, r := range s {
109-
isLetter := (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
110-
isDigit := r >= '0' && r <= '9'
111-
isAllowedPunct := r == '-' || r == '.' || r == ':'
112-
if !isLetter && !isDigit && !isAllowedPunct {
113-
return false
114-
}
115-
}
116-
117-
// Must not start or end with hyphen or dot
118-
if strings.HasPrefix(s, "-") || strings.HasSuffix(s, "-") ||
119-
strings.HasPrefix(s, ".") || strings.HasSuffix(s, ".") {
120-
return false
121-
}
122-
123-
// Check for consecutive dots
124-
if strings.Contains(s, "..") {
125-
return false
126-
}
127-
128-
return true
129-
}

api/middleware_test.go

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -234,85 +234,3 @@ func TestIsRequestFromAllowedDomain(t *testing.T) {
234234
})
235235
}
236236
}
237-
238-
func TestExtractHostname(t *testing.T) {
239-
tests := []struct {
240-
name string
241-
origin string
242-
expected string
243-
}{
244-
{"HTTPS URL", "https://station.massa", "station.massa"},
245-
{"HTTP URL", "http://localhost:3000", "localhost"},
246-
{"URL with port", "https://127.0.0.1:8080", "127.0.0.1"},
247-
{"Just hostname", "station.massa", "station.massa"},
248-
{"Just localhost", "localhost", "localhost"},
249-
{"Just IP", "127.0.0.1", "127.0.0.1"},
250-
{"Empty string", "", ""},
251-
{"Malformed URL", "not-a-url", "not-a-url"},
252-
{"Invalid control chars", string([]byte{0x7f, 0x80, 0x81}), ""},
253-
{"URL with path", "https://station.massa/path", "station.massa"},
254-
{"URL with query", "https://station.massa?query=value", "station.massa"},
255-
{"Malicious domain", "https://malicious-station.massa.com", "malicious-station.massa.com"},
256-
}
257-
258-
for _, tt := range tests {
259-
t.Run(tt.name, func(t *testing.T) {
260-
result := extractHostname(tt.origin)
261-
if result != tt.expected {
262-
t.Errorf("extractHostname(%s) = %s, expected %s", tt.origin, result, tt.expected)
263-
}
264-
})
265-
}
266-
}
267-
268-
func TestIsSimpleHostname(t *testing.T) {
269-
tests := []struct {
270-
name string
271-
hostname string
272-
expected bool
273-
}{
274-
{"Valid hostname", "station.massa", true},
275-
{"Valid localhost", "localhost", true},
276-
{"Valid IP", "127.0.0.1", true},
277-
{"Valid with port", "localhost:3000", true},
278-
{"Empty string", "", false},
279-
{"With spaces", "station massa", false},
280-
{"With tabs", "station\tmassa", false},
281-
{"With newlines", "station\nmassa", false},
282-
{"With angle brackets", "station<massa", false},
283-
{"With quotes", "station\"massa", false},
284-
{"With backticks", "station`massa", false},
285-
{"With braces", "station{massa", false},
286-
{"With control chars", string([]byte{0x7f, 0x80, 0x81}), false},
287-
{"Starting with hyphen", "-station.massa", false},
288-
{"Ending with hyphen", "station.massa-", false},
289-
{"Starting with dot", ".station.massa", false},
290-
{"Ending with dot", "station.massa.", false},
291-
{"Consecutive dots", "station..massa", false},
292-
{"Valid subdomain", "api.station.massa", true},
293-
{"Valid with numbers", "station1.massa2", true},
294-
// Additional security tests for new restrictive validation
295-
{"With at symbol", "station@massa", false},
296-
{"With hash", "station#massa", false},
297-
{"With dollar", "station$massa", false},
298-
{"With percent", "station%massa", false},
299-
{"With ampersand", "station&massa", false},
300-
{"With asterisk", "station*massa", false},
301-
{"With plus", "station+massa", false},
302-
{"With equals", "station=massa", false},
303-
{"With question mark", "station?massa", false},
304-
{"With underscore", "station_massa", false},
305-
{"With tilde", "station~massa", false},
306-
{"With pipe", "station|massa", false},
307-
{"With backslash", "station\\massa", false},
308-
}
309-
310-
for _, tt := range tests {
311-
t.Run(tt.name, func(t *testing.T) {
312-
result := isSimpleHostname(tt.hostname)
313-
if result != tt.expected {
314-
t.Errorf("isSimpleHostname(%s) = %v, expected %v", tt.hostname, result, tt.expected)
315-
}
316-
})
317-
}
318-
}

pkg/http/domain.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package http
2+
3+
import (
4+
"net/http"
5+
"net/url"
6+
"strings"
7+
)
8+
9+
func GetRequestOrigin(r *http.Request) string {
10+
if origin := r.Header.Get("Origin"); origin != "" {
11+
return origin
12+
}
13+
14+
// Check Referer header as fallback
15+
if referer := r.Header.Get("Referer"); referer != "" {
16+
return referer
17+
}
18+
19+
// Check Host header for local requests
20+
if host := r.Header.Get("Host"); host != "" {
21+
return host
22+
}
23+
24+
return "unknown"
25+
}
26+
27+
// ExtractHostname safely extracts the hostname from a URL string
28+
func ExtractHostname(origin string) string {
29+
if origin == "" {
30+
return ""
31+
}
32+
33+
// Handle cases where the origin might just be a hostname without protocol
34+
if !strings.Contains(origin, "://") {
35+
if parsed, err := url.Parse("http://" + origin); err == nil {
36+
return parsed.Hostname()
37+
}
38+
39+
// Fallback: if parsing fails but origin looks like a simple hostname,
40+
// check if it matches any allowed domain directly
41+
if IsSimpleHostname(origin) {
42+
return origin
43+
}
44+
return ""
45+
}
46+
47+
parsed, err := url.Parse(origin)
48+
if err != nil {
49+
return ""
50+
}
51+
52+
return parsed.Hostname()
53+
}
54+
55+
// IsSimpleHostname checks if a string looks like a simple hostname
56+
// (contains only valid hostname characters and no suspicious patterns)
57+
func IsSimpleHostname(s string) bool {
58+
if s == "" {
59+
return false
60+
}
61+
62+
// Only allow valid hostname characters: letters, digits, hyphens, dots, and colons (for ports)
63+
for _, r := range s {
64+
isLetter := (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z')
65+
isDigit := r >= '0' && r <= '9'
66+
isAllowedPunct := r == '-' || r == '.' || r == ':'
67+
if !isLetter && !isDigit && !isAllowedPunct {
68+
return false
69+
}
70+
}
71+
72+
// Must not start or end with hyphen or dot
73+
if strings.HasPrefix(s, "-") || strings.HasSuffix(s, "-") ||
74+
strings.HasPrefix(s, ".") || strings.HasSuffix(s, ".") {
75+
return false
76+
}
77+
78+
// Check for consecutive dots
79+
if strings.Contains(s, "..") {
80+
return false
81+
}
82+
83+
return true
84+
}

pkg/http/domain_test.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package http
2+
3+
import "testing"
4+
5+
func TestExtractHostname(t *testing.T) {
6+
tests := []struct {
7+
name string
8+
origin string
9+
expected string
10+
}{
11+
{"HTTPS URL", "https://station.massa", "station.massa"},
12+
{"HTTP URL", "http://localhost:3000", "localhost"},
13+
{"URL with port", "https://127.0.0.1:8080", "127.0.0.1"},
14+
{"Just hostname", "station.massa", "station.massa"},
15+
{"Just localhost", "localhost", "localhost"},
16+
{"Just IP", "127.0.0.1", "127.0.0.1"},
17+
{"Empty string", "", ""},
18+
{"Malformed URL", "not-a-url", "not-a-url"},
19+
{"Invalid control chars", string([]byte{0x7f, 0x80, 0x81}), ""},
20+
{"URL with path", "https://station.massa/path", "station.massa"},
21+
{"URL with query", "https://station.massa?query=value", "station.massa"},
22+
{"Malicious domain", "https://malicious-station.massa.com", "malicious-station.massa.com"},
23+
}
24+
25+
for _, tt := range tests {
26+
t.Run(tt.name, func(t *testing.T) {
27+
result := ExtractHostname(tt.origin)
28+
if result != tt.expected {
29+
t.Errorf("extractHostname(%s) = %s, expected %s", tt.origin, result, tt.expected)
30+
}
31+
})
32+
}
33+
}
34+
35+
func TestIsSimpleHostname(t *testing.T) {
36+
tests := []struct {
37+
name string
38+
hostname string
39+
expected bool
40+
}{
41+
{"Valid hostname", "station.massa", true},
42+
{"Valid localhost", "localhost", true},
43+
{"Valid IP", "127.0.0.1", true},
44+
{"Valid with port", "localhost:3000", true},
45+
{"Empty string", "", false},
46+
{"With spaces", "station massa", false},
47+
{"With tabs", "station\tmassa", false},
48+
{"With newlines", "station\nmassa", false},
49+
{"With angle brackets", "station<massa", false},
50+
{"With quotes", "station\"massa", false},
51+
{"With backticks", "station`massa", false},
52+
{"With braces", "station{massa", false},
53+
{"With control chars", string([]byte{0x7f, 0x80, 0x81}), false},
54+
{"Starting with hyphen", "-station.massa", false},
55+
{"Ending with hyphen", "station.massa-", false},
56+
{"Starting with dot", ".station.massa", false},
57+
{"Ending with dot", "station.massa.", false},
58+
{"Consecutive dots", "station..massa", false},
59+
{"Valid subdomain", "api.station.massa", true},
60+
{"Valid with numbers", "station1.massa2", true},
61+
// Additional security tests for new restrictive validation
62+
{"With at symbol", "station@massa", false},
63+
{"With hash", "station#massa", false},
64+
{"With dollar", "station$massa", false},
65+
{"With percent", "station%massa", false},
66+
{"With ampersand", "station&massa", false},
67+
{"With asterisk", "station*massa", false},
68+
{"With plus", "station+massa", false},
69+
{"With equals", "station=massa", false},
70+
{"With question mark", "station?massa", false},
71+
{"With underscore", "station_massa", false},
72+
{"With tilde", "station~massa", false},
73+
{"With pipe", "station|massa", false},
74+
{"With backslash", "station\\massa", false},
75+
}
76+
77+
for _, tt := range tests {
78+
t.Run(tt.name, func(t *testing.T) {
79+
result := IsSimpleHostname(tt.hostname)
80+
if result != tt.expected {
81+
t.Errorf("isSimpleHostname(%s) = %v, expected %v", tt.hostname, result, tt.expected)
82+
}
83+
})
84+
}
85+
}

0 commit comments

Comments
 (0)