Skip to content

Commit 9554e21

Browse files
Merge pull request #8 from form3tech-oss/add-date-header
feat: add date header if not present
2 parents 753a2e8 + d21efd5 commit 9554e21

File tree

2 files changed

+57
-12
lines changed

2 files changed

+57
-12
lines changed

proxy/handler.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,12 @@ func (h *handler) ForwardRequest(c *gin.Context) {
3535
req.Host = h.proxy.TargetHost
3636
req.Header.Set("Host", h.proxy.TargetHost)
3737

38+
// Add Date header since some clients don't automatically add it
39+
date := req.Header.Get("Date")
40+
if date == "" {
41+
req.Header.Set("Date", time.Now().Format(http.TimeFormat))
42+
}
43+
3844
start := time.Now()
3945
signedReq, err := h.reqSigner.SignRequest(req)
4046
singingDuration := time.Since(start)

proxy/handler_test.go

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"net/http"
55
"net/http/httptest"
66
"testing"
7+
"time"
78

89
"github.com/form3tech-oss/http-message-signing-proxy/test"
910
"github.com/gin-gonic/gin"
@@ -12,6 +13,30 @@ import (
1213
)
1314

1415
func TestHandler(t *testing.T) {
16+
tests := []struct {
17+
name string
18+
inputHeaders map[string]string
19+
headerTestFn func(responseHeader http.Header)
20+
}{
21+
{
22+
"automatic date header injection",
23+
nil,
24+
func(h http.Header) {
25+
_, err := time.Parse(http.TimeFormat, h.Get("Date"))
26+
require.NoError(t, err)
27+
},
28+
},
29+
{
30+
"date header present",
31+
map[string]string{
32+
"Date": time.Date(1998, time.May, 1, 1, 2, 3, 4, time.UTC).Format(http.TimeFormat),
33+
},
34+
func(h http.Header) {
35+
require.Equal(t, "Fri, 01 May 1998 01:02:03 GMT", h.Get("Date"))
36+
},
37+
},
38+
}
39+
1540
expectedRespBody := "OK"
1641
mockURL := "mock"
1742
mockCtrl := gomock.NewController(t)
@@ -29,44 +54,58 @@ func TestHandler(t *testing.T) {
2954
require.NoError(t, err)
3055

3156
// Test handler
57+
var w *test.TestResponseRecorder
3258
h := NewHandler(rs, mockReqSigner, mockMetricPublisher)
33-
w := test.NewTestResponseRecorder()
3459
_, e := gin.CreateTestContext(w)
3560
e.NoRoute(
3661
RecoverMiddleware(mockMetricPublisher),
3762
LogAndMetricsMiddleware(mockMetricPublisher),
3863
h.ForwardRequest,
3964
)
4065

41-
req, err := http.NewRequest(http.MethodGet, mockURL, nil)
42-
require.NoError(t, err)
66+
for _, tt := range tests {
67+
t.Run(tt.name, func(t *testing.T) {
68+
w = test.NewTestResponseRecorder()
69+
70+
// Test request
71+
req, err := http.NewRequest(http.MethodGet, mockURL, nil)
72+
require.NoError(t, err)
73+
74+
for k, v := range tt.inputHeaders {
75+
req.Header.Set(k, v)
76+
}
4377

44-
e.ServeHTTP(w, req)
78+
e.ServeHTTP(w, req)
4579

46-
require.Equal(t, expectedRespBody, w.Body.String())
47-
require.Equal(t, http.StatusOK, w.Code)
80+
require.Equal(t, expectedRespBody, w.Body.String())
81+
require.Equal(t, http.StatusOK, w.Code)
82+
tt.headerTestFn(w.Header())
83+
})
84+
}
4885
}
4986

5087
func mockReqSigner(mockCtrl *gomock.Controller) *MockRequestSigner {
5188
mockReqSigner := NewMockRequestSigner(mockCtrl)
5289
mockReqSigner.EXPECT().SignRequest(gomock.Any()).DoAndReturn(func(r *http.Request) (*http.Request, error) {
53-
// We don't test the signer here so we return the request as-is
90+
// We don't test the signer here, so we return the request as-is
5491
return r, nil
55-
})
92+
}).AnyTimes()
5693
return mockReqSigner
5794
}
5895

5996
func mockMetricPublisher(mockCtrl *gomock.Controller, mockURL string) *MockMetricPublisher {
6097
mockMetricPublisher := NewMockMetricPublisher(mockCtrl)
61-
mockMetricPublisher.EXPECT().IncrementTotalRequestCount(http.MethodGet, mockURL)
62-
mockMetricPublisher.EXPECT().MeasureSigningDuration(http.MethodGet, mockURL, gomock.Any())
63-
mockMetricPublisher.EXPECT().IncrementSignedRequestCount(http.MethodGet, mockURL)
64-
mockMetricPublisher.EXPECT().MeasureTotalDuration(http.MethodGet, mockURL, gomock.Any())
98+
mockMetricPublisher.EXPECT().IncrementTotalRequestCount(http.MethodGet, mockURL).AnyTimes()
99+
mockMetricPublisher.EXPECT().MeasureSigningDuration(http.MethodGet, mockURL, gomock.Any()).AnyTimes()
100+
mockMetricPublisher.EXPECT().IncrementSignedRequestCount(http.MethodGet, mockURL).AnyTimes()
101+
mockMetricPublisher.EXPECT().MeasureTotalDuration(http.MethodGet, mockURL, gomock.Any()).AnyTimes()
65102
return mockMetricPublisher
66103
}
67104

68105
func testTargetServer(expectedBody string) *httptest.Server {
69106
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107+
w.Header().Set("Date", r.Header.Get("Date"))
108+
w.Header().Set("haha", "haha")
70109
w.WriteHeader(http.StatusOK)
71110
_, _ = w.Write([]byte(expectedBody))
72111
}))

0 commit comments

Comments
 (0)