Skip to content

Commit 2b36b3d

Browse files
authored
Merge pull request #1667 from pafuent/listener_network_configurable
Adding Echo#ListenerNetwork as configuration
2 parents 06a9480 + 78fe222 commit 2b36b3d

File tree

2 files changed

+78
-8
lines changed

2 files changed

+78
-8
lines changed

echo.go

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ type (
9292
Renderer Renderer
9393
Logger Logger
9494
IPExtractor IPExtractor
95+
ListenerNetwork string
9596
}
9697

9798
// Route contains a handler and information for matching against requests.
@@ -281,6 +282,7 @@ var (
281282
ErrInvalidRedirectCode = errors.New("invalid redirect status code")
282283
ErrCookieNotFound = errors.New("cookie not found")
283284
ErrInvalidCertOrKeyType = errors.New("invalid cert or key type, must be string or []byte")
285+
ErrInvalidListenerNetwork = errors.New("invalid listener network")
284286
)
285287

286288
// Error handlers
@@ -302,9 +304,10 @@ func New() (e *Echo) {
302304
AutoTLSManager: autocert.Manager{
303305
Prompt: autocert.AcceptTOS,
304306
},
305-
Logger: log.New("echo"),
306-
colorer: color.New(),
307-
maxParam: new(int),
307+
Logger: log.New("echo"),
308+
colorer: color.New(),
309+
maxParam: new(int),
310+
ListenerNetwork: "tcp",
308311
}
309312
e.Server.Handler = e
310313
e.TLSServer.Handler = e
@@ -714,7 +717,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
714717

715718
if s.TLSConfig == nil {
716719
if e.Listener == nil {
717-
e.Listener, err = newListener(s.Addr)
720+
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
718721
if err != nil {
719722
return err
720723
}
@@ -725,7 +728,7 @@ func (e *Echo) StartServer(s *http.Server) (err error) {
725728
return s.Serve(e.Listener)
726729
}
727730
if e.TLSListener == nil {
728-
l, err := newListener(s.Addr)
731+
l, err := newListener(s.Addr, e.ListenerNetwork)
729732
if err != nil {
730733
return err
731734
}
@@ -754,7 +757,7 @@ func (e *Echo) StartH2CServer(address string, h2s *http2.Server) (err error) {
754757
}
755758

756759
if e.Listener == nil {
757-
e.Listener, err = newListener(s.Addr)
760+
e.Listener, err = newListener(s.Addr, e.ListenerNetwork)
758761
if err != nil {
759762
return err
760763
}
@@ -875,8 +878,11 @@ func (ln tcpKeepAliveListener) Accept() (c net.Conn, err error) {
875878
return
876879
}
877880

878-
func newListener(address string) (*tcpKeepAliveListener, error) {
879-
l, err := net.Listen("tcp", address)
881+
func newListener(address, network string) (*tcpKeepAliveListener, error) {
882+
if network != "tcp" && network != "tcp4" && network != "tcp6" {
883+
return nil, ErrInvalidListenerNetwork
884+
}
885+
l, err := net.Listen(network, address)
880886
if err != nil {
881887
return nil, err
882888
}

echo_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
stdContext "context"
66
"errors"
7+
"fmt"
78
"io/ioutil"
89
"net/http"
910
"net/http/httptest"
@@ -658,6 +659,69 @@ func TestEchoShutdown(t *testing.T) {
658659
assert.Equal(t, err.Error(), "http: Server closed")
659660
}
660661

662+
var listenerNetworkTests = []struct {
663+
test string
664+
network string
665+
address string
666+
}{
667+
{"tcp ipv4 address", "tcp", "127.0.0.1:1323"},
668+
{"tcp ipv6 address", "tcp", "[::1]:1323"},
669+
{"tcp4 ipv4 address", "tcp4", "127.0.0.1:1323"},
670+
{"tcp6 ipv6 address", "tcp6", "[::1]:1323"},
671+
}
672+
673+
func TestEchoListenerNetwork(t *testing.T) {
674+
for _, tt := range listenerNetworkTests {
675+
t.Run(tt.test, func(t *testing.T) {
676+
e := New()
677+
e.ListenerNetwork = tt.network
678+
679+
// HandlerFunc
680+
e.GET("/ok", func(c Context) error {
681+
return c.String(http.StatusOK, "OK")
682+
})
683+
684+
errCh := make(chan error)
685+
686+
go func() {
687+
errCh <- e.Start(tt.address)
688+
}()
689+
690+
time.Sleep(200 * time.Millisecond)
691+
692+
if resp, err := http.Get(fmt.Sprintf("http://%s/ok", tt.address)); err == nil {
693+
defer resp.Body.Close()
694+
assert.Equal(t, http.StatusOK, resp.StatusCode)
695+
696+
if body, err := ioutil.ReadAll(resp.Body); err == nil {
697+
assert.Equal(t, "OK", string(body))
698+
} else {
699+
assert.Fail(t, err.Error())
700+
}
701+
702+
} else {
703+
assert.Fail(t, err.Error())
704+
}
705+
706+
if err := e.Close(); err != nil {
707+
t.Fatal(err)
708+
}
709+
})
710+
}
711+
}
712+
713+
func TestEchoListenerNetworkInvalid(t *testing.T) {
714+
e := New()
715+
e.ListenerNetwork = "unix"
716+
717+
// HandlerFunc
718+
e.GET("/ok", func(c Context) error {
719+
return c.String(http.StatusOK, "OK")
720+
})
721+
722+
assert.Equal(t, ErrInvalidListenerNetwork, e.Start(":1323"))
723+
}
724+
661725
func TestEchoReverse(t *testing.T) {
662726
assert := assert.New(t)
663727

0 commit comments

Comments
 (0)