Skip to content

Commit 9841cb4

Browse files
authored
Check the provided port in the configuration (#31)
* Check port number and print diagnostic message when the app is running Signed-off-by: Doğukan Teber <[email protected]> * Write unit tests and update the existing unit test cases Signed-off-by: Doğukan Teber <[email protected]> * Update tests Signed-off-by: Doğukan Teber <[email protected]> * Change default server port Signed-off-by: Doğukan Teber <[email protected]> * Update tests Signed-off-by: Doğukan Teber <[email protected]> --------- Signed-off-by: Doğukan Teber <[email protected]>
1 parent 22696db commit 9841cb4

File tree

3 files changed

+100
-10
lines changed

3 files changed

+100
-10
lines changed

gateway/gateway_test.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@ import (
1414
)
1515

1616
func TestNewGateway(t *testing.T) {
17-
srv, err := server.New(server.Config{})
17+
srv, err := server.New(server.Config{
18+
HTTPListenAddr: "localhost",
19+
HTTPListenPort: 1234,
20+
UnAuthorizedHTTPListenAddr: "localhost",
21+
UnAuthorizedHTTPListenPort: 1235,
22+
})
1823
if err != nil {
1924
t.Fatal(err)
2025
}
@@ -44,6 +49,7 @@ func TestNewGateway(t *testing.T) {
4449
}
4550

4651
assert.NotNil(t, gw)
52+
srv.Shutdown()
4753
}
4854

4955
func TestStartGateway(t *testing.T) {
@@ -316,7 +322,7 @@ func TestStartGateway(t *testing.T) {
316322

317323
for _, tc := range testCases {
318324
t.Run(tc.name, func(t *testing.T) {
319-
gw, err := createMockGateway("localhost", 8080, 8081, tc.config)
325+
gw, err := createMockGateway("localhost", 8010, 8011, tc.config)
320326
if tc.expectedErr == nil && err != nil {
321327
t.Fatalf("Unexpected error when creating the gateway: %v\n", err)
322328
}

server/server.go

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ import (
1818
)
1919

2020
const (
21-
AUTH = "auth"
22-
UNAUTH = "unauth"
23-
DefaultNetwork = "tcp"
21+
AUTH = "auth"
22+
UNAUTH = "unauth"
23+
DefaultNetwork = "tcp"
24+
DefaultAuthPort = 80
25+
DefaultUnauthPort = 8081
2426
)
2527

2628
type Config struct {
@@ -58,7 +60,12 @@ type server struct {
5860
}
5961

6062
func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) {
61-
listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, cfg.HTTPListenPort)
63+
port, err := checkPort(cfg.HTTPListenAddr, cfg.HTTPListenPort, DefaultAuthPort, DefaultNetwork)
64+
if err != nil {
65+
return nil, err
66+
}
67+
cfg.HTTPListenPort = port
68+
listenAddr := fmt.Sprintf("%s:%d", cfg.HTTPListenAddr, port)
6269
httpListener, err := net.Listen(DefaultNetwork, listenAddr)
6370
if err != nil {
6471
return nil, err
@@ -115,6 +122,11 @@ func initAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, e
115122
}
116123

117124
func initUnAuthServer(cfg *Config, middlewares []middleware.Interface) (*server, error) {
125+
port, err := checkPort(cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort, DefaultUnauthPort, DefaultNetwork)
126+
if err != nil {
127+
return nil, err
128+
}
129+
cfg.UnAuthorizedHTTPListenPort = port
118130
listenAddr := fmt.Sprintf("%s:%d", cfg.UnAuthorizedHTTPListenAddr, cfg.UnAuthorizedHTTPListenPort)
119131
unauthHttpListener, err := net.Listen(DefaultNetwork, listenAddr)
120132
if err != nil {
@@ -222,7 +234,7 @@ func New(cfg Config) (*Server, error) {
222234
}
223235

224236
func (s *Server) Run() error {
225-
logrus.Infof("the server has started listening on %v", s.authServer.httpServer.Addr)
237+
logrus.Infof("the main server has started listening on %v", s.authServer.httpServer.Addr)
226238
errChan := make(chan error, 1)
227239

228240
go func() {
@@ -237,6 +249,7 @@ func (s *Server) Run() error {
237249
}
238250
}()
239251

252+
logrus.Infof("the admin server has started listening on %v", s.unAuthServer.httpServer.Addr)
240253
go func() {
241254
err := s.unAuthServer.run()
242255
if err == http.ErrServerClosed {
@@ -292,3 +305,25 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) {
292305
func (s *Server) GetHTTPHandlers() (http.Handler, http.Handler) {
293306
return s.authServer.http, s.unAuthServer.http
294307
}
308+
309+
func checkPortAvailable(addr string, port int, network string) bool {
310+
l, err := net.Listen(network, fmt.Sprintf("%s:%d", addr, port))
311+
if err != nil {
312+
return false
313+
}
314+
l.Close()
315+
return true
316+
}
317+
318+
func checkPort(addr string, port int, defaultPort int, network string) (int, error) {
319+
p := port
320+
if port == 0 {
321+
logrus.Info(fmt.Sprintf("port not specified, trying default port %d", defaultPort))
322+
if checkPortAvailable(addr, defaultPort, network) {
323+
p = defaultPort
324+
} else {
325+
return 0, fmt.Errorf(fmt.Sprintf("port %d is not available, please specify a port", defaultPort))
326+
}
327+
}
328+
return p, nil
329+
}

server/server_test.go

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ func TestNew(t *testing.T) {
2525
config: Config{
2626
HTTPListenAddr: "http://localhost",
2727
HTTPListenPort: 8080,
28+
UnAuthorizedHTTPListenAddr: "localhost",
29+
UnAuthorizedHTTPListenPort: 1111,
2830
ServerGracefulShutdownTimeout: time.Second * 5,
2931
HTTPServerReadTimeout: time.Second * 10,
3032
HTTPServerWriteTimeout: time.Second * 10,
@@ -35,8 +37,10 @@ func TestNew(t *testing.T) {
3537
{
3638
name: "invalid address for unauth",
3739
config: Config{
38-
UnAuthorizedHTTPListenAddr: "http://localhost",
40+
HTTPListenAddr: "localhost",
3941
HTTPListenPort: 8080,
42+
UnAuthorizedHTTPListenAddr: "http://localhost",
43+
UnAuthorizedHTTPListenPort: 8081,
4044
ServerGracefulShutdownTimeout: time.Second * 5,
4145
HTTPServerReadTimeout: time.Second * 10,
4246
HTTPServerWriteTimeout: time.Second * 10,
@@ -49,6 +53,8 @@ func TestNew(t *testing.T) {
4953
config: Config{
5054
HTTPListenAddr: "localhost",
5155
HTTPListenPort: 8080,
56+
UnAuthorizedHTTPListenAddr: "localhost",
57+
UnAuthorizedHTTPListenPort: 8081,
5258
ServerGracefulShutdownTimeout: time.Second * 5,
5359
HTTPServerReadTimeout: time.Second * 10,
5460
HTTPServerWriteTimeout: time.Second * 10,
@@ -61,6 +67,8 @@ func TestNew(t *testing.T) {
6167
config: Config{
6268
HTTPListenAddr: "localhost",
6369
HTTPListenPort: 8080,
70+
UnAuthorizedHTTPListenAddr: "localhost",
71+
UnAuthorizedHTTPListenPort: 8081,
6472
ServerGracefulShutdownTimeout: time.Second * 5,
6573
HTTPServerReadTimeout: time.Second * 10,
6674
HTTPServerWriteTimeout: time.Second * 10,
@@ -94,6 +102,7 @@ func TestNew(t *testing.T) {
94102
t.Errorf("Expected server address to be %s:%d, but got %s", tc.config.HTTPListenAddr, tc.config.HTTPListenPort, server.authServer.httpServer.Addr)
95103
}
96104
}
105+
server.Shutdown()
97106
})
98107
}
99108
}
@@ -115,7 +124,6 @@ func TestServer_RegisterTo(t *testing.T) {
115124
s.RegisterTo("/test_auth", testHandler, AUTH)
116125
s.RegisterTo("/test_unauth", testHandler, UNAUTH)
117126

118-
// Test authorized server.
119127
req := httptest.NewRequest(http.MethodGet, "/test_auth", nil)
120128
w := httptest.NewRecorder()
121129

@@ -126,7 +134,6 @@ func TestServer_RegisterTo(t *testing.T) {
126134
t.Errorf("Expected status code %d for AUTH server, but got %d", http.StatusOK, resp.StatusCode)
127135
}
128136

129-
// Test unauthorized server.
130137
req = httptest.NewRequest(http.MethodGet, "/test_unauth", nil)
131138
w = httptest.NewRecorder()
132139

@@ -227,6 +234,10 @@ func TestRun(t *testing.T) {
227234

228235
func TestReadyHandler(t *testing.T) {
229236
cfg := Config{
237+
HTTPListenAddr: "localhost",
238+
HTTPListenPort: 1234,
239+
UnAuthorizedHTTPListenAddr: "localhost",
240+
UnAuthorizedHTTPListenPort: 1235,
230241
HTTPServerReadTimeout: 5 * time.Second,
231242
HTTPServerWriteTimeout: 5 * time.Second,
232243
HTTPServerIdleTimeout: 5 * time.Second,
@@ -278,4 +289,42 @@ func TestReadyHandler(t *testing.T) {
278289
}
279290
})
280291
}
292+
s.Shutdown()
293+
}
294+
295+
func TestCheckPortAvailable(t *testing.T) {
296+
tests := []struct {
297+
name string
298+
listenAddr string
299+
listenPort int
300+
wantAvailable bool
301+
}{
302+
{
303+
name: "port available",
304+
listenAddr: "localhost",
305+
listenPort: 8080,
306+
wantAvailable: true,
307+
},
308+
{
309+
name: "port unavailable",
310+
listenAddr: "localhost",
311+
listenPort: 1234,
312+
wantAvailable: false,
313+
},
314+
}
315+
316+
listener, err := net.Listen(DefaultNetwork, fmt.Sprintf("%s:%d", "localhost", 1234))
317+
if err != nil {
318+
t.Fatalf("Failed to create a listener: %v", err)
319+
}
320+
defer listener.Close()
321+
for _, tt := range tests {
322+
t.Run(tt.name, func(t *testing.T) {
323+
324+
available := checkPortAvailable(tt.listenAddr, tt.listenPort, DefaultNetwork)
325+
if available != tt.wantAvailable {
326+
t.Errorf("Expected port %d to be available: %v", tt.listenPort, tt.wantAvailable)
327+
}
328+
})
329+
}
281330
}

0 commit comments

Comments
 (0)