diff --git a/internal/oauthex/auth_meta.go b/internal/oauthex/auth_meta.go index 1f075f8a..5bbbb412 100644 --- a/internal/oauthex/auth_meta.go +++ b/internal/oauthex/auth_meta.go @@ -8,10 +8,14 @@ package oauthex import ( + "bytes" "context" + "encoding/json" "errors" "fmt" + "io" "net/http" + "time" ) // AuthServerMeta represents the metadata for an OAuth 2.0 authorization server, @@ -109,6 +113,153 @@ type AuthServerMeta struct { CodeChallengeMethodsSupported []string `json:"code_challenge_methods_supported,omitempty"` } +// ClientRegistrationMetadata represents the client metadata fields for the DCR POST request (RFC 7591). +type ClientRegistrationMetadata struct { + // RedirectURIs is a REQUIRED JSON array of redirection URI strings for use in + // redirect-based flows (such as the authorization code grant). + RedirectURIs []string `json:"redirect_uris"` + + // TokenEndpointAuthMethod is an OPTIONAL string indicator of the requested + // authentication method for the token endpoint. + // If omitted, the default is "client_secret_basic". + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method,omitempty"` + + // GrantTypes is an OPTIONAL JSON array of OAuth 2.0 grant type strings + // that the client will restrict itself to using. + // If omitted, the default is ["authorization_code"]. + GrantTypes []string `json:"grant_types,omitempty"` + + // ResponseTypes is an OPTIONAL JSON array of OAuth 2.0 response type strings + // that the client will restrict itself to using. + // If omitted, the default is ["code"]. + ResponseTypes []string `json:"response_types,omitempty"` + + // ClientName is a RECOMMENDED human-readable name of the client to be presented + // to the end-user. + ClientName string `json:"client_name,omitempty"` + + // ClientURI is a RECOMMENDED URL of a web page providing information about the client. + ClientURI string `json:"client_uri,omitempty"` + + // LogoURI is an OPTIONAL URL of a logo for the client, which may be displayed + // to the end-user. + LogoURI string `json:"logo_uri,omitempty"` + + // Scope is an OPTIONAL string containing a space-separated list of scope values + // that the client will restrict itself to using. + Scope string `json:"scope,omitempty"` + + // Contacts is an OPTIONAL JSON array of strings representing ways to contact + // people responsible for this client (e.g., email addresses). + Contacts []string `json:"contacts,omitempty"` + + // TOSURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's terms of service. + TOSURI string `json:"tos_uri,omitempty"` + + // PolicyURI is an OPTIONAL URL that the client provides to the end-user + // to read about the client's privacy policy. + PolicyURI string `json:"policy_uri,omitempty"` + + // JWKSURI is an OPTIONAL URL for the client's JSON Web Key Set [JWK] document. + // This is preferred over the 'jwks' parameter. + JWKSURI string `json:"jwks_uri,omitempty"` + + // JWKS is an OPTIONAL client's JSON Web Key Set [JWK] document, passed by value. + // This is an alternative to providing a JWKSURI. + JWKS string `json:"jwks,omitempty"` + + // SoftwareID is an OPTIONAL unique identifier string for the client software, + // constant across all instances and versions. + SoftwareID string `json:"software_id,omitempty"` + + // SoftwareVersion is an OPTIONAL version identifier string for the client software. + SoftwareVersion string `json:"software_version,omitempty"` + + // SoftwareStatement is an OPTIONAL JWT that asserts client metadata values. + // Values in the software statement take precedence over other metadata values. + SoftwareStatement string `json:"software_statement,omitempty"` +} + +// ClientRegistrationResponse represents the fields returned by the Authorization Server +// (RFC 7591, Section 3.2.1 and 3.2.2). +type ClientRegistrationResponse struct { + // ClientRegistrationMetadata contains all registered client metadata, returned by the + // server on success, potentially with modified or defaulted values. + ClientRegistrationMetadata + + // ClientID is the REQUIRED newly issued OAuth 2.0 client identifier. + ClientID string `json:"client_id"` + + // ClientSecret is an OPTIONAL client secret string. + ClientSecret string `json:"client_secret,omitempty"` + + // ClientIDIssuedAt is an OPTIONAL Unix timestamp when the ClientID was issued. + ClientIDIssuedAt time.Time `json:"client_id_issued_at,omitempty"` + + // ClientSecretExpiresAt is the REQUIRED (if client_secret is issued) Unix + // timestamp when the secret expires, or 0 if it never expires. + ClientSecretExpiresAt time.Time `json:"client_secret_expires_at,omitempty"` +} + +func (r *ClientRegistrationResponse) MarshalJSON() ([]byte, error) { + type alias ClientRegistrationResponse + var clientIDIssuedAt int64 + var clientSecretExpiresAt int64 + + if !r.ClientIDIssuedAt.IsZero() { + clientIDIssuedAt = r.ClientIDIssuedAt.Unix() + } + if !r.ClientSecretExpiresAt.IsZero() { + clientSecretExpiresAt = r.ClientSecretExpiresAt.Unix() + } + + return json.Marshal(&struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + ClientIDIssuedAt: clientIDIssuedAt, + ClientSecretExpiresAt: clientSecretExpiresAt, + alias: (*alias)(r), + }) +} + +func (r *ClientRegistrationResponse) UnmarshalJSON(data []byte) error { + type alias ClientRegistrationResponse + aux := &struct { + ClientIDIssuedAt int64 `json:"client_id_issued_at,omitempty"` + ClientSecretExpiresAt int64 `json:"client_secret_expires_at,omitempty"` + *alias + }{ + alias: (*alias)(r), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + if aux.ClientIDIssuedAt != 0 { + r.ClientIDIssuedAt = time.Unix(aux.ClientIDIssuedAt, 0) + } + if aux.ClientSecretExpiresAt != 0 { + r.ClientSecretExpiresAt = time.Unix(aux.ClientSecretExpiresAt, 0) + } + return nil +} + +// ClientRegistrationError is the error response from the Authorization Server +// for a failed registration attempt (RFC 7591, Section 3.2.2). +type ClientRegistrationError struct { + // ErrorCode is the REQUIRED error code if registration failed (RFC 7591, 3.2.2). + ErrorCode string `json:"error"` + + // ErrorDescription is an OPTIONAL human-readable error message. + ErrorDescription string `json:"error_description,omitempty"` +} + +func (e *ClientRegistrationError) Error() string { + return fmt.Sprintf("registration failed: %s (%s)", e.ErrorCode, e.ErrorDescription) +} + var wellKnownPaths = []string{ "/.well-known/oauth-authorization-server", "/.well-known/openid-configuration", @@ -143,3 +294,59 @@ func GetAuthServerMeta(ctx context.Context, issuerURL string, c *http.Client) (* } return nil, fmt.Errorf("failed to get auth server metadata from %q: %w", issuerURL, errors.Join(errs...)) } + +// RegisterClient performs Dynamic Client Registration according to RFC 7591. +func RegisterClient(ctx context.Context, registrationEndpoint string, clientMeta *ClientRegistrationMetadata, c *http.Client) (*ClientRegistrationResponse, error) { + if registrationEndpoint == "" { + return nil, fmt.Errorf("registration_endpoint is required") + } + + if c == nil { + c = http.DefaultClient + } + + payload, err := json.Marshal(clientMeta) + if err != nil { + return nil, fmt.Errorf("failed to marshal client metadata: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, "POST", registrationEndpoint, bytes.NewBuffer(payload)) + if err != nil { + return nil, fmt.Errorf("failed to create registration request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := c.Do(req) + if err != nil { + return nil, fmt.Errorf("registration request failed: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read registration response body: %w", err) + } + + if resp.StatusCode == http.StatusCreated { + var regResponse ClientRegistrationResponse + if err := json.Unmarshal(body, ®Response); err != nil { + return nil, fmt.Errorf("failed to decode successful registration response: %w (%s)", err, string(body)) + } + if regResponse.ClientID == "" { + return nil, fmt.Errorf("registration response is missing required 'client_id' field") + } + return ®Response, nil + } + + if resp.StatusCode == http.StatusBadRequest { + var regError ClientRegistrationError + if err := json.Unmarshal(body, ®Error); err != nil { + return nil, fmt.Errorf("failed to decode registration error response: %w (%s)", err, string(body)) + } + return nil, ®Error + } + + return nil, fmt.Errorf("registration failed with status %s: %s", resp.Status, string(body)) +} diff --git a/internal/oauthex/auth_meta_test.go b/internal/oauthex/auth_meta_test.go index b83402f2..6ff9f3dd 100644 --- a/internal/oauthex/auth_meta_test.go +++ b/internal/oauthex/auth_meta_test.go @@ -5,10 +5,18 @@ package oauthex import ( + "context" "encoding/json" + "io" + "net/http" + "net/http/httptest" "os" "path/filepath" + "strings" "testing" + "time" + + "github.com/google/go-cmp/cmp" ) func TestAuthMetaParse(t *testing.T) { @@ -26,3 +34,198 @@ func TestAuthMetaParse(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +func TestClientRegistrationMetadataParse(t *testing.T) { + // Verify that we can parse a typical client metadata JSON. + data, err := os.ReadFile(filepath.FromSlash("testdata/client-auth-meta.json")) + if err != nil { + t.Fatal(err) + } + var a ClientRegistrationMetadata + if err := json.Unmarshal(data, &a); err != nil { + t.Fatal(err) + } + // Spot check + if g, w := a.ClientName, "My Test App"; g != w { + t.Errorf("got ClientName %q, want %q", g, w) + } + if g, w := len(a.RedirectURIs), 2; g != w { + t.Errorf("got %d RedirectURIs, want %d", g, w) + } +} + +func TestRegisterClient(t *testing.T) { + testCases := []struct { + name string + handler http.HandlerFunc + clientMeta *ClientRegistrationMetadata + wantClientID string + wantErr string + }{ + { + name: "Success", + handler: func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Errorf("Expected POST, got %s", r.Method) + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatal(err) + } + var receivedMeta ClientRegistrationMetadata + if err := json.Unmarshal(body, &receivedMeta); err != nil { + t.Fatalf("Failed to unmarshal request body: %v", err) + } + if receivedMeta.ClientName != "Test App" { + t.Errorf("Expected ClientName 'Test App', got '%s'", receivedMeta.ClientName) + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"client_id":"test-client-id","client_secret":"test-client-secret","client_name":"Test App"}`)) + }, + clientMeta: &ClientRegistrationMetadata{ClientName: "Test App", RedirectURIs: []string{"http://localhost/cb"}}, + wantClientID: "test-client-id", + }, + { + name: "Missing ClientID in Response", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"client_secret":"test-client-secret"}`)) // No client_id + }, + clientMeta: &ClientRegistrationMetadata{RedirectURIs: []string{"http://localhost/cb"}}, + wantErr: "registration response is missing required 'client_id' field", + }, + { + name: "Standard OAuth Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + w.Write([]byte(`{"error":"invalid_redirect_uri","error_description":"Redirect URI is not valid."}`)) + }, + clientMeta: &ClientRegistrationMetadata{RedirectURIs: []string{"http://invalid/cb"}}, + wantErr: "registration failed: invalid_redirect_uri (Redirect URI is not valid.)", + }, + { + name: "Non-JSON Server Error", + handler: func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + w.Write([]byte("Internal Server Error")) + }, + clientMeta: &ClientRegistrationMetadata{RedirectURIs: []string{"http://localhost/cb"}}, + wantErr: "registration failed with status 500 Internal Server Error", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := httptest.NewServer(tc.handler) + defer server.Close() + + info, err := RegisterClient(context.Background(), server.URL, tc.clientMeta, server.Client()) + + if tc.wantErr != "" { + if err == nil { + t.Fatalf("Expected an error containing '%s', but got nil", tc.wantErr) + } + if !strings.Contains(err.Error(), tc.wantErr) { + t.Errorf("Expected error to contain '%s', got '%v'", tc.wantErr, err) + } + return + } + + if err != nil { + t.Fatalf("Expected no error, but got: %v", err) + } + if info.ClientID != tc.wantClientID { + t.Errorf("Expected client_id '%s', got '%s'", tc.wantClientID, info.ClientID) + } + }) + } + + t.Run("No Endpoint", func(t *testing.T) { + _, err := RegisterClient(context.Background(), "", &ClientRegistrationMetadata{}, nil) + if err == nil { + t.Fatal("Expected an error for missing registration endpoint, got nil") + } + expectedErr := "registration_endpoint is required" + if err.Error() != expectedErr { + t.Errorf("Expected error '%s', got '%v'", expectedErr, err) + } + }) +} + +func TestClientRegistrationResponseJSON(t *testing.T) { + testCases := []struct { + name string + in ClientRegistrationResponse + wantJSON string + }{ + { + name: "full response", + in: ClientRegistrationResponse{ + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + ClientIDIssuedAt: time.Unix(1758840047, 0), + ClientSecretExpiresAt: time.Unix(1790376047, 0), + }, + wantJSON: `{"client_id":"test-client-id","client_secret":"test-client-secret","client_id_issued_at":1758840047,"client_secret_expires_at":1790376047, "redirect_uris": null}`, + }, + { + name: "minimal response with only required fields", + in: ClientRegistrationResponse{ + ClientID: "test-client-id-minimal", + }, + wantJSON: `{"client_id":"test-client-id-minimal", "redirect_uris":null}`, + }, + { + name: "response with a secret that does not expire", + in: ClientRegistrationResponse{ + ClientID: "test-client-id-no-expiry", + ClientSecret: "test-secret-no-expiry", + }, + wantJSON: `{"client_id":"test-client-id-no-expiry","client_secret":"test-secret-no-expiry", "redirect_uris":null}`, + }, + { + name: "unmarshal with zero timestamp", + in: ClientRegistrationResponse{ClientID: "client-id-zero"}, + wantJSON: `{"client_id":"client-id-zero", "redirect_uris":null}`, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Test MarshalJSON + t.Run("marshal", func(t *testing.T) { + b, err := json.Marshal(&tc.in) + if err != nil { + t.Fatalf("Marshal() error = %v", err) + } + + var gotMap, wantMap map[string]any + if err := json.Unmarshal(b, &gotMap); err != nil { + t.Fatalf("failed to unmarshal actual result: %v", err) + } + if err := json.Unmarshal([]byte(tc.wantJSON), &wantMap); err != nil { + t.Fatalf("failed to unmarshal expected result: %v", err) + } + + if diff := cmp.Diff(wantMap, gotMap); diff != "" { + t.Errorf("Marshal() mismatch (-want +got):\n%s", diff) + } + }) + + // Test UnmarshalJSON + t.Run("unmarshal", func(t *testing.T) { + var got ClientRegistrationResponse + if err := json.Unmarshal([]byte(tc.wantJSON), &got); err != nil { + t.Fatalf("Unmarshal() error = %v", err) + } + + if diff := cmp.Diff(tc.in, got); diff != "" { + t.Errorf("Unmarshal() mismatch (-want +got):\n%s", diff) + } + }) + }) + } +} diff --git a/internal/oauthex/testdata/client-auth-meta.json b/internal/oauthex/testdata/client-auth-meta.json new file mode 100644 index 00000000..c07f5be1 --- /dev/null +++ b/internal/oauthex/testdata/client-auth-meta.json @@ -0,0 +1,8 @@ +{ + "client_name": "My Test App", + "redirect_uris": [ + "https://client.example.org/callback", + "https://client.example.org/callback2" + ], + "scope": "read write" +}