Skip to content

Commit 6194273

Browse files
authored
use net.JoinHostPort instead of fmt.Sprintf (fatedier#2791)
1 parent b2311e5 commit 6194273

File tree

17 files changed

+61
-49
lines changed

17 files changed

+61
-49
lines changed

client/proxy/proxy.go

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -347,22 +347,18 @@ func (pxy *XTCPProxy) InWorkConn(conn net.Conn, m *msg.StartWorkConn) {
347347
xl.Trace("get natHoleRespMsg, sid [%s], client address [%s] visitor address [%s]", natHoleRespMsg.Sid, natHoleRespMsg.ClientAddr, natHoleRespMsg.VisitorAddr)
348348

349349
// Send detect message
350-
array := strings.Split(natHoleRespMsg.VisitorAddr, ":")
351-
if len(array) <= 1 {
352-
xl.Error("get NatHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
350+
host, portStr, err := net.SplitHostPort(natHoleRespMsg.VisitorAddr)
351+
if err != nil {
352+
xl.Error("get NatHoleResp visitor address [%s] error: %v", natHoleRespMsg.VisitorAddr, err)
353353
}
354354
laddr, _ := net.ResolveUDPAddr("udp", clientConn.LocalAddr().String())
355-
/*
356-
for i := 1000; i < 65000; i++ {
357-
pxy.sendDetectMsg(array[0], int64(i), laddr, "a")
358-
}
359-
*/
360-
port, err := strconv.ParseInt(array[1], 10, 64)
355+
356+
port, err := strconv.ParseInt(portStr, 10, 64)
361357
if err != nil {
362358
xl.Error("get natHoleResp visitor address error: %v", natHoleRespMsg.VisitorAddr)
363359
return
364360
}
365-
pxy.sendDetectMsg(array[0], int(port), laddr, []byte(natHoleRespMsg.Sid))
361+
pxy.sendDetectMsg(host, int(port), laddr, []byte(natHoleRespMsg.Sid))
366362
xl.Trace("send all detect msg done")
367363

368364
msg.WriteMsg(conn, &msg.NatHoleClientDetectOK{})

client/visitor.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"fmt"
2121
"io"
2222
"net"
23+
"strconv"
2324
"sync"
2425
"time"
2526

@@ -85,7 +86,7 @@ type STCPVisitor struct {
8586
}
8687

8788
func (sv *STCPVisitor) Run() (err error) {
88-
sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
89+
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
8990
if err != nil {
9091
return
9192
}
@@ -174,7 +175,7 @@ type XTCPVisitor struct {
174175
}
175176

176177
func (sv *XTCPVisitor) Run() (err error) {
177-
sv.l, err = net.Listen("tcp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
178+
sv.l, err = net.Listen("tcp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
178179
if err != nil {
179180
return
180181
}
@@ -352,7 +353,7 @@ type SUDPVisitor struct {
352353
func (sv *SUDPVisitor) Run() (err error) {
353354
xl := xlog.FromContextSafe(sv.ctx)
354355

355-
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", sv.cfg.BindAddr, sv.cfg.BindPort))
356+
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(sv.cfg.BindAddr, strconv.Itoa(sv.cfg.BindPort)))
356357
if err != nil {
357358
return fmt.Errorf("sudp ResolveUDPAddr error: %v", err)
358359
}

pkg/util/net/udp.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"fmt"
1919
"io"
2020
"net"
21+
"strconv"
2122
"sync"
2223
"time"
2324

@@ -163,7 +164,7 @@ type UDPListener struct {
163164
}
164165

165166
func ListenUDP(bindAddr string, bindPort int) (l *UDPListener, err error) {
166-
udpAddr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
167+
udpAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
167168
if err != nil {
168169
return l, err
169170
}

pkg/util/net/websocket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ package net
22

33
import (
44
"errors"
5-
"fmt"
65
"net"
76
"net/http"
7+
"strconv"
88

99
"golang.org/x/net/websocket"
1010
)
@@ -52,7 +52,7 @@ func NewWebsocketListener(ln net.Listener) (wl *WebsocketListener) {
5252
}
5353

5454
func ListenWebsocket(bindAddr string, bindPort int) (*WebsocketListener, error) {
55-
tcpLn, err := net.Listen("tcp", fmt.Sprintf("%s:%d", bindAddr, bindPort))
55+
tcpLn, err := net.Listen("tcp", net.JoinHostPort(bindAddr, strconv.Itoa(bindPort)))
5656
if err != nil {
5757
return nil, err
5858
}

pkg/util/tcpmux/httpconnect.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func readHTTPConnectRequest(rd io.Reader) (host string, err error) {
4848
return
4949
}
5050

51-
host = util.GetHostFromAddr(req.Host)
51+
host, _ = util.CanonicalHost(req.Host)
5252
return
5353
}
5454

pkg/util/util/http.go

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,6 @@ func OkResponse() *http.Response {
3434
return res
3535
}
3636

37-
// TODO: use "CanonicalHost" func to replace all "GetHostFromAddr" func.
38-
func GetHostFromAddr(addr string) (host string) {
39-
strs := strings.Split(addr, ":")
40-
if len(strs) > 1 {
41-
host = strs[0]
42-
} else {
43-
host = addr
44-
}
45-
return
46-
}
47-
4837
// canonicalHost strips port from host if present and returns the canonicalized
4938
// host name.
5039
func CanonicalHost(host string) (string, error) {

pkg/util/util/util.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"crypto/rand"
2020
"encoding/hex"
2121
"fmt"
22+
"net"
2223
"strconv"
2324
"strings"
2425
)
@@ -52,7 +53,7 @@ func CanonicalAddr(host string, port int) (addr string) {
5253
if port == 80 || port == 443 {
5354
addr = host
5455
} else {
55-
addr = fmt.Sprintf("%s:%d", host, port)
56+
addr = net.JoinHostPort(host, strconv.Itoa(port))
5657
}
5758
return
5859
}

pkg/util/vhost/http.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
5959
Director: func(req *http.Request) {
6060
req.URL.Scheme = "http"
6161
url := req.Context().Value(RouteInfoURL).(string)
62-
oldHost := util.GetHostFromAddr(req.Context().Value(RouteInfoHost).(string))
62+
oldHost, _ := util.CanonicalHost(req.Context().Value(RouteInfoHost).(string))
6363
rc := rp.GetRouteConfig(oldHost, url)
6464
if rc != nil {
6565
if rc.RewriteHost != "" {
@@ -81,7 +81,7 @@ func NewHTTPReverseProxy(option HTTPReverseProxyOptions, vhostRouter *Routers) *
8181
IdleConnTimeout: 60 * time.Second,
8282
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
8383
url := ctx.Value(RouteInfoURL).(string)
84-
host := util.GetHostFromAddr(ctx.Value(RouteInfoHost).(string))
84+
host, _ := util.CanonicalHost(ctx.Value(RouteInfoHost).(string))
8585
remote := ctx.Value(RouteInfoRemote).(string)
8686
return rp.CreateConnection(host, url, remote)
8787
},
@@ -191,7 +191,7 @@ func (rp *HTTPReverseProxy) getVhost(domain string, location string) (vr *Router
191191
}
192192

193193
func (rp *HTTPReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
194-
domain := util.GetHostFromAddr(req.Host)
194+
domain, _ := util.CanonicalHost(req.Host)
195195
location := req.URL.Path
196196
user, passwd, _ := req.BasicAuth()
197197
if !rp.CheckAuth(domain, location, user, passwd) {

server/group/tcp.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
package group
1616

1717
import (
18-
"fmt"
1918
"net"
19+
"strconv"
2020
"sync"
2121

2222
"github.com/fatedier/frp/server/ports"
@@ -101,7 +101,7 @@ func (tg *TCPGroup) Listen(proxyName string, group string, groupKey string, addr
101101
if err != nil {
102102
return
103103
}
104-
tcpLn, errRet := net.Listen("tcp", fmt.Sprintf("%s:%d", addr, port))
104+
tcpLn, errRet := net.Listen("tcp", net.JoinHostPort(addr, strconv.Itoa(port)))
105105
if errRet != nil {
106106
err = errRet
107107
return

server/ports/ports.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ package ports
22

33
import (
44
"errors"
5-
"fmt"
65
"net"
6+
"strconv"
77
"sync"
88
"time"
99
)
@@ -134,7 +134,7 @@ func (pm *Manager) Acquire(name string, port int) (realPort int, err error) {
134134

135135
func (pm *Manager) isPortAvailable(port int) bool {
136136
if pm.netType == "udp" {
137-
addr, err := net.ResolveUDPAddr("udp", fmt.Sprintf("%s:%d", pm.bindAddr, port))
137+
addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
138138
if err != nil {
139139
return false
140140
}
@@ -146,7 +146,7 @@ func (pm *Manager) isPortAvailable(port int) bool {
146146
return true
147147
}
148148

149-
l, err := net.Listen(pm.netType, fmt.Sprintf("%s:%d", pm.bindAddr, port))
149+
l, err := net.Listen(pm.netType, net.JoinHostPort(pm.bindAddr, strconv.Itoa(port)))
150150
if err != nil {
151151
return false
152152
}

0 commit comments

Comments
 (0)