Skip to content

Commit ea71681

Browse files
authored
Merge pull request #107 from go-authgate/worktree-new
test: improve test coverage for handlers, store, and middleware
2 parents 8b9543d + fc2ea0c commit ea71681

4 files changed

Lines changed: 1096 additions & 0 deletions

File tree

internal/handlers/device_test.go

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
package handlers
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"net/http"
7+
"net/http/httptest"
8+
"net/url"
9+
"strings"
10+
"testing"
11+
"time"
12+
13+
"github.com/go-authgate/authgate/internal/config"
14+
"github.com/go-authgate/authgate/internal/metrics"
15+
"github.com/go-authgate/authgate/internal/models"
16+
"github.com/go-authgate/authgate/internal/services"
17+
"github.com/go-authgate/authgate/internal/store"
18+
19+
"github.com/gin-gonic/gin"
20+
"github.com/google/uuid"
21+
"github.com/stretchr/testify/assert"
22+
"github.com/stretchr/testify/require"
23+
)
24+
25+
func setupDeviceTestEnv(t *testing.T) (*gin.Engine, *store.Store) {
26+
t.Helper()
27+
gin.SetMode(gin.TestMode)
28+
29+
cfg := &config.Config{
30+
DeviceCodeExpiration: 30 * time.Minute,
31+
PollingInterval: 5,
32+
BaseURL: "http://localhost:8080",
33+
}
34+
35+
s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{})
36+
require.NoError(t, err)
37+
38+
auditSvc := services.NewAuditService(s, false, 0)
39+
deviceSvc := services.NewDeviceService(s, cfg, auditSvc, metrics.NewNoopMetrics())
40+
userSvc := services.NewUserService(s, nil, nil, "local", false, auditSvc, nil, 0)
41+
authzSvc := services.NewAuthorizationService(s, cfg, auditSvc)
42+
handler := NewDeviceHandler(deviceSvc, userSvc, authzSvc, cfg)
43+
44+
r := gin.New()
45+
r.POST("/oauth/device/code", handler.DeviceCodeRequest)
46+
47+
return r, s
48+
}
49+
50+
func createDeviceFlowClient(
51+
t *testing.T,
52+
s *store.Store,
53+
active bool,
54+
deviceFlowEnabled bool,
55+
) *models.OAuthApplication {
56+
t.Helper()
57+
status := models.ClientStatusActive
58+
if !active {
59+
status = models.ClientStatusInactive
60+
}
61+
client := &models.OAuthApplication{
62+
ClientID: uuid.New().String(),
63+
ClientName: "Device Test Client",
64+
UserID: uuid.New().String(),
65+
Scopes: "email profile",
66+
GrantTypes: "device_code",
67+
EnableDeviceFlow: deviceFlowEnabled,
68+
Status: status,
69+
}
70+
require.NoError(t, s.CreateClient(client))
71+
return client
72+
}
73+
74+
func TestDeviceCodeRequest_Success(t *testing.T) {
75+
r, s := setupDeviceTestEnv(t)
76+
client := createDeviceFlowClient(t, s, true, true)
77+
78+
w := httptest.NewRecorder()
79+
form := url.Values{"client_id": {client.ClientID}, "scope": {"email"}}
80+
req, _ := http.NewRequest(
81+
http.MethodPost,
82+
"/oauth/device/code",
83+
strings.NewReader(form.Encode()),
84+
)
85+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
86+
r.ServeHTTP(w, req)
87+
88+
assert.Equal(t, http.StatusOK, w.Code)
89+
90+
var resp map[string]any
91+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
92+
assert.NotEmpty(t, resp["device_code"])
93+
assert.NotEmpty(t, resp["user_code"])
94+
assert.Contains(t, resp["verification_uri"], "/device")
95+
assert.NotZero(t, resp["expires_in"])
96+
assert.NotZero(t, resp["interval"])
97+
}
98+
99+
func TestDeviceCodeRequest_MissingClientID(t *testing.T) {
100+
r, _ := setupDeviceTestEnv(t)
101+
102+
w := httptest.NewRecorder()
103+
req, _ := http.NewRequest(http.MethodPost, "/oauth/device/code", nil)
104+
r.ServeHTTP(w, req)
105+
106+
assert.Equal(t, http.StatusBadRequest, w.Code)
107+
108+
var resp map[string]string
109+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
110+
assert.Equal(t, "invalid_request", resp["error"])
111+
}
112+
113+
func TestDeviceCodeRequest_UnknownClient(t *testing.T) {
114+
r, _ := setupDeviceTestEnv(t)
115+
116+
w := httptest.NewRecorder()
117+
form := url.Values{"client_id": {"nonexistent-client"}}
118+
req, _ := http.NewRequest(
119+
http.MethodPost,
120+
"/oauth/device/code",
121+
strings.NewReader(form.Encode()),
122+
)
123+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
124+
r.ServeHTTP(w, req)
125+
126+
assert.Equal(t, http.StatusBadRequest, w.Code)
127+
128+
var resp map[string]string
129+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
130+
assert.Equal(t, "invalid_client", resp["error"])
131+
}
132+
133+
func TestDeviceCodeRequest_InactiveClient(t *testing.T) {
134+
r, s := setupDeviceTestEnv(t)
135+
client := createDeviceFlowClient(t, s, false, true)
136+
137+
w := httptest.NewRecorder()
138+
form := url.Values{"client_id": {client.ClientID}}
139+
req, _ := http.NewRequest(
140+
http.MethodPost,
141+
"/oauth/device/code",
142+
strings.NewReader(form.Encode()),
143+
)
144+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
145+
r.ServeHTTP(w, req)
146+
147+
assert.Equal(t, http.StatusBadRequest, w.Code)
148+
149+
var resp map[string]string
150+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
151+
assert.Equal(t, "invalid_client", resp["error"])
152+
}
153+
154+
func TestDeviceCodeRequest_DeviceFlowNotEnabled(t *testing.T) {
155+
r, s := setupDeviceTestEnv(t)
156+
// Create client with device flow enabled first, then disable it
157+
// (GORM default:true skips zero-value false on insert)
158+
client := createDeviceFlowClient(t, s, true, true)
159+
client.EnableDeviceFlow = false
160+
require.NoError(t, s.UpdateClient(client))
161+
162+
w := httptest.NewRecorder()
163+
form := url.Values{"client_id": {client.ClientID}}
164+
req, _ := http.NewRequest(
165+
http.MethodPost,
166+
"/oauth/device/code",
167+
strings.NewReader(form.Encode()),
168+
)
169+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
170+
r.ServeHTTP(w, req)
171+
172+
assert.Equal(t, http.StatusBadRequest, w.Code)
173+
174+
var resp map[string]string
175+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
176+
assert.Equal(t, "unauthorized_client", resp["error"])
177+
}
178+
179+
func TestDeviceCodeRequest_JSONBody(t *testing.T) {
180+
r, s := setupDeviceTestEnv(t)
181+
client := createDeviceFlowClient(t, s, true, true)
182+
183+
w := httptest.NewRecorder()
184+
body := `{"client_id":"` + client.ClientID + `","scope":"profile"}`
185+
req, _ := http.NewRequest(http.MethodPost, "/oauth/device/code", strings.NewReader(body))
186+
req.Header.Set("Content-Type", "application/json")
187+
r.ServeHTTP(w, req)
188+
189+
assert.Equal(t, http.StatusOK, w.Code)
190+
191+
var resp map[string]any
192+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
193+
assert.NotEmpty(t, resp["device_code"])
194+
}
195+
196+
func TestDeviceCodeRequest_DefaultScope(t *testing.T) {
197+
r, s := setupDeviceTestEnv(t)
198+
client := createDeviceFlowClient(t, s, true, true)
199+
200+
w := httptest.NewRecorder()
201+
form := url.Values{"client_id": {client.ClientID}}
202+
req, _ := http.NewRequest(
203+
http.MethodPost,
204+
"/oauth/device/code",
205+
strings.NewReader(form.Encode()),
206+
)
207+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
208+
r.ServeHTTP(w, req)
209+
210+
assert.Equal(t, http.StatusOK, w.Code)
211+
212+
var resp map[string]any
213+
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &resp))
214+
assert.NotEmpty(t, resp["device_code"])
215+
assert.NotEmpty(t, resp["user_code"])
216+
}

