Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ _book/
dist/
coverage.*
.bin
.claude/
10 changes: 9 additions & 1 deletion json.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ type ErrorContainer struct {
Error *DefaultError `json:"error"`
}

func (e *ErrorContainer) ID() string {
return e.Error.ID()
}

type ErrorReporter interface {
ReportError(r *http.Request, code int, err error, args ...interface{})
}
Expand Down Expand Up @@ -159,13 +163,15 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
}

w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)

// Enhancing must happen after logging or context will be lost.
var payload interface{} = err
if h.ErrorEnhancer != nil {
payload = h.ErrorEnhancer(r, err)
}
if id, ok := payload.(interface{ ID() string }); ok {
w.Header().Set("Ory-Error-Id", id.ID())
}
if de, ok := payload.(*DefaultError); ok && !h.EnableDebug {
de2 := *de
de2.DebugField = ""
Expand All @@ -179,6 +185,8 @@ func (h *JSONWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
payload = ec2
}

w.WriteHeader(code)

if err := json.NewEncoder(w).Encode(payload); err != nil {
// There was an error, but there's actually not a lot we can do except log that this happened.
h.Reporter.ReportError(r, code, errors.WithStack(err), "Could not write ErrorContainer to response writer")
Expand Down
48 changes: 48 additions & 0 deletions json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,3 +422,51 @@ func TestCanceledJSON(t *testing.T) {
assert.Contains(t, string(body), "some unrelated error")
assert.Equal(t, 499, resp.StatusCode)
}

func TestOryErrorIDHeader(t *testing.T) {
for k, tc := range []struct {
name string
err error
expectedHeader string
}{
{
name: "error with ID sets header",
err: &ErrMisconfiguration,
expectedHeader: "invalid_configuration",
},
{
name: "error without ID does not set header",
err: &ErrNotFound,
expectedHeader: "",
},
{
name: "custom error with ID sets header",
err: &DefaultError{
IDField: "custom_error_id",
CodeField: http.StatusBadRequest,
StatusField: http.StatusText(http.StatusBadRequest),
ErrorField: "custom error",
},
expectedHeader: "custom_error_id",
},
{
name: "upstream error sets header",
err: &ErrUpstreamError,
expectedHeader: "upstream_error",
},
} {
t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) {
h := NewJSONWriter(nil)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, tc.err)
}))
t.Cleanup(ts.Close)

resp, err := http.Get(ts.URL + "/do")
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id"))
})
}
}
3 changes: 3 additions & 0 deletions plain.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ func (h *TextWriter) WriteErrorCode(w http.ResponseWriter, r *http.Request, code
// All errors land here, so it's a really good idea to do the logging here as well!
h.Reporter.ReportError(r, code, err, "An error occurred while handling a request")

if id, ok := err.(interface{ ID() string }); ok {
w.Header().Set("Ory-Error-Id", id.ID())
}
w.Header().Set("Content-Type", h.contentType)
w.WriteHeader(code)
fmt.Fprintf(w, "%s", err)
Expand Down
62 changes: 62 additions & 0 deletions plain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright © 2023 Ory Corp
// SPDX-License-Identifier: Apache-2.0

package herodot

import (
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestTextWriterOryErrorIDHeader(t *testing.T) {
for k, tc := range []struct {
name string
err error
expectedHeader string
}{
{
name: "error with ID sets header",
err: &ErrMisconfiguration,
expectedHeader: "invalid_configuration",
},
{
name: "error without ID does not set header",
err: &ErrNotFound,
expectedHeader: "",
},
{
name: "custom error with ID sets header",
err: &DefaultError{
IDField: "custom_text_error_id",
CodeField: http.StatusBadRequest,
StatusField: http.StatusText(http.StatusBadRequest),
ErrorField: "custom error",
},
expectedHeader: "custom_text_error_id",
},
{
name: "upstream error sets header",
err: &ErrUpstreamError,
expectedHeader: "upstream_error",
},
} {
t.Run(fmt.Sprintf("case=%d/%s", k, tc.name), func(t *testing.T) {
h := NewTextWriter(&stdReporter{}, "plain")
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.WriteError(w, r, tc.err)
}))
t.Cleanup(ts.Close)

resp, err := http.Get(ts.URL + "/do")
require.NoError(t, err)
defer resp.Body.Close()

assert.Equal(t, tc.expectedHeader, resp.Header.Get("Ory-Error-Id"))
})
}
}
Loading