Skip to content

Commit 51721b7

Browse files
committed
fix: lots of fixes and improvements
1 parent 5705acf commit 51721b7

File tree

17 files changed

+350
-189
lines changed

17 files changed

+350
-189
lines changed

contract/session.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ type Session interface {
5858
// HasRegenerated returns true if the session has been regenerated (i.e., the session ID
5959
// has changed). This is useful for sending updated session identifiers to the client.
6060
HasRegenerated() bool
61+
62+
// MarkAsUnchanged sets the session as if nothing has changed, therefore avoiding saving
63+
// the session when the request finishes.
64+
MarkAsUnchanged()
6165
}
6266

6367
// SessionDriver defines the interface for persisting and retrieving session data.

framework/README.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,7 @@ func handler(w http.ResponseWriter, r *http.Request) error {
5555
func main() {
5656
app := framework.New()
5757

58-
app.Use(middleware.ErrorHandler(middleware.ErrorHandlerOptions{
59-
Logger: slog.Default(),
60-
IsDev: true,
61-
}))
62-
58+
app.Use(middleware.Logger(slog.Default()))
6359
app.Use(middleware.Recover())
6460

6561
app.Get("/", handler)

framework/cache/redis/cache.go

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@ func New(options *Options) *Client {
2020
}
2121

2222
// Get retrieves a value by key or returns contract.ErrNotFound if missing.
23-
func (c *Client) Get(ctx context.Context, key string) (v any, e error) {
24-
encoded, err := (*redis.Client)(c).Get(ctx, key).Result()
23+
func (c *Client) Get(ctx context.Context, key string) (any, error) {
24+
v, err := (*redis.Client)(c).Get(ctx, key).Result()
2525

2626
if errors.Is(err, redis.Nil) {
2727
return nil, fmt.Errorf("%w: %s", contract.ErrCacheKeyNotFound, key)
@@ -31,22 +31,12 @@ func (c *Client) Get(ctx context.Context, key string) (v any, e error) {
3131
return nil, err
3232
}
3333

34-
if err := json.Unmarshal([]byte(encoded), &v); err != nil {
35-
return nil, err
36-
}
37-
3834
return v, nil
3935
}
4036

4137
// Put sets a key with value and TTL.
4238
func (c *Client) Put(ctx context.Context, key string, value any, ttl time.Duration) error {
43-
encoded, err := json.Marshal(value)
44-
45-
if err != nil {
46-
return err
47-
}
48-
49-
return (*redis.Client)(c).Set(ctx, key, encoded, ttl).Err()
39+
return (*redis.Client)(c).Set(ctx, key, value, ttl).Err()
5040
}
5141

5242
// Delete removes a key.

framework/go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ require (
2020
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
2121
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
2222
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
23+
github.com/stretchr/objx v0.5.2 // indirect
2324
golang.org/x/sys v0.38.0 // indirect
2425
gopkg.in/yaml.v3 v3.0.1 // indirect
2526
)

framework/go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI
2828
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
2929
github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4ViluI=
3030
github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370=
31+
github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY=
32+
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
3133
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
3234
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
3335
github.com/studiolambda/cosmos/contract v0.5.0 h1:wGJj6k3Yhv1jeZvmxcDFsuSUyYTxV88S6n/5Gh4njBQ=

framework/handler.go

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@ import (
55
"errors"
66
"net/http"
77
"net/http/httptest"
8+
9+
"github.com/studiolambda/cosmos/framework/hook"
10+
"github.com/studiolambda/cosmos/problem"
811
)
912

1013
// Handler defines the function signature for HTTP request handlers in Cosmos.
@@ -50,6 +53,28 @@ type HTTPStatus interface {
5053
HTTPStatus() int
5154
}
5255

56+
func handleError(w http.ResponseWriter, r *http.Request, err error) {
57+
status := http.StatusInternalServerError
58+
59+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
60+
status = 499 // A non-standard status code: 499 Client Closed Request
61+
}
62+
63+
if s, ok := err.(HTTPStatus); ok {
64+
status = s.HTTPStatus()
65+
}
66+
67+
// We can check if the error can be directly
68+
// handled by using a [http.Handler], in the case
69+
// we'll simply handle it using ServeHTTP.
70+
if h, ok := err.(http.Handler); ok {
71+
h.ServeHTTP(w, r)
72+
return
73+
}
74+
75+
problem.NewProblem(err, status).ServeHTTP(w, r)
76+
}
77+
5378
// ServeHTTP implements the http.Handler interface, allowing Cosmos handlers
5479
// to be used with the standard HTTP server. It bridges the gap between
5580
// Cosmos's error-returning handlers and Go's standard http.Handler interface.
@@ -59,18 +84,21 @@ type HTTPStatus interface {
5984
// response to the client. However, if the response has already been partially
6085
// written (e.g., during streaming), the error response may not be deliverable.
6186
func (handler Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
62-
if err := handler(w, r); err != nil {
63-
status := http.StatusInternalServerError
87+
hooks := hook.NewManager()
88+
wrapped := hook.NewResponseWriter(w, hooks)
89+
ctx := context.WithValue(r.Context(), hook.Key, hooks)
90+
err := handler(wrapped, r.WithContext(ctx))
6491

65-
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
66-
status = 499 // A non-standard status code: 499 Client Closed Request
67-
}
92+
if err != nil {
93+
handleError(w, r, err)
94+
}
6895

69-
if s, ok := err.(HTTPStatus); ok {
70-
status = s.HTTPStatus()
71-
}
96+
if !wrapped.WriteHeaderCalled() {
97+
wrapped.WriteHeader(http.StatusNoContent)
98+
}
7299

73-
http.Error(w, err.Error(), status)
100+
for _, callback := range hooks.AfterResponseFuncs() {
101+
callback(err)
74102
}
75103
}
76104

framework/hook/manager.go

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
package hook
2+
3+
import (
4+
"net/http"
5+
"slices"
6+
"sync"
7+
)
8+
9+
// Manager is a structure that can be used to
10+
// attach specific hooks to the request / response
11+
// lifecycle. It requires the use of the hooks middleware.
12+
//
13+
// It is safe for concurrent use.
14+
type Manager struct {
15+
mutex sync.Mutex
16+
afterResponseFuncs []AfterResponseFunc
17+
beforeWriteHeaderFuncs []BeforeWriteHeaderFunc
18+
beforeWriteFuncs []BeforeWriteFunc
19+
}
20+
21+
type AfterResponseFunc func(err error)
22+
type BeforeWriteHeaderFunc func(w http.ResponseWriter, status int)
23+
type BeforeWriteFunc func(w http.ResponseWriter, content []byte)
24+
type key struct{}
25+
26+
var Key = key{}
27+
28+
func NewManager() *Manager {
29+
return &Manager{
30+
mutex: sync.Mutex{},
31+
beforeWriteHeaderFuncs: []BeforeWriteHeaderFunc{},
32+
beforeWriteFuncs: []BeforeWriteFunc{},
33+
}
34+
}
35+
36+
func (h *Manager) BeforeWriteHeader(callbacks ...BeforeWriteHeaderFunc) {
37+
h.mutex.Lock()
38+
defer h.mutex.Unlock()
39+
40+
h.beforeWriteHeaderFuncs = append(h.beforeWriteHeaderFuncs, callbacks...)
41+
}
42+
43+
func (h *Manager) BeforeWriteHeaderFuncs() []BeforeWriteHeaderFunc {
44+
h.mutex.Lock()
45+
defer h.mutex.Unlock()
46+
47+
clone := slices.Clone(h.beforeWriteHeaderFuncs)
48+
slices.Reverse(clone)
49+
50+
return clone
51+
}
52+
53+
func (h *Manager) BeforeWrite(callbacks ...BeforeWriteFunc) {
54+
h.mutex.Lock()
55+
defer h.mutex.Unlock()
56+
57+
h.beforeWriteFuncs = append(h.beforeWriteFuncs, callbacks...)
58+
}
59+
60+
func (h *Manager) BeforeWriteFuncs() []BeforeWriteFunc {
61+
h.mutex.Lock()
62+
defer h.mutex.Unlock()
63+
64+
clone := slices.Clone(h.beforeWriteFuncs)
65+
slices.Reverse(clone)
66+
67+
return clone
68+
}
69+
70+
func (h *Manager) AfterResponse(callbacks ...AfterResponseFunc) {
71+
h.mutex.Lock()
72+
defer h.mutex.Unlock()
73+
74+
h.afterResponseFuncs = append(h.afterResponseFuncs, callbacks...)
75+
}
76+
77+
func (h *Manager) AfterResponseFuncs() []AfterResponseFunc {
78+
h.mutex.Lock()
79+
defer h.mutex.Unlock()
80+
81+
clone := slices.Clone(h.afterResponseFuncs)
82+
slices.Reverse(clone)
83+
84+
return clone
85+
}

framework/hook/writer.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package hook
2+
3+
import "net/http"
4+
5+
type ResponseWriter struct {
6+
http.ResponseWriter
7+
*Manager
8+
writeHeaderCalled bool
9+
}
10+
11+
type ResponseWriterFlusher struct {
12+
*ResponseWriter
13+
http.Flusher
14+
}
15+
16+
type WrappedResponseWriter interface {
17+
http.ResponseWriter
18+
WriteHeaderCalled() bool
19+
}
20+
21+
func NewResponseWriter(w http.ResponseWriter, m *Manager) WrappedResponseWriter {
22+
wrapped := &ResponseWriter{
23+
ResponseWriter: w,
24+
Manager: m,
25+
}
26+
27+
if f, ok := w.(http.Flusher); ok {
28+
return &ResponseWriterFlusher{
29+
ResponseWriter: wrapped,
30+
Flusher: f,
31+
}
32+
}
33+
34+
return wrapped
35+
}
36+
37+
func (w *ResponseWriter) WriteHeaderCalled() bool {
38+
return w.writeHeaderCalled
39+
}
40+
41+
func (w *ResponseWriter) WriteHeader(status int) {
42+
if w.WriteHeaderCalled() {
43+
return
44+
}
45+
46+
for _, hook := range w.Manager.BeforeWriteHeaderFuncs() {
47+
hook(w.ResponseWriter, status)
48+
}
49+
50+
w.ResponseWriter.WriteHeader(status)
51+
w.writeHeaderCalled = true
52+
}
53+
54+
func (w *ResponseWriter) Write(content []byte) (int, error) {
55+
if !w.WriteHeaderCalled() {
56+
// Same behaviour as the [http.ResponseWriter]
57+
w.WriteHeader(http.StatusOK)
58+
}
59+
60+
for _, hook := range w.Manager.BeforeWriteFuncs() {
61+
hook(w.ResponseWriter, content)
62+
}
63+
64+
return w.ResponseWriter.Write(content)
65+
}

framework/middleware/error_handler.go

Lines changed: 0 additions & 76 deletions
This file was deleted.

0 commit comments

Comments
 (0)