|
| 1 | +package apierrors |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "errors" |
| 6 | + "net/http" |
| 7 | + "strings" |
| 8 | + |
| 9 | + api "github.com/openmeterio/openmeter/api/v3" |
| 10 | + "github.com/openmeterio/openmeter/pkg/errorsx" |
| 11 | + "github.com/openmeterio/openmeter/pkg/models" |
| 12 | +) |
| 13 | + |
| 14 | +const httpStatusCodeErrorAttribute = "openmeter.http.status_code" |
| 15 | + |
| 16 | +type v3ErrorMapping struct { |
| 17 | + match func(err error) bool |
| 18 | + build func(ctx context.Context, err error) *BaseAPIError |
| 19 | +} |
| 20 | + |
| 21 | +// NewV3ErrorHandlerFunc returns an oapi-codegen ChiServerOptions.ErrorHandlerFunc implementation. |
| 22 | +// |
| 23 | +// It is invoked when the generated router fails request binding (query/path/header parsing). |
| 24 | +// The main purpose is to ensure we always write a response (otherwise net/http defaults to 200 with |
| 25 | +// an empty body), and to keep error-to-status mapping consistent with our model error types. |
| 26 | +func NewV3ErrorHandlerFunc(logger errorsx.Handler) func(w http.ResponseWriter, r *http.Request, err error) { |
| 27 | + // Mirrors pkg/framework/commonhttp/encoder.go GenericErrorEncoder ordering (after the status-code attribute mapping). |
| 28 | + mappings := []v3ErrorMapping{ |
| 29 | + { |
| 30 | + match: models.IsGenericConflictError, |
| 31 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewConflictError(ctx, err, err.Error()) }, |
| 32 | + }, |
| 33 | + { |
| 34 | + match: models.IsGenericForbiddenError, |
| 35 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewForbiddenError(ctx, err) }, |
| 36 | + }, |
| 37 | + { |
| 38 | + match: models.IsGenericNotImplementedError, |
| 39 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewNotImplementedError(ctx, err) }, |
| 40 | + }, |
| 41 | + { |
| 42 | + match: models.IsGenericValidationError, |
| 43 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewBadRequestError(ctx, err, nil) }, |
| 44 | + }, |
| 45 | + { |
| 46 | + match: models.IsGenericNotFoundError, |
| 47 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewNotFoundError(ctx, err, "") }, |
| 48 | + }, |
| 49 | + { |
| 50 | + match: models.IsGenericUnauthorizedError, |
| 51 | + build: func(ctx context.Context, err error) *BaseAPIError { return NewUnauthenticatedError(ctx, err) }, |
| 52 | + }, |
| 53 | + { |
| 54 | + match: models.IsGenericPreConditionFailedError, |
| 55 | + build: func(ctx context.Context, err error) *BaseAPIError { |
| 56 | + return NewPreconditionFailedError(ctx, err.Error()) |
| 57 | + }, |
| 58 | + }, |
| 59 | + } |
| 60 | + |
| 61 | + return func(w http.ResponseWriter, r *http.Request, err error) { |
| 62 | + if err == nil { |
| 63 | + return |
| 64 | + } |
| 65 | + |
| 66 | + // If it's already a v3 API error, just render it. |
| 67 | + var apiErr *BaseAPIError |
| 68 | + if errors.As(err, &apiErr) { |
| 69 | + apiErr.HandleAPIError(w, r) |
| 70 | + return |
| 71 | + } |
| 72 | + |
| 73 | + ctx := r.Context() |
| 74 | + |
| 75 | + // Request binding errors produced by the generated v3 router. |
| 76 | + // Convert them into v3 InvalidParameters so the response is actionable for clients. |
| 77 | + if invalidParams, ok := invalidParametersFromGeneratedRouterError(err); ok { |
| 78 | + logger.HandleContext(ctx, err) |
| 79 | + NewBadRequestError(ctx, err, invalidParams).HandleAPIError(w, r) |
| 80 | + return |
| 81 | + } |
| 82 | + |
| 83 | + // Mirror commonhttp.GenericErrorEncoder's ordering, but render using v3 apierrors. |
| 84 | + if status, ok := singularHTTPStatusFromValidationIssues(err); ok { |
| 85 | + if mapped := apiErrorFromHTTPStatus(ctx, status, err); mapped != nil { |
| 86 | + logger.HandleContext(ctx, err) |
| 87 | + mapped.HandleAPIError(w, r) |
| 88 | + return |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + for _, m := range mappings { |
| 93 | + if m.match(err) { |
| 94 | + logger.HandleContext(ctx, err) |
| 95 | + m.build(ctx, err).HandleAPIError(w, r) |
| 96 | + return |
| 97 | + } |
| 98 | + } |
| 99 | + |
| 100 | + // Default: classify as validation error (400) for request binding failures. |
| 101 | + validationErr := models.NewGenericValidationError(err) |
| 102 | + logger.HandleContext(r.Context(), validationErr) |
| 103 | + NewBadRequestError(r.Context(), validationErr, nil).HandleAPIError(w, r) |
| 104 | + } |
| 105 | +} |
| 106 | + |
| 107 | +func invalidParametersFromGeneratedRouterError(err error) (InvalidParameters, bool) { |
| 108 | + // These types are defined in api/v3/api.gen.go. |
| 109 | + // |
| 110 | + // Note: those errors do not carry the parameter location (query/path/header) except for the |
| 111 | + // dedicated "required header" variant, so we default to "query" where ambiguous. This is still |
| 112 | + // a major improvement over returning an empty error response. |
| 113 | + var invalidFormat *api.InvalidParamFormatError |
| 114 | + if errors.As(err, &invalidFormat) { |
| 115 | + field := enrichFieldFromBindError(invalidFormat.ParamName, invalidFormat.Err.Error()) |
| 116 | + return InvalidParameters{ |
| 117 | + { |
| 118 | + Field: field, |
| 119 | + Rule: "format", |
| 120 | + Reason: invalidFormat.Err.Error(), |
| 121 | + Source: InvalidParamSourceQuery, |
| 122 | + }, |
| 123 | + }, true |
| 124 | + } |
| 125 | + |
| 126 | + var requiredParam *api.RequiredParamError |
| 127 | + if errors.As(err, &requiredParam) { |
| 128 | + return InvalidParameters{ |
| 129 | + { |
| 130 | + Field: requiredParam.ParamName, |
| 131 | + Rule: "required", |
| 132 | + Reason: "is required", |
| 133 | + Source: InvalidParamSourceQuery, |
| 134 | + }, |
| 135 | + }, true |
| 136 | + } |
| 137 | + |
| 138 | + var requiredHeader *api.RequiredHeaderError |
| 139 | + if errors.As(err, &requiredHeader) { |
| 140 | + return InvalidParameters{ |
| 141 | + { |
| 142 | + Field: requiredHeader.ParamName, |
| 143 | + Rule: "required", |
| 144 | + Reason: "is required", |
| 145 | + Source: InvalidParamSourceHeader, |
| 146 | + }, |
| 147 | + }, true |
| 148 | + } |
| 149 | + |
| 150 | + var tooMany *api.TooManyValuesForParamError |
| 151 | + if errors.As(err, &tooMany) { |
| 152 | + return InvalidParameters{ |
| 153 | + { |
| 154 | + Field: tooMany.ParamName, |
| 155 | + Rule: "too_many_values", |
| 156 | + Reason: tooMany.Error(), |
| 157 | + Source: InvalidParamSourceQuery, |
| 158 | + }, |
| 159 | + }, true |
| 160 | + } |
| 161 | + |
| 162 | + var unmarshal *api.UnmarshalingParamError |
| 163 | + if errors.As(err, &unmarshal) { |
| 164 | + return InvalidParameters{ |
| 165 | + { |
| 166 | + Field: unmarshal.ParamName, |
| 167 | + Rule: "unmarshal", |
| 168 | + Reason: unmarshal.Err.Error(), |
| 169 | + Source: InvalidParamSourceQuery, |
| 170 | + }, |
| 171 | + }, true |
| 172 | + } |
| 173 | + |
| 174 | + var unescapedCookie *api.UnescapedCookieParamError |
| 175 | + if errors.As(err, &unescapedCookie) { |
| 176 | + return InvalidParameters{ |
| 177 | + { |
| 178 | + Field: unescapedCookie.ParamName, |
| 179 | + Rule: "unescape", |
| 180 | + Reason: unescapedCookie.Error(), |
| 181 | + Source: InvalidParamSourceHeader, |
| 182 | + }, |
| 183 | + }, true |
| 184 | + } |
| 185 | + |
| 186 | + return nil, false |
| 187 | +} |
| 188 | + |
| 189 | +func enrichFieldFromBindError(paramName string, bindErrMsg string) string { |
| 190 | + // oapi-codegen deepObject binding errors (runtime.BindQueryParameter) can be more specific than |
| 191 | + // just the outer parameter name, e.g.: |
| 192 | + // "error assigning value to destination: field [sizee] is not present in destination object". |
| 193 | + // |
| 194 | + // For nicer AIP errors, return "page.sizee" instead of just "page". |
| 195 | + if paramName == "" || bindErrMsg == "" { |
| 196 | + return paramName |
| 197 | + } |
| 198 | + if strings.Contains(paramName, "[") { |
| 199 | + // Already specific (e.g. "page[size]") - keep as-is. |
| 200 | + return paramName |
| 201 | + } |
| 202 | + const needle = "field [" |
| 203 | + i := strings.Index(bindErrMsg, needle) |
| 204 | + if i == -1 { |
| 205 | + return paramName |
| 206 | + } |
| 207 | + rest := bindErrMsg[i+len(needle):] |
| 208 | + j := strings.Index(rest, "]") |
| 209 | + if j == -1 { |
| 210 | + return paramName |
| 211 | + } |
| 212 | + field := rest[:j] |
| 213 | + if field == "" { |
| 214 | + return paramName |
| 215 | + } |
| 216 | + return paramName + "." + field |
| 217 | +} |
| 218 | + |
| 219 | +func singularHTTPStatusFromValidationIssues(err error) (int, bool) { |
| 220 | + issues, _ := models.AsValidationIssues(err) |
| 221 | + if len(issues) == 0 { |
| 222 | + return 0, false |
| 223 | + } |
| 224 | + |
| 225 | + // We intentionally mirror commonhttp.HandleIssueIfHTTPStatusKnown's "singular" behavior: |
| 226 | + // if multiple status codes are present, we don't map. |
| 227 | + codes := make(map[int]struct{}, 1) |
| 228 | + for _, issue := range issues { |
| 229 | + raw, ok := issue.Attributes()[httpStatusCodeErrorAttribute] |
| 230 | + if !ok { |
| 231 | + continue |
| 232 | + } |
| 233 | + c, ok := raw.(int) |
| 234 | + if !ok { |
| 235 | + continue |
| 236 | + } |
| 237 | + codes[c] = struct{}{} |
| 238 | + } |
| 239 | + |
| 240 | + if len(codes) != 1 { |
| 241 | + return 0, false |
| 242 | + } |
| 243 | + |
| 244 | + for c := range codes { |
| 245 | + return c, true |
| 246 | + } |
| 247 | + return 0, false |
| 248 | +} |
| 249 | + |
| 250 | +func apiErrorFromHTTPStatus(ctx context.Context, status int, err error) *BaseAPIError { |
| 251 | + switch status { |
| 252 | + case http.StatusBadRequest: |
| 253 | + return NewBadRequestError(ctx, err, nil) |
| 254 | + case http.StatusUnauthorized: |
| 255 | + return NewUnauthenticatedError(ctx, err) |
| 256 | + case http.StatusForbidden: |
| 257 | + return NewForbiddenError(ctx, err) |
| 258 | + case http.StatusNotFound: |
| 259 | + return NewNotFoundError(ctx, err, "") |
| 260 | + case http.StatusConflict: |
| 261 | + return NewConflictError(ctx, err, err.Error()) |
| 262 | + case http.StatusPreconditionFailed: |
| 263 | + return NewPreconditionFailedError(ctx, err.Error()) |
| 264 | + case http.StatusNotImplemented: |
| 265 | + return NewNotImplementedError(ctx, err) |
| 266 | + default: |
| 267 | + return nil |
| 268 | + } |
| 269 | +} |
0 commit comments