|
1 | 1 | package web |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "io" |
| 5 | + "net/http" |
| 6 | + "net/http/httptest" |
| 7 | + "regexp" |
| 8 | + "strings" |
4 | 9 | "testing" |
5 | 10 | "time" |
6 | 11 |
|
| 12 | + "github.com/gofiber/fiber/v2" |
| 13 | + "github.com/gofiber/fiber/v2/middleware/session" |
| 14 | + "github.com/gofiber/storage/memory/v2" |
7 | 15 | ldap "github.com/netresearch/simple-ldap-go" |
8 | 16 |
|
9 | 17 | "github.com/netresearch/ldap-manager/internal/options" |
@@ -177,3 +185,176 @@ func TestCSRFConfigurationAcceptsOpts(t *testing.T) { |
177 | 185 | // Handler created successfully - type is fiber.Handler (internal Fiber type) |
178 | 186 | t.Log("CSRF handler created successfully with opts parameter") |
179 | 187 | } |
| 188 | + |
| 189 | +// TestCSRFTokenValidation verifies that CSRF tokens are properly validated on POST requests. |
| 190 | +// This test ensures the CSRF expiration is set correctly (regression test for the 3600 nanoseconds bug). |
| 191 | +// |
| 192 | +//nolint:gocognit // Test function with multiple subtests has inherent complexity |
| 193 | +func TestCSRFTokenValidation(t *testing.T) { |
| 194 | + opts := &options.Opts{ |
| 195 | + LDAP: ldap.Config{ |
| 196 | + Server: "ldap://localhost:389", |
| 197 | + BaseDN: "dc=test,dc=local", |
| 198 | + IsActiveDirectory: false, |
| 199 | + }, |
| 200 | + ReadonlyUser: "cn=readonly,dc=test,dc=local", |
| 201 | + ReadonlyPassword: "password", |
| 202 | + CookieSecure: false, // HTTP for testing |
| 203 | + PersistSessions: false, |
| 204 | + SessionDuration: 30 * time.Minute, |
| 205 | + PoolMaxConnections: 10, |
| 206 | + PoolMinConnections: 2, |
| 207 | + PoolMaxIdleTime: 15 * time.Minute, |
| 208 | + PoolHealthCheckInterval: 30 * time.Second, |
| 209 | + PoolConnectionTimeout: 30 * time.Second, |
| 210 | + PoolAcquireTimeout: 10 * time.Second, |
| 211 | + } |
| 212 | + |
| 213 | + // Create a test Fiber app with CSRF middleware |
| 214 | + f := fiber.New() |
| 215 | + csrfHandler := createCSRFConfig(opts) |
| 216 | + sessionStore := session.New(session.Config{ |
| 217 | + Storage: memory.New(), |
| 218 | + }) |
| 219 | + |
| 220 | + // Test endpoint that returns CSRF token on GET and validates on POST |
| 221 | + f.All("/test-csrf", *csrfHandler, func(c *fiber.Ctx) error { |
| 222 | + sess, err := sessionStore.Get(c) |
| 223 | + if err != nil { |
| 224 | + return c.Status(fiber.StatusInternalServerError).SendString("Failed to get session") |
| 225 | + } |
| 226 | + defer func() { _ = sess.Save() }() |
| 227 | + |
| 228 | + if c.Method() == "GET" { |
| 229 | + token := c.Locals("token") |
| 230 | + if token == nil { |
| 231 | + return c.Status(fiber.StatusInternalServerError).SendString("No CSRF token generated") |
| 232 | + } |
| 233 | + |
| 234 | + tokenStr, ok := token.(string) |
| 235 | + if !ok { |
| 236 | + return c.Status(fiber.StatusInternalServerError).SendString("CSRF token is not a string") |
| 237 | + } |
| 238 | + |
| 239 | + return c.SendString("csrf_token:" + tokenStr) |
| 240 | + } |
| 241 | + // POST - if we get here, CSRF validation passed |
| 242 | + return c.SendString("CSRF validation passed") |
| 243 | + }) |
| 244 | + |
| 245 | + t.Run("GET request returns CSRF token", func(t *testing.T) { |
| 246 | + req := httptest.NewRequest("GET", "/test-csrf", nil) |
| 247 | + resp, err := f.Test(req) |
| 248 | + if err != nil { |
| 249 | + t.Fatalf("Request failed: %v", err) |
| 250 | + } |
| 251 | + defer func() { _ = resp.Body.Close() }() |
| 252 | + |
| 253 | + if resp.StatusCode != http.StatusOK { |
| 254 | + t.Errorf("Expected status %d, got %d", http.StatusOK, resp.StatusCode) |
| 255 | + } |
| 256 | + |
| 257 | + body, err := io.ReadAll(resp.Body) |
| 258 | + if err != nil { |
| 259 | + t.Fatalf("Failed to read response body: %v", err) |
| 260 | + } |
| 261 | + |
| 262 | + if !strings.HasPrefix(string(body), "csrf_token:") { |
| 263 | + t.Errorf("Expected CSRF token in response, got: %s", string(body)) |
| 264 | + } |
| 265 | + }) |
| 266 | + |
| 267 | + t.Run("POST without CSRF token returns 403 Forbidden", func(t *testing.T) { |
| 268 | + req := httptest.NewRequest("POST", "/test-csrf", strings.NewReader("data=test")) |
| 269 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 270 | + resp, err := f.Test(req) |
| 271 | + if err != nil { |
| 272 | + t.Fatalf("Request failed: %v", err) |
| 273 | + } |
| 274 | + defer func() { _ = resp.Body.Close() }() |
| 275 | + |
| 276 | + if resp.StatusCode != http.StatusForbidden { |
| 277 | + t.Errorf("Expected status %d for missing CSRF token, got %d", http.StatusForbidden, resp.StatusCode) |
| 278 | + } |
| 279 | + }) |
| 280 | + |
| 281 | + t.Run("POST with invalid CSRF token returns 403 Forbidden", func(t *testing.T) { |
| 282 | + req := httptest.NewRequest("POST", "/test-csrf", strings.NewReader("csrf_token=invalid-token&data=test")) |
| 283 | + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 284 | + resp, err := f.Test(req) |
| 285 | + if err != nil { |
| 286 | + t.Fatalf("Request failed: %v", err) |
| 287 | + } |
| 288 | + defer func() { _ = resp.Body.Close() }() |
| 289 | + |
| 290 | + if resp.StatusCode != http.StatusForbidden { |
| 291 | + t.Errorf("Expected status %d for invalid CSRF token, got %d", http.StatusForbidden, resp.StatusCode) |
| 292 | + } |
| 293 | + }) |
| 294 | + |
| 295 | + t.Run("POST with valid CSRF token succeeds", func(t *testing.T) { |
| 296 | + // Step 1: GET to obtain CSRF token and cookie |
| 297 | + getReq := httptest.NewRequest("GET", "/test-csrf", nil) |
| 298 | + getResp, err := f.Test(getReq) |
| 299 | + if err != nil { |
| 300 | + t.Fatalf("GET request failed: %v", err) |
| 301 | + } |
| 302 | + |
| 303 | + // Extract CSRF token from response body |
| 304 | + body, err := io.ReadAll(getResp.Body) |
| 305 | + if err != nil { |
| 306 | + t.Fatalf("Failed to read response body: %v", err) |
| 307 | + } |
| 308 | + _ = getResp.Body.Close() |
| 309 | + |
| 310 | + tokenMatch := regexp.MustCompile(`csrf_token:(.+)`).FindStringSubmatch(string(body)) |
| 311 | + if len(tokenMatch) < 2 { |
| 312 | + t.Fatalf("Could not extract CSRF token from response: %s", string(body)) |
| 313 | + } |
| 314 | + csrfToken := tokenMatch[1] |
| 315 | + |
| 316 | + // Extract CSRF cookie |
| 317 | + var csrfCookie *http.Cookie |
| 318 | + for _, cookie := range getResp.Cookies() { |
| 319 | + if strings.HasPrefix(cookie.Name, "csrf_") { |
| 320 | + csrfCookie = cookie |
| 321 | + |
| 322 | + break |
| 323 | + } |
| 324 | + } |
| 325 | + |
| 326 | + if csrfCookie == nil { |
| 327 | + t.Fatal("CSRF cookie not found in response") |
| 328 | + } |
| 329 | + |
| 330 | + // Step 2: POST with valid CSRF token and cookie |
| 331 | + postReq := httptest.NewRequest("POST", "/test-csrf", |
| 332 | + strings.NewReader("csrf_token="+csrfToken+"&data=test")) |
| 333 | + postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") |
| 334 | + postReq.AddCookie(csrfCookie) |
| 335 | + |
| 336 | + postResp, err := f.Test(postReq) |
| 337 | + if err != nil { |
| 338 | + t.Fatalf("POST request failed: %v", err) |
| 339 | + } |
| 340 | + defer func() { _ = postResp.Body.Close() }() |
| 341 | + |
| 342 | + // Read response body once for both assertions |
| 343 | + respBody, err := io.ReadAll(postResp.Body) |
| 344 | + if err != nil { |
| 345 | + t.Fatalf("Failed to read response body: %v", err) |
| 346 | + } |
| 347 | + |
| 348 | + // This is the critical test: with the bug (Expiration: 3600 nanoseconds), |
| 349 | + // the token would expire immediately and this would return 403. |
| 350 | + // With the fix (Expiration: time.Hour), this should return 200. |
| 351 | + if postResp.StatusCode != http.StatusOK { |
| 352 | + t.Errorf("Expected status %d for valid CSRF token, got %d. Response: %s", |
| 353 | + http.StatusOK, postResp.StatusCode, string(respBody)) |
| 354 | + } |
| 355 | + |
| 356 | + if string(respBody) != "CSRF validation passed" { |
| 357 | + t.Errorf("Expected 'CSRF validation passed', got: %s", string(respBody)) |
| 358 | + } |
| 359 | + }) |
| 360 | +} |
0 commit comments