internal/handlers/session_test.go

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
package handlers
2+
3+
import (
4+
"context"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"github.com/go-authgate/authgate/internal/config"
11+
"github.com/go-authgate/authgate/internal/metrics"
12+
"github.com/go-authgate/authgate/internal/models"
13+
"github.com/go-authgate/authgate/internal/services"
14+
"github.com/go-authgate/authgate/internal/store"
15+
"github.com/go-authgate/authgate/internal/token"
16+
"github.com/go-authgate/authgate/internal/util"
17+
18+
"github.com/gin-gonic/gin"
19+
"github.com/google/uuid"
20+
"github.com/stretchr/testify/assert"
21+
"github.com/stretchr/testify/require"
22+
)
23+
24+
// setupSessionServices creates the store and services needed for session tests
25+
// without building a Gin router (each test wires its own routes with appropriate middleware).
26+
func setupSessionServices(t *testing.T) (*store.Store, *services.TokenService) {
27+
t.Helper()
28+
gin.SetMode(gin.TestMode)
29+
30+
cfg := &config.Config{
31+
JWTExpiration: 1 * time.Hour,
32+
JWTSecret: "test-secret-32-chars-long!!!!!!!",
33+
BaseURL: "http://localhost:8080",
34+
}
35+
36+
s, err := store.New(context.Background(), "sqlite", ":memory:", &config.Config{})
37+
require.NoError(t, err)
38+
39+
localProvider := token.NewLocalTokenProvider(cfg)
40+
auditSvc := services.NewAuditService(s, false, 0)
41+
deviceSvc := services.NewDeviceService(s, cfg, auditSvc, metrics.NewNoopMetrics())
42+
tokenSvc := services.NewTokenService(
43+
s, cfg, deviceSvc, localProvider, auditSvc, metrics.NewNoopMetrics(),
44+
)
45+
46+
return s, tokenSvc
47+
}
48+
49+
func createTestToken(t *testing.T, s *store.Store, userID, clientID string) *models.AccessToken {
50+
t.Helper()
51+
tok := &models.AccessToken{
52+
ID: uuid.New().String(),
53+
TokenHash: util.SHA256Hex(uuid.New().String()),
54+
TokenCategory: models.TokenCategoryAccess,
55+
Status: models.TokenStatusActive,
56+
UserID: userID,
57+
ClientID: clientID,
58+
Scopes: "email profile",
59+
ExpiresAt: time.Now().Add(1 * time.Hour),
60+
}
61+
require.NoError(t, s.CreateAccessToken(tok))
62+
return tok
63+
}
64+
65+
// newSessionRouter creates a Gin router with the session handler and optional user_id injection.
66+
func newSessionRouter(handler *SessionHandler, userID string) *gin.Engine {
67+
r := gin.New()
68+
if userID != "" {
69+
r.Use(func(c *gin.Context) {
70+
c.Set("user_id", userID)
71+
c.Next()
72+
})
73+
}
74+
r.POST("/account/sessions/:id/revoke", handler.RevokeSession)
75+
r.POST("/account/sessions/:id/disable", handler.DisableSession)
76+
r.POST("/account/sessions/:id/enable", handler.EnableSession)
77+
r.POST("/account/sessions/revoke-all", handler.RevokeAllSessions)
78+
return r
79+
}
80+
81+
func TestRevokeSession_Success(t *testing.T) {
82+
s, tokenSvc := setupSessionServices(t)
83+
userID := uuid.New().String()
84+
tok := createTestToken(t, s, userID, uuid.New().String())
85+
86+
handler := NewSessionHandler(tokenSvc, nil)
87+
r := newSessionRouter(handler, userID)
88+
89+
w := httptest.NewRecorder()
90+
req, _ := http.NewRequest(http.MethodPost, "/account/sessions/"+tok.ID+"/revoke", nil)
91+
r.ServeHTTP(w, req)
92+
93+
assert.Equal(t, http.StatusFound, w.Code)
94+
95+
// Verify token was revoked
96+
_, err := s.GetAccessTokenByID(tok.ID)
97+
assert.Error(t, err) // deleted
98+
}
99+
100+
func TestRevokeSession_NotOwned(t *testing.T) {
101+
s, tokenSvc := setupSessionServices(t)
102+
ownerID := uuid.New().String()
103+
attackerID := uuid.New().String()
104+
tok := createTestToken(t, s, ownerID, uuid.New().String())
105+
106+
handler := NewSessionHandler(tokenSvc, nil)
107+
r := newSessionRouter(handler, attackerID)
108+
109+
w := httptest.NewRecorder()
110+
req, _ := http.NewRequest(http.MethodPost, "/account/sessions/"+tok.ID+"/revoke", nil)
111+
r.ServeHTTP(w, req)
112+
113+
assert.Equal(t, http.StatusForbidden, w.Code)
114+
}
115+
116+
func TestRevokeSession_Unauthenticated(t *testing.T) {
117+
_, tokenSvc := setupSessionServices(t)
118+
handler := NewSessionHandler(tokenSvc, nil)
119+
r := newSessionRouter(handler, "") // no user_id
120+
121+
w := httptest.NewRecorder()
122+
req, _ := http.NewRequest(http.MethodPost, "/account/sessions/some-id/revoke", nil)
123+
r.ServeHTTP(w, req)
124+
125+
assert.Equal(t, http.StatusUnauthorized, w.Code)
126+
}
127+
128+
func TestDisableAndEnableSession(t *testing.T) {
129+
s, tokenSvc := setupSessionServices(t)
130+
userID := uuid.New().String()
131+
tok := createTestToken(t, s, userID, uuid.New().String())
132+
133+
handler := NewSessionHandler(tokenSvc, nil)
134+
r := newSessionRouter(handler, userID)
135+
136+
// Disable
137+
w := httptest.NewRecorder()
138+
req, _ := http.NewRequest(http.MethodPost, "/account/sessions/"+tok.ID+"/disable", nil)
139+
r.ServeHTTP(w, req)
140+
assert.Equal(t, http.StatusFound, w.Code)
141+
142+
disabled, err := s.GetAccessTokenByID(tok.ID)
143+
require.NoError(t, err)
144+
assert.Equal(t, models.TokenStatusDisabled, disabled.Status)
145+
146+
// Enable
147+
w = httptest.NewRecorder()
148+
req, _ = http.NewRequest(http.MethodPost, "/account/sessions/"+tok.ID+"/enable", nil)
149+
r.ServeHTTP(w, req)
150+
assert.Equal(t, http.StatusFound, w.Code)
151+
152+
enabled, err := s.GetAccessTokenByID(tok.ID)
153+
require.NoError(t, err)
154+
assert.Equal(t, models.TokenStatusActive, enabled.Status)
155+
}
156+
157+
func TestRevokeAllSessions(t *testing.T) {
158+
s, tokenSvc := setupSessionServices(t)
159+
userID := uuid.New().String()
160+
clientID := uuid.New().String()
161+
createTestToken(t, s, userID, clientID)
162+
createTestToken(t, s, userID, clientID)
163+
164+
handler := NewSessionHandler(tokenSvc, nil)
165+
r := newSessionRouter(handler, userID)
166+
167+
w := httptest.NewRecorder()
168+
req, _ := http.NewRequest(http.MethodPost, "/account/sessions/revoke-all", nil)
169+
r.ServeHTTP(w, req)
170+
171+
assert.Equal(t, http.StatusFound, w.Code)
172+
173+
tokens, err := s.GetTokensByUserID(userID)
174+
require.NoError(t, err)
175+
assert.Empty(t, tokens)
176+
}

0 commit comments

Comments
 (0)