Skip to content

Commit 6cd2749

Browse files
authored
fix(utils): santize host when target has host port (#6759)
1 parent bfffb3b commit 6cd2749

File tree

2 files changed

+180
-6
lines changed

2 files changed

+180
-6
lines changed

pkg/protocols/utils/fields.go

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package utils
22

33
import (
4+
"net"
5+
"strings"
6+
47
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
58
iputil "github.com/projectdiscovery/utils/ip"
69
urlutil "github.com/projectdiscovery/utils/url"
@@ -28,17 +31,21 @@ func GetJsonFieldsFromURL(URL string) JsonFields {
2831
URL: parsed.String(),
2932
Path: parsed.Path,
3033
}
34+
35+
host := parsed.Host
36+
host, fields.Port = extractHostPort(host, fields.Port)
37+
3138
if fields.Port == "" {
3239
fields.Port = "80"
3340
if fields.Scheme == "https" {
3441
fields.Port = "443"
3542
}
3643
}
37-
if iputil.IsIP(parsed.Host) {
38-
fields.Ip = parsed.Host
44+
if iputil.IsIP(host) {
45+
fields.Ip = host
3946
}
4047

41-
fields.Host = parsed.Host
48+
fields.Host = host
4249
return fields
4350
}
4451

@@ -56,16 +63,45 @@ func GetJsonFieldsFromMetaInput(ctx *contextargs.MetaInput) JsonFields {
5663
fields.Scheme = parsed.Scheme
5764
fields.URL = parsed.String()
5865
fields.Path = parsed.Path
66+
67+
host := parsed.Host
68+
host, fields.Port = extractHostPort(host, fields.Port)
69+
5970
if fields.Port == "" {
6071
fields.Port = "80"
6172
if fields.Scheme == "https" {
6273
fields.Port = "443"
6374
}
6475
}
65-
if iputil.IsIP(parsed.Host) {
66-
fields.Ip = parsed.Host
76+
if iputil.IsIP(host) {
77+
fields.Ip = host
6778
}
6879

69-
fields.Host = parsed.Host
80+
fields.Host = host
7081
return fields
7182
}
83+
84+
func extractHostPort(host, port string) (string, string) {
85+
if !strings.Contains(host, ":") {
86+
return host, port
87+
}
88+
if strings.HasPrefix(host, "[") {
89+
if idx := strings.Index(host, "]:"); idx != -1 {
90+
if port == "" {
91+
port = host[idx+2:]
92+
}
93+
return host[1:idx], port
94+
}
95+
if strings.HasSuffix(host, "]") {
96+
return host[1 : len(host)-1], port
97+
}
98+
return host, port
99+
}
100+
if h, p, err := net.SplitHostPort(host); err == nil {
101+
if port == "" {
102+
port = p
103+
}
104+
return h, port
105+
}
106+
return host, port
107+
}

pkg/protocols/utils/fields_test.go

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
package utils
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestGetJsonFieldsFromURL_HostPortExtraction(t *testing.T) {
10+
t.Parallel()
11+
12+
tests := []struct {
13+
name string
14+
input string
15+
expectedHost string
16+
expectedPort string
17+
}{
18+
{
19+
name: "URL with scheme and port",
20+
input: "http://example.com:8080/path",
21+
expectedHost: "example.com",
22+
expectedPort: "8080",
23+
},
24+
{
25+
name: "URL with scheme no port",
26+
input: "https://example.com/path",
27+
expectedHost: "example.com",
28+
expectedPort: "443",
29+
},
30+
{
31+
name: "host:port without scheme",
32+
input: "example.com:8080",
33+
expectedHost: "example.com",
34+
expectedPort: "8080",
35+
},
36+
{
37+
name: "host:port with standard HTTPS port",
38+
input: "example.com:443",
39+
expectedHost: "example.com",
40+
expectedPort: "443",
41+
},
42+
{
43+
name: "IPv4 with port",
44+
input: "192.168.1.1:8080",
45+
expectedHost: "192.168.1.1",
46+
expectedPort: "8080",
47+
},
48+
{
49+
name: "IPv6 with port",
50+
input: "[2001:db8::1]:8080",
51+
expectedHost: "2001:db8::1",
52+
expectedPort: "8080",
53+
},
54+
{
55+
name: "localhost with port",
56+
input: "localhost:3000",
57+
expectedHost: "localhost",
58+
expectedPort: "3000",
59+
},
60+
}
61+
62+
for _, tt := range tests {
63+
t.Run(tt.name, func(t *testing.T) {
64+
t.Parallel()
65+
66+
fields := GetJsonFieldsFromURL(tt.input)
67+
68+
assert.Equal(t, tt.expectedHost, fields.Host)
69+
assert.Equal(t, tt.expectedPort, fields.Port)
70+
})
71+
}
72+
}
73+
74+
func TestExtractHostPort(t *testing.T) {
75+
t.Parallel()
76+
77+
tests := []struct {
78+
name string
79+
host string
80+
port string
81+
expectedHost string
82+
expectedPort string
83+
}{
84+
{
85+
name: "host without port",
86+
host: "example.com",
87+
port: "",
88+
expectedHost: "example.com",
89+
expectedPort: "",
90+
},
91+
{
92+
name: "host with port",
93+
host: "example.com:8080",
94+
port: "",
95+
expectedHost: "example.com",
96+
expectedPort: "8080",
97+
},
98+
{
99+
name: "port already set",
100+
host: "example.com:8080",
101+
port: "443",
102+
expectedHost: "example.com",
103+
expectedPort: "443",
104+
},
105+
{
106+
name: "IPv6 with port",
107+
host: "[::1]:8080",
108+
port: "",
109+
expectedHost: "::1",
110+
expectedPort: "8080",
111+
},
112+
{
113+
name: "IPv6 without port",
114+
host: "[::1]",
115+
port: "",
116+
expectedHost: "::1",
117+
expectedPort: "",
118+
},
119+
{
120+
name: "IPv4 with port",
121+
host: "192.168.1.1:8080",
122+
port: "",
123+
expectedHost: "192.168.1.1",
124+
expectedPort: "8080",
125+
},
126+
}
127+
128+
for _, tt := range tests {
129+
t.Run(tt.name, func(t *testing.T) {
130+
t.Parallel()
131+
132+
host, port := extractHostPort(tt.host, tt.port)
133+
134+
assert.Equal(t, tt.expectedHost, host)
135+
assert.Equal(t, tt.expectedPort, port)
136+
})
137+
}
138+
}

0 commit comments

Comments
 (0)