Skip to content

Commit 9873403

Browse files
authored
[+] add tests for webserver (#754)
* [+] add tests for `webserver` * [+] add tests for metric and preset endpoints * [+] add tests for `handleStatic` * [+] add tests for source endpoints * [+] add tests for `/log` endpoint * [+] synchronize web-socket log reader and writer * [+] add jwt-related tests
1 parent d411ddd commit 9873403

File tree

17 files changed

+1160
-223
lines changed

17 files changed

+1160
-223
lines changed

cmd/pgwatch/main.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ func setupCloseHandler(cancel context.CancelFunc) {
3232
}
3333

3434
var (
35-
exitCode atomic.Int32 // Exit code to be returned to the OS
36-
mainCtx context.Context // Main context for the application
37-
cancel context.CancelFunc // Cancel function to stop the main context
38-
logger log.LoggerHookerIface // Logger for the application
39-
opts *cmdopts.Options // Command line options for the application
35+
exitCode atomic.Int32 // Exit code to be returned to the OS
36+
mainCtx context.Context // Main context for the application
37+
cancel context.CancelFunc // Cancel function to stop the main context
38+
logger log.LoggerHooker // Logger for the application
39+
opts *cmdopts.Options // Command line options for the application
4040
err error
4141
)
4242

internal/log/log.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ import (
1111
)
1212

1313
type (
14-
// LoggerIface is the interface used by all components
15-
LoggerIface logrus.FieldLogger
16-
//LoggerHookerIface adds AddHook method to LoggerIface for database logging hook
17-
LoggerHookerIface interface {
18-
LoggerIface
14+
// Logger is the interface used by all components
15+
Logger logrus.FieldLogger
16+
//LoggerHooker adds AddHook method to LoggerIface for database logging hook
17+
LoggerHooker interface {
18+
Logger
1919
AddHook(hook logrus.Hook)
2020
AddSubscriber(msgCh MessageChanType)
2121
RemoveSubscriber(msgCh MessageChanType)
@@ -54,7 +54,7 @@ func getLogFileFormatter(opts CmdOpts) logrus.Formatter {
5454
}
5555

5656
// Init creates logging facilities for the application
57-
func Init(opts CmdOpts) LoggerHookerIface {
57+
func Init(opts CmdOpts) LoggerHooker {
5858
var err error
5959
l := logger{logrus.New(), NewBrokerHook(context.Background(), opts.LogLevel)}
6060
l.AddHook(l.BrokerHook)
@@ -74,11 +74,11 @@ func Init(opts CmdOpts) LoggerHookerIface {
7474

7575
// PgxLogger is the struct used to log using pgx postgres driver
7676
type PgxLogger struct {
77-
l LoggerIface
77+
l Logger
7878
}
7979

8080
// NewPgxLogger returns a new instance of PgxLogger
81-
func NewPgxLogger(l LoggerIface) *PgxLogger {
81+
func NewPgxLogger(l Logger) *PgxLogger {
8282
return &PgxLogger{l}
8383
}
8484

@@ -107,7 +107,7 @@ func (pgxlogger *PgxLogger) Log(ctx context.Context, level tracelog.LogLevel, ms
107107

108108
// WithLogger returns a new context with the provided logger. Use in
109109
// combination with logger.WithField(s) for great effect
110-
func WithLogger(ctx context.Context, logger LoggerIface) context.Context {
110+
func WithLogger(ctx context.Context, logger Logger) context.Context {
111111
return context.WithValue(ctx, loggerKey{}, logger)
112112
}
113113

@@ -116,10 +116,10 @@ var FallbackLogger = Init(CmdOpts{})
116116

117117
// GetLogger retrieves the current logger from the context. If no logger is
118118
// available, the default logger is returned
119-
func GetLogger(ctx context.Context) LoggerIface {
119+
func GetLogger(ctx context.Context) Logger {
120120
logger := ctx.Value(loggerKey{})
121121
if logger == nil {
122122
return FallbackLogger
123123
}
124-
return logger.(LoggerIface)
124+
return logger.(Logger)
125125
}

internal/reaper/reaper.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ type Reaper struct {
2626
ready atomic.Bool
2727
measurementCh chan []metrics.MeasurementEnvelope
2828
measurementCache *InstanceMetricCache
29-
logger log.LoggerIface
29+
logger log.Logger
3030
monitoredSources sources.SourceConns
3131
prevLoopMonitoredDBs sources.SourceConns
3232
cancelFuncs map[string]context.CancelFunc
@@ -191,7 +191,7 @@ func (r *Reaper) Reap(ctx context.Context) {
191191
}
192192

193193
// CreateSourceHelpers creates the extensions and metric helpers for the monitored source
194-
func (r *Reaper) CreateSourceHelpers(ctx context.Context, srcL log.LoggerIface, monitoredSource *sources.SourceConn) {
194+
func (r *Reaper) CreateSourceHelpers(ctx context.Context, srcL log.Logger, monitoredSource *sources.SourceConn) {
195195
if r.prevLoopMonitoredDBs.GetMonitoredDatabase(monitoredSource.Name) != nil {
196196
return // already created
197197
}

internal/sources/resolver.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ type PatroniClusterMember struct {
5656
Role string
5757
}
5858

59-
var logger log.LoggerIface = log.FallbackLogger
59+
var logger log.Logger = log.FallbackLogger
6060

6161
var lastFoundClusterMembers = make(map[string][]PatroniClusterMember) // needed for cases where DCS is temporarily down
6262
// don't want to immediately remove monitoring of DBs

internal/webserver/api.go

Lines changed: 0 additions & 92 deletions
This file was deleted.

internal/webserver/cors.go

Lines changed: 0 additions & 18 deletions
This file was deleted.

internal/webserver/jwt_test.go

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
package webserver
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net/http"
8+
"net/http/httptest"
9+
"testing"
10+
"time"
11+
12+
"github.com/golang-jwt/jwt/v5"
13+
"github.com/stretchr/testify/assert"
14+
)
15+
16+
func TestIsCorrectPassword(t *testing.T) {
17+
ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
18+
assert.True(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "pass"}))
19+
assert.False(t, ts.IsCorrectPassword(loginReq{Username: "user", Password: "wrong"}))
20+
assert.True(t, (&WebUIServer{}).IsCorrectPassword(loginReq{})) // empty user/pass disables auth
21+
}
22+
23+
func TestHandleLogin_POST_Success(t *testing.T) {
24+
ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
25+
body, _ := json.Marshal(map[string]string{"user": "user", "password": "pass"})
26+
r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
27+
w := httptest.NewRecorder()
28+
ts.handleLogin(w, r)
29+
resp := w.Result()
30+
assert.Equal(t, http.StatusOK, resp.StatusCode)
31+
token, _ := io.ReadAll(resp.Body)
32+
assert.NotEmpty(t, string(token))
33+
}
34+
35+
func TestHandleLogin_POST_Fail(t *testing.T) {
36+
ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
37+
body, _ := json.Marshal(map[string]string{"user": "user", "password": "wrong"})
38+
r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader(body))
39+
w := httptest.NewRecorder()
40+
ts.handleLogin(w, r)
41+
resp := w.Result()
42+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
43+
}
44+
45+
func TestHandleLogin_POST_BadJSON(t *testing.T) {
46+
ts := &WebUIServer{CmdOpts: CmdOpts{WebUser: "user", WebPassword: "pass"}}
47+
r := httptest.NewRequest(http.MethodPost, "/login", bytes.NewReader([]byte("notjson")))
48+
w := httptest.NewRecorder()
49+
ts.handleLogin(w, r)
50+
resp := w.Result()
51+
assert.Equal(t, http.StatusInternalServerError, resp.StatusCode)
52+
}
53+
54+
func TestHandleLogin_GET(t *testing.T) {
55+
ts := &WebUIServer{}
56+
r := httptest.NewRequest(http.MethodGet, "/login", nil)
57+
w := httptest.NewRecorder()
58+
ts.handleLogin(w, r)
59+
resp := w.Result()
60+
assert.Equal(t, http.StatusOK, resp.StatusCode)
61+
body, _ := io.ReadAll(resp.Body)
62+
assert.Equal(t, "only POST methods is allowed.", string(body))
63+
}
64+
65+
func TestGenerateAndValidateJWT(t *testing.T) {
66+
token, err := generateJWT("user1")
67+
assert.NoError(t, err)
68+
r := httptest.NewRequest(http.MethodGet, "/", nil)
69+
r.Header.Set("Token", token)
70+
assert.NoError(t, validateToken(r))
71+
}
72+
73+
func TestValidateToken_MissingToken(t *testing.T) {
74+
r := httptest.NewRequest(http.MethodGet, "/", nil)
75+
err := validateToken(r)
76+
assert.Error(t, err)
77+
assert.Contains(t, err.Error(), "can not find token")
78+
}
79+
80+
func TestValidateToken_InvalidToken(t *testing.T) {
81+
r := httptest.NewRequest(http.MethodGet, "/", nil)
82+
r.Header.Set("Token", "invalidtoken")
83+
err := validateToken(r)
84+
assert.Error(t, err)
85+
}
86+
87+
func TestEnsureAuth_ServeHTTP(t *testing.T) {
88+
called := false
89+
h := func(w http.ResponseWriter, _ *http.Request) {
90+
called = true
91+
w.WriteHeader(http.StatusTeapot)
92+
}
93+
token, _ := generateJWT("user1")
94+
r := httptest.NewRequest(http.MethodGet, "/", nil)
95+
r.Header.Set("Token", token)
96+
w := httptest.NewRecorder()
97+
NewEnsureAuth(h).ServeHTTP(w, r)
98+
resp := w.Result()
99+
assert.Equal(t, http.StatusTeapot, resp.StatusCode)
100+
assert.True(t, called)
101+
}
102+
103+
func TestEnsureAuth_ServeHTTP_InvalidToken(t *testing.T) {
104+
h := func(w http.ResponseWriter, _ *http.Request) {
105+
w.WriteHeader(http.StatusTeapot)
106+
}
107+
r := httptest.NewRequest(http.MethodGet, "/", nil)
108+
r.Header.Set("Token", "invalidtoken")
109+
w := httptest.NewRecorder()
110+
NewEnsureAuth(h).ServeHTTP(w, r)
111+
resp := w.Result()
112+
assert.Equal(t, http.StatusUnauthorized, resp.StatusCode)
113+
}
114+
115+
func TestJWT_Expiration(t *testing.T) {
116+
tok := jwt.New(jwt.SigningMethodHS256)
117+
claims := tok.Claims.(jwt.MapClaims)
118+
claims["authorized"] = true
119+
claims["username"] = "user"
120+
claims["exp"] = time.Now().Add(-time.Hour).Unix() // expired
121+
token, _ := tok.SignedString(sampleSecretKey)
122+
r := httptest.NewRequest(http.MethodGet, "/", nil)
123+
r.Header.Set("Token", token)
124+
err := validateToken(r)
125+
assert.Error(t, err)
126+
assert.Contains(t, err.Error(), "token is expired")
127+
}

0 commit comments

Comments
 (0)