diff --git a/README.md b/README.md index 2b77659..88f064c 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,68 @@ sudo sysctl -w net.ipv4.ip_forward=1 sudo sysctl -w net.ipv4.conf.all.rp_filter=2 ``` +### Performance Tuning (Optional) + +For maximum throughput, especially on high-bandwidth servers with many peers, apply these additional sysctl settings. Create a file `/etc/sysctl.d/99-vprox-performance.conf` or apply them temporarily: + +```bash +# UDP/Socket Buffer Sizes (WireGuard uses UDP) +# Increase max buffer sizes to 25MB for high-throughput scenarios +sudo sysctl -w net.core.rmem_max=26214400 +sudo sysctl -w net.core.wmem_max=26214400 +sudo sysctl -w net.core.rmem_default=1048576 +sudo sysctl -w net.core.wmem_default=1048576 + +# Network device backlog (for high packet rates) +# Increase backlog to handle traffic bursts +sudo sysctl -w net.core.netdev_max_backlog=50000 +sudo sysctl -w net.core.netdev_budget=600 + +# TCP tuning (for traffic inside the tunnel) +# Format: min default max +sudo sysctl -w net.ipv4.tcp_rmem="4096 1048576 26214400" +sudo sysctl -w net.ipv4.tcp_wmem="4096 1048576 26214400" + +# Use BBR congestion control (better than cubic for most scenarios) +sudo sysctl -w net.ipv4.tcp_congestion_control=bbr + +# Enable TCP Fast Open for reduced latency on reconnects +sudo sysctl -w net.ipv4.tcp_fastopen=3 + +# Connection tracking limits (critical for NAT with many peers) +# Increase max tracked connections to 1M +sudo sysctl -w net.netfilter.nf_conntrack_max=1048576 + +# Optional: Busy polling for lower latency (increases CPU usage) +# sudo sysctl -w net.core.busy_poll=50 +# sudo sysctl -w net.core.busy_read=50 +``` + +To make these settings persistent across reboots, add them to `/etc/sysctl.d/99-vprox-performance.conf` without the `sudo sysctl -w` prefix: + +``` +# /etc/sysctl.d/99-vprox-performance.conf +net.core.rmem_max=26214400 +net.core.wmem_max=26214400 +net.core.rmem_default=1048576 +net.core.wmem_default=1048576 +net.core.netdev_max_backlog=50000 +net.core.netdev_budget=600 +net.ipv4.tcp_rmem=4096 1048576 26214400 +net.ipv4.tcp_wmem=4096 1048576 26214400 +net.ipv4.tcp_congestion_control=bbr +net.ipv4.tcp_fastopen=3 +net.netfilter.nf_conntrack_max=1048576 +``` + +Then apply with `sudo sysctl --system`. + To set up `vprox`, you'll need the private IPv4 address of the server connected to an Internet gateway (use the `ip addr` command), as well as a block of IPs to allocate to the WireGuard subnet between server and client. This has no particular meaning and can be arbitrarily chosen to not overlap with other subnets. +#### Password Authentication (default) + +The default authentication mode uses a shared password via the `VPROX_PASSWORD` environment variable: + ```bash # [Machine A: public IP 1.2.3.4, private IP 172.31.64.125] VPROX_PASSWORD=my-password vprox server --ip 172.31.64.125 --wg-block 240.1.0.0/16 @@ -38,6 +98,39 @@ curl ifconfig.me # => 5.6.7.8 curl --interface vprox0 ifconfig.me # => 1.2.3.4 ``` +#### OIDC Authentication (Modal) + +vprox supports OIDC token-based authentication, designed for use with [Modal's OIDC integration](https://modal.com/docs/guide/oidc-integration). In this mode, the server verifies JWT identity tokens signed by Modal instead of using a shared password. + +**Server setup:** + +```bash +# Start the server in oidc-modal mode, restricting access to a specific Modal workspace +VPROX_AUTH_MODE=oidc-modal \ +VPROX_OIDC_ISSUER=https://oidc.modal.com \ +VPROX_OIDC_ALLOWED_WORKSPACE_IDS=ws-abc123 \ + vprox server --ip 172.31.64.125 --wg-block 240.1.0.0/16 +``` + +**Client setup (inside a Modal container):** + +```bash +# Modal sets MODAL_IDENTITY_TOKEN automatically; pass it to vprox via VPROX_OIDC_TOKEN +VPROX_AUTH_MODE=oidc-modal \ +VPROX_OIDC_TOKEN="$MODAL_IDENTITY_TOKEN" \ + vprox connect 1.2.3.4 --interface vprox0 +``` + +**OIDC environment variables:** + +| Variable | Description | Default | +|---|---|---| +| `VPROX_AUTH_MODE` | Auth mode: `password` or `oidc-modal` | `password` | +| `VPROX_OIDC_TOKEN` | OIDC identity token (JWT) for client authentication | — | +| `VPROX_OIDC_ISSUER` | OIDC issuer URL | `https://oidc.modal.com` | +| `VPROX_OIDC_AUDIENCE` | Expected `aud` claim (skip check if empty) | _(empty)_ | +| `VPROX_OIDC_ALLOWED_WORKSPACE_IDS` | Comma-separated list of allowed Modal workspace IDs. Set to `*` to explicitly allow all workspaces (**testing only**). | _(any)_ | + Note that Machine B must be able to send UDP packets to port 50227 on Machine A, and TCP to port 443. All outbound network traffic seen by `vprox0` will automatically be forwarded through the WireGuard tunnel. The VPN server masquerades the source IP address. @@ -68,6 +161,9 @@ On AWS in particular, the `--cloud aws` option allows you to automatically disco - Automatic discovery of IPs using instance metadata endpoints (AWS) - Only one vprox server may be running on a host - Control traffic is encrypted with TLS (Warning: does not verify server certificate) +- Optimized for throughput with automatic MTU, MSS, GSO/GRO, and multi-queue configuration +- Connection tracking bypass (NOTRACK) for reduced CPU overhead on WireGuard UDP flows +- OIDC authentication for passwordless auth from Modal containers (`oidc-modal`) ## Authors diff --git a/cmd/connect.go b/cmd/connect.go index 0cf7f12..4f22ebe 100644 --- a/cmd/connect.go +++ b/cmd/connect.go @@ -87,7 +87,7 @@ func runConnect(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to load server key: %v", err) } - password, err := lib.GetVproxPassword() + token, err := lib.GetClientToken() if err != nil { return err } @@ -101,7 +101,7 @@ func runConnect(cmd *cobra.Command, args []string) error { Key: key, Ifname: connectCmdArgs.ifname, ServerIp: serverIp, - Password: password, + Token: token, WgClient: wgClient, Http: &http.Client{ Timeout: 5 * time.Second, diff --git a/cmd/server.go b/cmd/server.go index 47c366b..5b48877 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -93,14 +93,14 @@ func runServer(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to load server key: %v", err) } - password, err := lib.GetVproxPassword() + auth, err := lib.GetAuthenticator() if err != nil { return err } ctx, done := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) - sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, password, serverCmdArgs.takeover) + sm, err := lib.NewServerManager(wgBlock, wgBlockPerIp, ctx, key, auth, serverCmdArgs.takeover) if err != nil { done() return err diff --git a/lib/auth.go b/lib/auth.go new file mode 100644 index 0000000..d17493e --- /dev/null +++ b/lib/auth.go @@ -0,0 +1,424 @@ +package lib + +import ( + "crypto" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "math/big" + "net/http" + "strings" + "sync" + "time" +) + +// AuthMode determines how the server authenticates incoming requests. +type AuthMode string + +const ( + AuthModePassword AuthMode = "password" + AuthModeOIDCModal AuthMode = "oidc-modal" +) + +// OIDCConfig holds configuration for OIDC-based authentication. +type OIDCConfig struct { + // IssuerURL is the Modal OIDC issuer URL (e.g. "https://oidc.modal.com"). + IssuerURL string + + // Audience is the expected "aud" claim in the token. If empty, audience is not checked. + Audience string + + // AllowedWorkspaceIDs is a list of Modal workspace IDs that are allowed to authenticate. + // If empty, any workspace is allowed (only issuer/signature are checked). + AllowedWorkspaceIDs []string +} + +// Authenticator provides request authentication for the vprox server. +type Authenticator struct { + mode AuthMode + password string + oidc *OIDCConfig + jwks *JWKSCache +} + +// NewPasswordAuthenticator creates an Authenticator that uses password-based auth. +func NewPasswordAuthenticator(password string) *Authenticator { + return &Authenticator{ + mode: AuthModePassword, + password: password, + } +} + +// NewOIDCModalAuthenticator creates an Authenticator that validates Modal OIDC tokens. +func NewOIDCModalAuthenticator(config *OIDCConfig) (*Authenticator, error) { + if config.IssuerURL == "" { + return nil, errors.New("OIDC issuer URL is required") + } + // Strip trailing slash for consistency. + config.IssuerURL = strings.TrimRight(config.IssuerURL, "/") + + jwksURL, err := discoverJWKSURL(config.IssuerURL) + if err != nil { + return nil, fmt.Errorf("failed to discover JWKS URL from issuer %s: %v", config.IssuerURL, err) + } + + return &Authenticator{ + mode: AuthModeOIDCModal, + oidc: config, + jwks: NewJWKSCache(jwksURL), + }, nil +} + +// Authenticate checks the Authorization header of an HTTP request. +// Returns nil on success, or an error describing the failure. +func (a *Authenticator) Authenticate(r *http.Request) error { + auth := r.Header.Get("Authorization") + if !strings.HasPrefix(auth, "Bearer ") { + return errors.New("missing or malformed Authorization header") + } + token := strings.TrimPrefix(auth, "Bearer ") + + switch a.mode { + case AuthModePassword: + if token != a.password { + return errors.New("invalid password") + } + return nil + + case AuthModeOIDCModal: + return a.verifyOIDCToken(token) + + default: + return fmt.Errorf("unknown auth mode: %s", a.mode) + } +} + +// Mode returns the authentication mode. +func (a *Authenticator) Mode() AuthMode { + return a.mode +} + +// verifyOIDCToken verifies a JWT token against the OIDC provider's JWKS. +func (a *Authenticator) verifyOIDCToken(tokenStr string) error { + // Parse the JWT without verification first to get the header. + parts := strings.Split(tokenStr, ".") + if len(parts) != 3 { + return errors.New("invalid JWT: expected 3 parts") + } + + // Decode the header to get the key ID. + headerBytes, err := base64URLDecode(parts[0]) + if err != nil { + return fmt.Errorf("invalid JWT header encoding: %v", err) + } + + var header jwtHeader + if err := json.Unmarshal(headerBytes, &header); err != nil { + return fmt.Errorf("invalid JWT header: %v", err) + } + + if header.Alg != "RS256" { + return fmt.Errorf("unsupported JWT algorithm: %s", header.Alg) + } + + // Decode the payload. + payloadBytes, err := base64URLDecode(parts[1]) + if err != nil { + return fmt.Errorf("invalid JWT payload encoding: %v", err) + } + + var claims ModalClaims + if err := json.Unmarshal(payloadBytes, &claims); err != nil { + return fmt.Errorf("invalid JWT payload: %v", err) + } + + // Verify the signature using the JWKS. + sigBytes, err := base64URLDecode(parts[2]) + if err != nil { + return fmt.Errorf("invalid JWT signature encoding: %v", err) + } + + signedContent := parts[0] + "." + parts[1] + if err := a.jwks.VerifyRS256(header.Kid, []byte(signedContent), sigBytes); err != nil { + return fmt.Errorf("JWT signature verification failed: %v", err) + } + + // Verify standard claims. + now := time.Now().Unix() + + if claims.Exp != 0 && now > claims.Exp { + return fmt.Errorf("token expired at %d, current time is %d", claims.Exp, now) + } + + // Allow 60 seconds of clock skew for iat. + if claims.Iat != 0 && now < claims.Iat-60 { + return fmt.Errorf("token issued in the future: iat=%d, now=%d", claims.Iat, now) + } + + if claims.Iss != a.oidc.IssuerURL { + return fmt.Errorf("issuer mismatch: got %q, expected %q", claims.Iss, a.oidc.IssuerURL) + } + + if a.oidc.Audience != "" && claims.Aud != a.oidc.Audience { + return fmt.Errorf("audience mismatch: got %q, expected %q", claims.Aud, a.oidc.Audience) + } + + // Verify Modal workspace claim. + if len(a.oidc.AllowedWorkspaceIDs) > 0 { + if !stringInSlice(claims.WorkspaceID, a.oidc.AllowedWorkspaceIDs) { + return fmt.Errorf("workspace %q is not in the allowed list", claims.WorkspaceID) + } + } + + return nil +} + +// ModalClaims represents the claims in a Modal OIDC identity token. +type ModalClaims struct { + // Standard OIDC claims + Sub string `json:"sub"` + Aud string `json:"aud"` + Exp int64 `json:"exp"` + Iat int64 `json:"iat"` + Iss string `json:"iss"` + Jti string `json:"jti"` + + // Modal-specific claims + WorkspaceID string `json:"workspace_id"` + EnvironmentID string `json:"environment_id"` + EnvironmentName string `json:"environment_name"` + AppID string `json:"app_id"` + AppName string `json:"app_name"` + FunctionID string `json:"function_id"` + FunctionName string `json:"function_name"` + ContainerID string `json:"container_id"` +} + +type jwtHeader struct { + Alg string `json:"alg"` + Kid string `json:"kid"` + Typ string `json:"typ"` +} + +// --- JWKS Cache --- + +// JWKSCache fetches and caches JWKS keys from a remote endpoint. +type JWKSCache struct { + url string + mu sync.RWMutex + keys map[string]*rsa.PublicKey + lastFetch time.Time + httpClient *http.Client +} + +const jwksCacheDuration = 5 * time.Minute + +// NewJWKSCache creates a new JWKS cache for the given URL. +func NewJWKSCache(url string) *JWKSCache { + return &JWKSCache{ + url: url, + keys: make(map[string]*rsa.PublicKey), + httpClient: &http.Client{ + Timeout: 10 * time.Second, + }, + } +} + +// VerifyRS256 verifies an RS256 signature using the cached JWKS keys. +func (c *JWKSCache) VerifyRS256(kid string, message, signature []byte) error { + key, err := c.getKey(kid) + if err != nil { + return err + } + + return verifyRS256Signature(key, message, signature) +} + +// getKey returns the RSA public key for the given key ID, fetching from the +// JWKS endpoint if necessary. +func (c *JWKSCache) getKey(kid string) (*rsa.PublicKey, error) { + // Try to find the key in the cache first. + c.mu.RLock() + key, ok := c.keys[kid] + cacheValid := time.Since(c.lastFetch) < jwksCacheDuration + c.mu.RUnlock() + + if ok && cacheValid { + return key, nil + } + + // If the key is not found or the cache is stale, refresh. + if err := c.refresh(); err != nil { + return nil, fmt.Errorf("failed to refresh JWKS: %v", err) + } + + c.mu.RLock() + key, ok = c.keys[kid] + c.mu.RUnlock() + + if !ok { + return nil, fmt.Errorf("key %q not found in JWKS", kid) + } + + return key, nil +} + +// refresh fetches the JWKS from the remote endpoint and updates the cache. +func (c *JWKSCache) refresh() error { + resp, err := c.httpClient.Get(c.url) + if err != nil { + return fmt.Errorf("failed to fetch JWKS from %s: %v", c.url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("JWKS endpoint returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read JWKS response: %v", err) + } + + var jwks jwksResponse + if err := json.Unmarshal(body, &jwks); err != nil { + return fmt.Errorf("failed to parse JWKS: %v", err) + } + + keys := make(map[string]*rsa.PublicKey) + for _, jwk := range jwks.Keys { + if jwk.Kty != "RSA" { + continue + } + if jwk.Use != "" && jwk.Use != "sig" { + continue + } + + pubKey, err := jwkToRSAPublicKey(jwk) + if err != nil { + continue // skip malformed keys + } + + keys[jwk.Kid] = pubKey + } + + c.mu.Lock() + c.keys = keys + c.lastFetch = time.Now() + c.mu.Unlock() + + return nil +} + +type jwksResponse struct { + Keys []jwkKey `json:"keys"` +} + +type jwkKey struct { + Kty string `json:"kty"` + Use string `json:"use"` + Kid string `json:"kid"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` +} + +// --- OIDC Discovery --- + +type oidcDiscovery struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` +} + +// discoverJWKSURL fetches the OIDC discovery document and returns the JWKS URL. +func discoverJWKSURL(issuerURL string) (string, error) { + discoveryURL := issuerURL + "/.well-known/openid-configuration" + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Get(discoveryURL) + if err != nil { + return "", fmt.Errorf("failed to fetch OIDC discovery from %s: %v", discoveryURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("OIDC discovery endpoint returned status %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read OIDC discovery response: %v", err) + } + + var discovery oidcDiscovery + if err := json.Unmarshal(body, &discovery); err != nil { + return "", fmt.Errorf("failed to parse OIDC discovery document: %v", err) + } + + if discovery.JWKSURI == "" { + return "", errors.New("OIDC discovery document missing jwks_uri") + } + + return discovery.JWKSURI, nil +} + +// --- Crypto helpers --- + +// jwkToRSAPublicKey converts a JWK to an RSA public key. +func jwkToRSAPublicKey(jwk jwkKey) (*rsa.PublicKey, error) { + nBytes, err := base64URLDecode(jwk.N) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK modulus: %v", err) + } + + eBytes, err := base64URLDecode(jwk.E) + if err != nil { + return nil, fmt.Errorf("failed to decode JWK exponent: %v", err) + } + + n := new(big.Int).SetBytes(nBytes) + e := new(big.Int).SetBytes(eBytes) + + if !e.IsInt64() { + return nil, errors.New("JWK exponent too large") + } + + return &rsa.PublicKey{ + N: n, + E: int(e.Int64()), + }, nil +} + +// verifyRS256Signature verifies an RS256 (RSASSA-PKCS1-v1_5 with SHA-256) signature. +func verifyRS256Signature(pubKey *rsa.PublicKey, message, signature []byte) error { + // RS256 = RSASSA-PKCS1-v1_5 using SHA-256 + h := sha256.Sum256(message) + return rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, h[:], signature) +} + +// base64URLDecode decodes a base64url-encoded string (with or without padding). +func base64URLDecode(s string) ([]byte, error) { + // Add padding if needed. + switch len(s) % 4 { + case 2: + s += "==" + case 3: + s += "=" + } + return base64.URLEncoding.DecodeString(s) +} + +// --- Utility --- + +func stringInSlice(s string, slice []string) bool { + for _, v := range slice { + if v == s { + return true + } + } + return false +} diff --git a/lib/auth_test.go b/lib/auth_test.go new file mode 100644 index 0000000..52fd6c4 --- /dev/null +++ b/lib/auth_test.go @@ -0,0 +1,862 @@ +package lib + +import ( + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "math/big" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// --- helpers --- + +// testKeyPair generates a fresh RSA key pair for testing. +func testKeyPair(t *testing.T) *rsa.PrivateKey { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return key +} + +// b64url encodes bytes as unpadded base64url. +func b64url(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} + +// signJWT creates a signed RS256 JWT from the given header, claims, and private key. +func signJWT(t *testing.T, header jwtHeader, claims ModalClaims, key *rsa.PrivateKey) string { + t.Helper() + + hdrJSON, err := json.Marshal(header) + require.NoError(t, err) + claimsJSON, err := json.Marshal(claims) + require.NoError(t, err) + + payload := b64url(hdrJSON) + "." + b64url(claimsJSON) + h := sha256.Sum256([]byte(payload)) + sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, h[:]) + require.NoError(t, err) + + return payload + "." + b64url(sig) +} + +// serveJWKS starts a test HTTP server that serves a JWKS endpoint and an OIDC +// discovery endpoint for the given public keys. Returns the server and the +// issuer URL to use. +func serveJWKS(t *testing.T, keys map[string]*rsa.PublicKey) (*httptest.Server, string) { + t.Helper() + + jwksKeys := make([]jwkKey, 0, len(keys)) + for kid, pub := range keys { + jwksKeys = append(jwksKeys, jwkKey{ + Kty: "RSA", + Use: "sig", + Kid: kid, + Alg: "RS256", + N: b64url(pub.N.Bytes()), + E: b64url(big.NewInt(int64(pub.E)).Bytes()), + }) + } + + jwksResp := jwksResponse{Keys: jwksKeys} + jwksBytes, err := json.Marshal(jwksResp) + require.NoError(t, err) + + // We need a mux so we can serve both discovery and JWKS. + mux := http.NewServeMux() + + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mux.ServeHTTP(w, r) + })) + + issuer := srv.URL + + discovery := oidcDiscovery{ + Issuer: issuer, + JWKSURI: issuer + "/jwks", + } + discoveryBytes, err := json.Marshal(discovery) + require.NoError(t, err) + + mux.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(discoveryBytes) + }) + mux.HandleFunc("/jwks", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(jwksBytes) + }) + + t.Cleanup(srv.Close) + return srv, issuer +} + +// buildOIDCAuthenticator sets up a full OIDC authenticator backed by a test +// JWKS server. Returns the authenticator, the issuer URL, and the private key. +func buildOIDCAuthenticator(t *testing.T, configFn func(cfg *OIDCConfig)) (*Authenticator, string, *rsa.PrivateKey) { + t.Helper() + + priv := testKeyPair(t) + pub := &priv.PublicKey + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"test-key-1": pub}) + + cfg := &OIDCConfig{ + IssuerURL: issuer, + } + if configFn != nil { + configFn(cfg) + } + + auth, err := NewOIDCModalAuthenticator(cfg) + require.NoError(t, err) + return auth, issuer, priv +} + +func defaultHeader() jwtHeader { + return jwtHeader{Alg: "RS256", Kid: "test-key-1", Typ: "JWT"} +} + +func defaultClaims(issuer string) ModalClaims { + now := time.Now().Unix() + return ModalClaims{ + Sub: "ws-abc123:main:my-app:my-func", + Aud: "", + Exp: now + 3600, + Iat: now, + Iss: issuer, + Jti: "jti-random", + WorkspaceID: "ws-abc123", + EnvironmentID: "env-def456", + EnvironmentName: "main", + AppID: "app-ghi789", + AppName: "my-app", + FunctionID: "fn-jkl012", + FunctionName: "my-func", + ContainerID: "ctr-mno345", + } +} + +// --- Password auth tests --- + +func TestPasswordAuth_Success(t *testing.T) { + auth := NewPasswordAuthenticator("secret-pass") + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer secret-pass") + + err := auth.Authenticate(req) + assert.NoError(t, err) +} + +func TestPasswordAuth_WrongPassword(t *testing.T) { + auth := NewPasswordAuthenticator("secret-pass") + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer wrong-pass") + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid password") +} + +func TestPasswordAuth_MissingHeader(t *testing.T) { + auth := NewPasswordAuthenticator("secret-pass") + + req := httptest.NewRequest("GET", "/connect", nil) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing or malformed") +} + +func TestPasswordAuth_BasicScheme(t *testing.T) { + auth := NewPasswordAuthenticator("secret-pass") + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Basic dXNlcjpwYXNz") + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing or malformed") +} + +func TestPasswordAuth_Mode(t *testing.T) { + auth := NewPasswordAuthenticator("pw") + assert.Equal(t, AuthModePassword, auth.Mode()) +} + +// --- OIDC auth tests --- + +func TestOIDCAuth_ValidToken(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + token := signJWT(t, defaultHeader(), defaultClaims(issuer), priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.NoError(t, err) +} + +func TestOIDCAuth_ExpiredToken(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + claims := defaultClaims(issuer) + claims.Exp = time.Now().Unix() - 100 // expired 100 seconds ago + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token expired") +} + +func TestOIDCAuth_FutureIat(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + claims := defaultClaims(issuer) + claims.Iat = time.Now().Unix() + 3600 // issued 1 hour in the future + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "issued in the future") +} + +func TestOIDCAuth_WrongIssuer(t *testing.T) { + auth, _, priv := buildOIDCAuthenticator(t, nil) + + claims := defaultClaims("https://evil.example.com") + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "issuer mismatch") +} + +func TestOIDCAuth_AudienceCheck(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, func(cfg *OIDCConfig) { + cfg.Audience = "my-service" + }) + + // Token without audience should fail. + claims := defaultClaims(issuer) + claims.Aud = "" + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "audience mismatch") + + // Token with correct audience should succeed. + claims.Aud = "my-service" + token = signJWT(t, defaultHeader(), claims, priv) + req = httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err = auth.Authenticate(req) + assert.NoError(t, err) + + // Token with wrong audience should fail. + claims.Aud = "other-service" + token = signJWT(t, defaultHeader(), claims, priv) + req = httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err = auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "audience mismatch") +} + +func TestOIDCAuth_AllowedWorkspaceIDs(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, func(cfg *OIDCConfig) { + cfg.AllowedWorkspaceIDs = []string{"ws-allowed1", "ws-allowed2"} + }) + + // Allowed workspace. + claims := defaultClaims(issuer) + claims.WorkspaceID = "ws-allowed1" + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + assert.NoError(t, auth.Authenticate(req)) + + // Also allowed. + claims.WorkspaceID = "ws-allowed2" + token = signJWT(t, defaultHeader(), claims, priv) + req = httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + assert.NoError(t, auth.Authenticate(req)) + + // Disallowed workspace. + claims.WorkspaceID = "ws-evil" + token = signJWT(t, defaultHeader(), claims, priv) + req = httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "workspace") + assert.Contains(t, err.Error(), "not in the allowed list") +} + +func TestOIDCAuth_WrongSignature(t *testing.T) { + auth, issuer, _ := buildOIDCAuthenticator(t, nil) + + // Sign with a different key that the JWKS server doesn't know about. + otherKey := testKeyPair(t) + token := signJWT(t, defaultHeader(), defaultClaims(issuer), otherKey) + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "signature verification failed") +} + +func TestOIDCAuth_UnknownKid(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + header := defaultHeader() + header.Kid = "unknown-key-id" + + token := signJWT(t, header, defaultClaims(issuer), priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found in JWKS") +} + +func TestOIDCAuth_UnsupportedAlgorithm(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + header := defaultHeader() + header.Alg = "HS256" + + token := signJWT(t, header, defaultClaims(issuer), priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported JWT algorithm") +} + +func TestOIDCAuth_MalformedToken(t *testing.T) { + auth, _, _ := buildOIDCAuthenticator(t, nil) + + tests := []struct { + name string + token string + }{ + {"empty", ""}, + {"one part", "abc"}, + {"two parts", "abc.def"}, + {"four parts", "a.b.c.d"}, + {"garbage", "not-a-jwt-at-all!!!"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+tc.token) + err := auth.Authenticate(req) + assert.Error(t, err) + }) + } +} + +func TestOIDCAuth_MissingHeader(t *testing.T) { + auth, _, _ := buildOIDCAuthenticator(t, nil) + + req := httptest.NewRequest("GET", "/connect", nil) + err := auth.Authenticate(req) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing or malformed") +} + +func TestOIDCAuth_Mode(t *testing.T) { + auth, _, _ := buildOIDCAuthenticator(t, nil) + assert.Equal(t, AuthModeOIDCModal, auth.Mode()) +} + +func TestOIDCAuth_NoAudienceCheck_WhenNotConfigured(t *testing.T) { + // When audience is empty in config, any audience should be accepted. + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + claims := defaultClaims(issuer) + claims.Aud = "anything-goes" + token := signJWT(t, defaultHeader(), claims, priv) + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + assert.NoError(t, auth.Authenticate(req)) +} + +func TestOIDCAuth_NoWorkspaceCheck_WhenNotConfigured(t *testing.T) { + // When no workspace IDs are configured, any workspace should be accepted. + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + claims := defaultClaims(issuer) + claims.WorkspaceID = "ws-any-workspace" + token := signJWT(t, defaultHeader(), claims, priv) + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + assert.NoError(t, auth.Authenticate(req)) +} + +// --- base64URLDecode tests --- + +func TestBase64URLDecode(t *testing.T) { + tests := []struct { + input string + expected string + }{ + // No padding needed (len % 4 == 0) + {"aGVsbG8gd29ybGQh", "hello world!"}, + // 2 chars padding needed (len % 4 == 2) + {"YQ", "a"}, + // 1 char padding needed (len % 4 == 3) + {"YWI", "ab"}, + // Already padded + {"YQ==", "a"}, + {"YWI=", "ab"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result, err := base64URLDecode(tc.input) + require.NoError(t, err) + assert.Equal(t, tc.expected, string(result)) + }) + } +} + +// --- stringInSlice tests --- + +func TestStringInSlice(t *testing.T) { + assert.True(t, stringInSlice("a", []string{"a", "b", "c"})) + assert.True(t, stringInSlice("c", []string{"a", "b", "c"})) + assert.False(t, stringInSlice("d", []string{"a", "b", "c"})) + assert.False(t, stringInSlice("a", []string{})) + assert.False(t, stringInSlice("", []string{})) + assert.True(t, stringInSlice("", []string{""})) +} + +// --- JWK conversion tests --- + +func TestJWKToRSAPublicKey(t *testing.T) { + priv := testKeyPair(t) + pub := &priv.PublicKey + + jwk := jwkKey{ + Kty: "RSA", + N: b64url(pub.N.Bytes()), + E: b64url(big.NewInt(int64(pub.E)).Bytes()), + } + + result, err := jwkToRSAPublicKey(jwk) + require.NoError(t, err) + assert.Equal(t, pub.N.Cmp(result.N), 0) + assert.Equal(t, pub.E, result.E) +} + +func TestJWKToRSAPublicKey_BadModulus(t *testing.T) { + _, err := jwkToRSAPublicKey(jwkKey{ + Kty: "RSA", + N: "!!!invalid-base64!!!", + E: b64url(big.NewInt(65537).Bytes()), + }) + assert.Error(t, err) +} + +func TestJWKToRSAPublicKey_BadExponent(t *testing.T) { + priv := testKeyPair(t) + _, err := jwkToRSAPublicKey(jwkKey{ + Kty: "RSA", + N: b64url(priv.PublicKey.N.Bytes()), + E: "!!!invalid!!!", + }) + assert.Error(t, err) +} + +// --- JWKS Cache tests --- + +func TestJWKSCache_FetchesKeys(t *testing.T) { + priv := testKeyPair(t) + pub := &priv.PublicKey + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"k1": pub}) + + cache := NewJWKSCache(issuer + "/jwks") + + // Sign some data and verify. + msg := []byte("hello, world") + h := sha256.Sum256(msg) + sig, err := rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, h[:]) + require.NoError(t, err) + + err = cache.VerifyRS256("k1", msg, sig) + assert.NoError(t, err) +} + +func TestJWKSCache_UnknownKid(t *testing.T) { + priv := testKeyPair(t) + pub := &priv.PublicKey + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"k1": pub}) + + cache := NewJWKSCache(issuer + "/jwks") + + msg := []byte("hello") + h := sha256.Sum256(msg) + sig, err := rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, h[:]) + require.NoError(t, err) + + err = cache.VerifyRS256("nonexistent", msg, sig) + assert.Error(t, err) + assert.Contains(t, err.Error(), "not found in JWKS") +} + +func TestJWKSCache_InvalidSignature(t *testing.T) { + priv := testKeyPair(t) + pub := &priv.PublicKey + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"k1": pub}) + + cache := NewJWKSCache(issuer + "/jwks") + + err := cache.VerifyRS256("k1", []byte("hello"), []byte("bad-signature")) + assert.Error(t, err) +} + +func TestJWKSCache_MultipleKeys(t *testing.T) { + priv1 := testKeyPair(t) + priv2 := testKeyPair(t) + + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{ + "key-a": &priv1.PublicKey, + "key-b": &priv2.PublicKey, + }) + + cache := NewJWKSCache(issuer + "/jwks") + + msg := []byte("test message") + h := sha256.Sum256(msg) + + sig1, err := rsa.SignPKCS1v15(rand.Reader, priv1, crypto.SHA256, h[:]) + require.NoError(t, err) + sig2, err := rsa.SignPKCS1v15(rand.Reader, priv2, crypto.SHA256, h[:]) + require.NoError(t, err) + + assert.NoError(t, cache.VerifyRS256("key-a", msg, sig1)) + assert.NoError(t, cache.VerifyRS256("key-b", msg, sig2)) + + // Cross-verification should fail. + assert.Error(t, cache.VerifyRS256("key-a", msg, sig2)) + assert.Error(t, cache.VerifyRS256("key-b", msg, sig1)) +} + +// --- OIDC Discovery tests --- + +func TestDiscoverJWKSURL(t *testing.T) { + expected := "https://oidc.example.com/jwks" + discovery := oidcDiscovery{ + Issuer: "https://oidc.example.com", + JWKSURI: expected, + } + body, _ := json.Marshal(discovery) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/.well-known/openid-configuration" { + w.Header().Set("Content-Type", "application/json") + w.Write(body) + } else { + http.NotFound(w, r) + } + })) + defer srv.Close() + + url, err := discoverJWKSURL(srv.URL) + require.NoError(t, err) + assert.Equal(t, expected, url) +} + +func TestDiscoverJWKSURL_MissingJWKSURI(t *testing.T) { + discovery := oidcDiscovery{Issuer: "https://oidc.example.com"} + body, _ := json.Marshal(discovery) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write(body) + })) + defer srv.Close() + + _, err := discoverJWKSURL(srv.URL) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing jwks_uri") +} + +func TestDiscoverJWKSURL_ServerDown(t *testing.T) { + _, err := discoverJWKSURL("http://127.0.0.1:1") // nothing listening + assert.Error(t, err) +} + +// --- NewOIDCModalAuthenticator error cases --- + +func TestNewOIDCModalAuthenticator_EmptyIssuer(t *testing.T) { + _, err := NewOIDCModalAuthenticator(&OIDCConfig{IssuerURL: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "issuer URL is required") +} + +func TestNewOIDCModalAuthenticator_BadIssuer(t *testing.T) { + _, err := NewOIDCModalAuthenticator(&OIDCConfig{IssuerURL: "http://127.0.0.1:1"}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to discover JWKS") +} + +func TestNewOIDCModalAuthenticator_TrailingSlashNormalized(t *testing.T) { + priv := testKeyPair(t) + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"k": &priv.PublicKey}) + + // Pass issuer with trailing slash — it should be trimmed. + auth, err := NewOIDCModalAuthenticator(&OIDCConfig{IssuerURL: issuer + "/"}) + require.NoError(t, err) + + claims := defaultClaims(issuer) // claims use the issuer without trailing slash + token := signJWT(t, defaultHeader(), claims, priv) + + // Hack: the JWKS cache was set up via the server, and the kid we use + // is "test-key-1" in our default header, but the JWKS server above only + // has kid "k". So use kid "k" here. + header := jwtHeader{Alg: "RS256", Kid: "k", Typ: "JWT"} + token = signJWT(t, header, claims, priv) + + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err = auth.Authenticate(req) + assert.NoError(t, err) +} + +// --- Integration-style: full round-trip with all claim checks --- + +func TestOIDCAuth_FullClaimValidation(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, func(cfg *OIDCConfig) { + cfg.Audience = "vprox-server" + cfg.AllowedWorkspaceIDs = []string{"ws-prod"} + }) + + claims := defaultClaims(issuer) + claims.Aud = "vprox-server" + claims.WorkspaceID = "ws-prod" + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("POST", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.NoError(t, err) +} + +func TestOIDCAuth_FullClaimValidation_FailEach(t *testing.T) { + makeAuth := func(t *testing.T) (*Authenticator, string, *rsa.PrivateKey) { + return buildOIDCAuthenticator(t, func(cfg *OIDCConfig) { + cfg.Audience = "vprox-server" + cfg.AllowedWorkspaceIDs = []string{"ws-prod"} + }) + } + + tests := []struct { + name string + mutate func(c *ModalClaims) + errPart string + }{ + {"wrong audience", func(c *ModalClaims) { c.Aud = "other" }, "audience mismatch"}, + {"wrong workspace", func(c *ModalClaims) { c.WorkspaceID = "ws-other" }, "workspace"}, + {"expired", func(c *ModalClaims) { c.Exp = time.Now().Unix() - 10 }, "expired"}, + {"wrong issuer", func(c *ModalClaims) { c.Iss = "https://evil.example.com" }, "issuer mismatch"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + auth, issuer, priv := makeAuth(t) + + claims := defaultClaims(issuer) + claims.Aud = "vprox-server" + claims.WorkspaceID = "ws-prod" + tc.mutate(&claims) + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("POST", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.Error(t, err, "expected error for case: %s", tc.name) + assert.Contains(t, err.Error(), tc.errPart) + }) + } +} + +// --- splitCSV tests (from env.go) --- + +func TestSplitCSV(t *testing.T) { + tests := []struct { + input string + expected []string + }{ + {"", nil}, + {"a,b,c", []string{"a", "b", "c"}}, + {" a , b , c ", []string{"a", "b", "c"}}, + {"single", []string{"single"}}, + {"a,,b", []string{"a", "b"}}, + {",,,", nil}, + } + + for _, tc := range tests { + t.Run(fmt.Sprintf("%q", tc.input), func(t *testing.T) { + result := splitCSV(tc.input) + assert.Equal(t, tc.expected, result) + }) + } +} + +// --- Clock skew tolerance --- + +func TestOIDCAuth_ClockSkewTolerance(t *testing.T) { + auth, issuer, priv := buildOIDCAuthenticator(t, nil) + + // Token issued 30 seconds in the future should still be accepted + // (within the 60-second skew tolerance). + claims := defaultClaims(issuer) + claims.Iat = time.Now().Unix() + 30 + + token := signJWT(t, defaultHeader(), claims, priv) + req := httptest.NewRequest("GET", "/connect", nil) + req.Header.Set("Authorization", "Bearer "+token) + + err := auth.Authenticate(req) + assert.NoError(t, err) +} + +// --- Verify RS256 directly --- + +func TestVerifyRS256Signature(t *testing.T) { + priv := testKeyPair(t) + msg := []byte("test message for signing") + + h := sha256.Sum256(msg) + sig, err := rsa.SignPKCS1v15(rand.Reader, priv, crypto.SHA256, h[:]) + require.NoError(t, err) + + // Correct signature. + err = verifyRS256Signature(&priv.PublicKey, msg, sig) + assert.NoError(t, err) + + // Tampered message. + err = verifyRS256Signature(&priv.PublicKey, []byte("different message"), sig) + assert.Error(t, err) + + // Tampered signature. + badSig := make([]byte, len(sig)) + copy(badSig, sig) + badSig[0] ^= 0xFF + err = verifyRS256Signature(&priv.PublicKey, msg, badSig) + assert.Error(t, err) +} + +// --- Modal claims parsing --- + +func TestModalClaims_AllFieldsParsed(t *testing.T) { + raw := `{ + "sub": "ws-123:main:app:fn", + "aud": "my-aud", + "exp": 1700000000, + "iat": 1699999000, + "iss": "https://oidc.modal.com", + "jti": "some-jti", + "workspace_id": "ws-123", + "environment_id": "env-456", + "environment_name": "main", + "app_id": "app-789", + "app_name": "my-app", + "function_id": "fn-012", + "function_name": "my-func", + "container_id": "ctr-345" + }` + + var claims ModalClaims + err := json.Unmarshal([]byte(raw), &claims) + require.NoError(t, err) + + assert.Equal(t, "ws-123:main:app:fn", claims.Sub) + assert.Equal(t, "my-aud", claims.Aud) + assert.Equal(t, int64(1700000000), claims.Exp) + assert.Equal(t, int64(1699999000), claims.Iat) + assert.Equal(t, "https://oidc.modal.com", claims.Iss) + assert.Equal(t, "some-jti", claims.Jti) + assert.Equal(t, "ws-123", claims.WorkspaceID) + assert.Equal(t, "env-456", claims.EnvironmentID) + assert.Equal(t, "main", claims.EnvironmentName) + assert.Equal(t, "app-789", claims.AppID) + assert.Equal(t, "my-app", claims.AppName) + assert.Equal(t, "fn-012", claims.FunctionID) + assert.Equal(t, "my-func", claims.FunctionName) + assert.Equal(t, "ctr-345", claims.ContainerID) +} + +// --- Edge case: token with extra whitespace in Bearer prefix --- + +func TestAuth_BearerPrefixVariants(t *testing.T) { + auth := NewPasswordAuthenticator("pw") + + // "Bearer pw" (double space) should fail - we require exact "Bearer " prefix. + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "Bearer pw") + err := auth.Authenticate(req) + assert.Error(t, err, "double space after Bearer should be treated as part of the token") + + // Lowercase "bearer" should fail. + req = httptest.NewRequest("GET", "/", nil) + req.Header.Set("Authorization", "bearer pw") + err = auth.Authenticate(req) + assert.Error(t, err, "lowercase bearer should be rejected") +} + +// Verify that the JWT we build in tests really has 3 dot-separated parts. +func TestSignJWT_Format(t *testing.T) { + priv := testKeyPair(t) + _, issuer := serveJWKS(t, map[string]*rsa.PublicKey{"k": &priv.PublicKey}) + + token := signJWT(t, defaultHeader(), defaultClaims(issuer), priv) + parts := strings.Split(token, ".") + assert.Len(t, parts, 3, "JWT should have exactly 3 parts") + + for i, part := range parts { + assert.NotEmpty(t, part, "JWT part %d should not be empty", i) + } +} diff --git a/lib/client.go b/lib/client.go index 65f58e4..1b9a11b 100644 --- a/lib/client.go +++ b/lib/client.go @@ -51,8 +51,10 @@ type Client struct { // ServerIp is the public IPv4 address of the server. ServerIp netip.Addr - // Password authenticates the client connection. - Password string + // Token is the bearer token used to authenticate with the server. + // In password mode, this is the VPROX_PASSWORD value. + // In oidc-modal mode, this is the VPROX_OIDC_TOKEN value. + Token string // WgClient is a shared client for interacting with the WireGuard kernel module. WgClient *wgctrl.Client @@ -152,7 +154,7 @@ func (c *Client) sendConnectionRequest() (connectResponse, error) { Method: http.MethodPost, URL: connectUrl, Header: http.Header{ - "Authorization": []string{"Bearer " + c.Password}, + "Authorization": []string{"Bearer " + c.Token}, }, Body: io.NopCloser(bytes.NewBuffer(buf)), } @@ -231,7 +233,7 @@ func (c *Client) Disconnect() error { Method: http.MethodPost, URL: disconnectUrl, Header: http.Header{ - "Authorization": []string{"Bearer " + c.Password}, + "Authorization": []string{"Bearer " + c.Token}, }, Body: io.NopCloser(bytes.NewBuffer(buf)), } diff --git a/lib/env.go b/lib/env.go index 739ab39..a4c3338 100644 --- a/lib/env.go +++ b/lib/env.go @@ -2,7 +2,10 @@ package lib import ( "errors" + "fmt" + "log" "os" + "strings" ) func GetVproxPassword() (string, error) { @@ -12,3 +15,119 @@ func GetVproxPassword() (string, error) { } return password, nil } + +// GetAuthMode returns the configured auth mode from the VPROX_AUTH_MODE +// environment variable. Defaults to "password" if not set. +func GetAuthMode() AuthMode { + mode := os.Getenv("VPROX_AUTH_MODE") + switch strings.ToLower(mode) { + case "oidc-modal": + return AuthModeOIDCModal + case "password", "": + return AuthModePassword + default: + // Fall back to password mode for unknown values. + return AuthModePassword + } +} + +// GetOIDCIssuerURL returns the OIDC issuer URL from the VPROX_OIDC_ISSUER +// environment variable. Defaults to "https://oidc.modal.com" if not set. +func GetOIDCIssuerURL() string { + issuer := os.Getenv("VPROX_OIDC_ISSUER") + if issuer == "" { + return "https://oidc.modal.com" + } + return issuer +} + +// GetOIDCAudience returns the expected OIDC audience from the +// VPROX_OIDC_AUDIENCE environment variable. If empty, audience is not checked. +func GetOIDCAudience() string { + return os.Getenv("VPROX_OIDC_AUDIENCE") +} + +// GetOIDCAllowedWorkspaceIDs returns the list of allowed Modal workspace IDs +// from the VPROX_OIDC_ALLOWED_WORKSPACE_IDS environment variable (comma-separated). +// If empty, any workspace is allowed. If set to "*", all workspaces are explicitly +// allowed (returns nil) with a warning logged at startup. +func GetOIDCAllowedWorkspaceIDs() []string { + raw := os.Getenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS") + if strings.TrimSpace(raw) == "*" { + log.Println("WARNING: VPROX_OIDC_ALLOWED_WORKSPACE_IDS is set to '*', allowing ALL workspaces. This should only be used for testing!") + return nil + } + return splitCSV(raw) +} + +// GetOIDCToken returns the OIDC identity token from the VPROX_OIDC_TOKEN +// environment variable. +func GetOIDCToken() (string, error) { + token := os.Getenv("VPROX_OIDC_TOKEN") + if token == "" { + return "", errors.New("VPROX_OIDC_TOKEN environment variable is not set") + } + return token, nil +} + +// GetAuthenticator creates the appropriate Authenticator based on environment +// configuration. This is used by the server. +func GetAuthenticator() (*Authenticator, error) { + mode := GetAuthMode() + + switch mode { + case AuthModeOIDCModal: + config := &OIDCConfig{ + IssuerURL: GetOIDCIssuerURL(), + Audience: GetOIDCAudience(), + AllowedWorkspaceIDs: GetOIDCAllowedWorkspaceIDs(), + } + auth, err := NewOIDCModalAuthenticator(config) + if err != nil { + return nil, fmt.Errorf("failed to initialize OIDC authenticator: %v", err) + } + return auth, nil + + case AuthModePassword: + password, err := GetVproxPassword() + if err != nil { + return nil, err + } + return NewPasswordAuthenticator(password), nil + + default: + return nil, fmt.Errorf("unknown auth mode: %s", mode) + } +} + +// GetClientToken returns the bearer token the client should send to the server, +// based on the current auth mode. +func GetClientToken() (string, error) { + mode := GetAuthMode() + + switch mode { + case AuthModeOIDCModal: + return GetOIDCToken() + case AuthModePassword: + return GetVproxPassword() + default: + return "", fmt.Errorf("unknown auth mode: %s", mode) + } +} + +// splitCSV splits a comma-separated string into a slice, trimming whitespace +// and filtering out empty strings. +func splitCSV(s string) []string { + if s == "" { + return nil + } + parts := strings.Split(s, ",") + var result []string + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + result = append(result, p) + } + } + return result +} diff --git a/lib/env_test.go b/lib/env_test.go new file mode 100644 index 0000000..515cd50 --- /dev/null +++ b/lib/env_test.go @@ -0,0 +1,120 @@ +package lib + +import ( + "os" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGetOIDCAllowedWorkspaceIDs_Wildcard(t *testing.T) { + t.Setenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS", "*") + result := GetOIDCAllowedWorkspaceIDs() + assert.Nil(t, result, "wildcard '*' should return nil (allow all)") +} + +func TestGetOIDCAllowedWorkspaceIDs_WildcardWithWhitespace(t *testing.T) { + t.Setenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS", " * ") + result := GetOIDCAllowedWorkspaceIDs() + assert.Nil(t, result, "wildcard ' * ' with whitespace should return nil (allow all)") +} + +func TestGetOIDCAllowedWorkspaceIDs_CommaSeparated(t *testing.T) { + t.Setenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS", "ws-abc,ws-def") + result := GetOIDCAllowedWorkspaceIDs() + assert.Equal(t, []string{"ws-abc", "ws-def"}, result) +} + +func TestGetOIDCAllowedWorkspaceIDs_Empty(t *testing.T) { + t.Setenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS", "") + result := GetOIDCAllowedWorkspaceIDs() + assert.Nil(t, result, "empty string should return nil") +} + +func TestGetOIDCAllowedWorkspaceIDs_Unset(t *testing.T) { + os.Unsetenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS") + result := GetOIDCAllowedWorkspaceIDs() + assert.Nil(t, result, "unset env should return nil") +} + +func TestGetOIDCAllowedWorkspaceIDs_WildcardIsNotSpecialInList(t *testing.T) { + // A "*" mixed with other values is NOT treated as a wildcard; + // it is treated as a literal workspace ID entry. + t.Setenv("VPROX_OIDC_ALLOWED_WORKSPACE_IDS", "ws-abc,*,ws-def") + result := GetOIDCAllowedWorkspaceIDs() + assert.Equal(t, []string{"ws-abc", "*", "ws-def"}, result) +} + +func TestGetAuthMode_Defaults(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "") + assert.Equal(t, AuthModePassword, GetAuthMode()) +} + +func TestGetAuthMode_OIDCModal(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "oidc-modal") + assert.Equal(t, AuthModeOIDCModal, GetAuthMode()) +} + +func TestGetAuthMode_OIDCModalCaseInsensitive(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "OIDC-MODAL") + assert.Equal(t, AuthModeOIDCModal, GetAuthMode()) +} + +func TestGetAuthMode_UnknownFallsBackToPassword(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "bogus") + assert.Equal(t, AuthModePassword, GetAuthMode()) +} + +func TestGetAuthMode_PlainOIDCFallsBackToPassword(t *testing.T) { + // "oidc" alone is not a valid mode; must be "oidc-modal". + t.Setenv("VPROX_AUTH_MODE", "oidc") + assert.Equal(t, AuthModePassword, GetAuthMode()) +} + +func TestGetOIDCIssuerURL_Default(t *testing.T) { + t.Setenv("VPROX_OIDC_ISSUER", "") + assert.Equal(t, "https://oidc.modal.com", GetOIDCIssuerURL()) +} + +func TestGetOIDCIssuerURL_Custom(t *testing.T) { + t.Setenv("VPROX_OIDC_ISSUER", "https://custom.issuer.example.com") + assert.Equal(t, "https://custom.issuer.example.com", GetOIDCIssuerURL()) +} + +func TestGetOIDCToken_Set(t *testing.T) { + t.Setenv("VPROX_OIDC_TOKEN", "eyJhbGciOiJSUzI1NiJ9.test.sig") + token, err := GetOIDCToken() + assert.NoError(t, err) + assert.Equal(t, "eyJhbGciOiJSUzI1NiJ9.test.sig", token) +} + +func TestGetOIDCToken_Unset(t *testing.T) { + os.Unsetenv("VPROX_OIDC_TOKEN") + _, err := GetOIDCToken() + assert.Error(t, err) + assert.Contains(t, err.Error(), "VPROX_OIDC_TOKEN") +} + +func TestGetClientToken_Password(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "password") + t.Setenv("VPROX_PASSWORD", "s3cret") + token, err := GetClientToken() + assert.NoError(t, err) + assert.Equal(t, "s3cret", token) +} + +func TestGetClientToken_OIDCModal(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "oidc-modal") + t.Setenv("VPROX_OIDC_TOKEN", "eyJhbGciOiJSUzI1NiJ9.test.sig") + token, err := GetClientToken() + assert.NoError(t, err) + assert.Equal(t, "eyJhbGciOiJSUzI1NiJ9.test.sig", token) +} + +func TestGetClientToken_OIDCModal_MissingToken(t *testing.T) { + t.Setenv("VPROX_AUTH_MODE", "oidc-modal") + os.Unsetenv("VPROX_OIDC_TOKEN") + _, err := GetClientToken() + assert.Error(t, err) + assert.Contains(t, err.Error(), "VPROX_OIDC_TOKEN") +} diff --git a/lib/server.go b/lib/server.go index 0d1ff5a..d8fef7c 100644 --- a/lib/server.go +++ b/lib/server.go @@ -62,8 +62,8 @@ type Server struct { // Currently only setting this to the default interface is supported. BindIface netlink.Link - // Password is needed to authenticate connection requests. - Password string + // Auth is the authenticator used to verify incoming requests. + Auth *Authenticator // Index is a unique server index for firewall marks and other uses. It starts at 0. Index uint16 @@ -190,8 +190,7 @@ func (srv *Server) connectHandler(w http.ResponseWriter, r *http.Request) { return } - auth := r.Header.Get("Authorization") - if auth != "Bearer "+srv.Password { + if err := srv.Auth.Authenticate(r); err != nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } @@ -294,8 +293,7 @@ func (srv *Server) disconnectHandler(w http.ResponseWriter, r *http.Request) { return } - auth := r.Header.Get("Authorization") - if auth != "Bearer "+srv.Password { + if err := srv.Auth.Authenticate(r); err != nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } @@ -357,8 +355,7 @@ func (srv *Server) relinquishHandler(w http.ResponseWriter, r *http.Request) { return } - auth := r.Header.Get("Authorization") - if auth != "Bearer "+srv.Password { + if err := srv.Auth.Authenticate(r); err != nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } @@ -395,8 +392,7 @@ func (srv *Server) versionHandler(w http.ResponseWriter, r *http.Request) { return } - auth := r.Header.Get("Authorization") - if auth != "Bearer "+srv.Password { + if err := srv.Auth.Authenticate(r); err != nil { http.Error(w, "unauthorized", http.StatusUnauthorized) return } diff --git a/lib/server_manager.go b/lib/server_manager.go index 632541f..6f07927 100644 --- a/lib/server_manager.go +++ b/lib/server_manager.go @@ -24,7 +24,7 @@ type ServerManager struct { wgClient *wgctrl.Client ipt *iptables.IPTables key wgtypes.Key - password string + auth *Authenticator ctx context.Context waitGroup *sync.WaitGroup wgBlock netip.Prefix @@ -41,7 +41,7 @@ type ServerManager struct { } // NewServerManager creates a new server manager -func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, password string, takeover bool) (*ServerManager, error) { +func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Context, key wgtypes.Key, auth *Authenticator, takeover bool) (*ServerManager, error) { // Make a shared WireGuard client. wgClient, err := wgctrl.New() if err != nil { @@ -62,7 +62,7 @@ func NewServerManager(wgBlock netip.Prefix, wgBlockPerIp uint, ctx context.Conte sm.wgClient = wgClient sm.ipt = ipt sm.key = key - sm.password = password + sm.auth = auth sm.ctx = ctx sm.waitGroup = new(sync.WaitGroup) sm.wgBlock = wgBlock.Masked() @@ -111,7 +111,7 @@ func (sm *ServerManager) Start(ip netip.Addr) error { srv := &Server{ Key: sm.key, BindAddr: ip, - Password: sm.password, + Auth: sm.auth, Index: i, Ipt: sm.ipt, WgClient: sm.wgClient,