Skip to content

Commit b6b7771

Browse files
authored
routing/http/server: add cache control (#584)
1 parent 97e347e commit b6b7771

File tree

4 files changed

+200
-38
lines changed

4 files changed

+200
-38
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ The following emojis are used to highlight certain changes:
1616

1717
### Added
1818

19+
* `routing/http/server` now adds `Cache-Control` HTTP header to GET requests: 15 seconds for empty responses, or 5 minutes for responses with providers.
20+
1921
### Changed
2022

2123
### Removed

routing/http/server/server.go

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"io"
1010
"mime"
1111
"net/http"
12-
"strconv"
1312
"strings"
1413
"time"
1514

@@ -402,15 +401,26 @@ func (s *server) GetIPNS(w http.ResponseWriter, r *http.Request) {
402401
return
403402
}
404403

404+
var remainingValidity int
405+
// Include 'Expires' header with time when signature expiration happens
406+
if validityType, err := record.ValidityType(); err == nil && validityType == ipns.ValidityEOL {
407+
if validity, err := record.Validity(); err == nil {
408+
w.Header().Set("Expires", validity.UTC().Format(http.TimeFormat))
409+
remainingValidity = int(time.Until(validity).Seconds())
410+
}
411+
} else {
412+
remainingValidity = int(ipns.DefaultRecordLifetime.Seconds())
413+
}
405414
if ttl, err := record.TTL(); err == nil {
406-
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d", int(ttl.Seconds())))
415+
setCacheControl(w, int(ttl.Seconds()), remainingValidity)
407416
} else {
408-
w.Header().Set("Cache-Control", "max-age=60")
417+
setCacheControl(w, int(ipns.DefaultRecordTTL.Seconds()), remainingValidity)
409418
}
419+
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
410420

411-
recordEtag := strconv.FormatUint(xxhash.Sum64(rawRecord), 32)
412-
w.Header().Set("Etag", recordEtag)
421+
w.Header().Set("Etag", fmt.Sprintf(`"%x"`, xxhash.Sum64(rawRecord)))
413422
w.Header().Set("Content-Type", mediaTypeIPNSRecord)
423+
w.Header().Add("Vary", "Accept")
414424
w.Write(rawRecord)
415425
}
416426

@@ -462,8 +472,30 @@ func (s *server) PutIPNS(w http.ResponseWriter, r *http.Request) {
462472
w.WriteHeader(http.StatusOK)
463473
}
464474

465-
func writeJSONResult(w http.ResponseWriter, method string, val any) {
475+
var (
476+
// Rule-of-thumb Cache-Control policy is to work well with caching proxies and load balancers.
477+
// If there are any results, cache on the client for longer, and hint any in-between caches to
478+
// serve cached result and upddate cache in background as long we have
479+
// result that is within Amino DHT expiration window
480+
maxAgeWithResults = int((5 * time.Minute).Seconds()) // cache >0 results for longer
481+
maxAgeWithoutResults = int((15 * time.Second).Seconds()) // cache no results briefly
482+
maxStale = int((48 * time.Hour).Seconds()) // allow stale results as long within Amino DHT Expiration window
483+
)
484+
485+
func setCacheControl(w http.ResponseWriter, maxAge int, stale int) {
486+
w.Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d, stale-while-revalidate=%d, stale-if-error=%d", maxAge, stale, stale))
487+
}
488+
489+
func writeJSONResult(w http.ResponseWriter, method string, val interface{ Length() int }) {
466490
w.Header().Add("Content-Type", mediaTypeJSON)
491+
w.Header().Add("Vary", "Accept")
492+
493+
if val.Length() > 0 {
494+
setCacheControl(w, maxAgeWithResults, maxStale)
495+
} else {
496+
setCacheControl(w, maxAgeWithoutResults, maxStale)
497+
}
498+
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
467499

468500
// keep the marshaling separate from the writing, so we can distinguish bugs (which surface as 500)
469501
// from transient network issues (which surface as transport errors)
@@ -500,21 +532,30 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result
500532
defer resultIter.Close()
501533

502534
w.Header().Set("Content-Type", mediaTypeNDJSON)
503-
w.WriteHeader(http.StatusOK)
535+
w.Header().Add("Vary", "Accept")
536+
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
504537

538+
hasResults := false
505539
for resultIter.Next() {
506540
res := resultIter.Val()
507541
if res.Err != nil {
508542
logger.Errorw("ndjson iterator error", "Error", res.Err)
509543
return
510544
}
545+
511546
// don't use an encoder because we can't easily differentiate writer errors from encoding errors
512547
b, err := drjson.MarshalJSONBytes(res.Val)
513548
if err != nil {
514549
logger.Errorw("ndjson marshal error", "Error", err)
515550
return
516551
}
517552

553+
if !hasResults {
554+
hasResults = true
555+
// There's results, cache useful result for longer
556+
setCacheControl(w, maxAgeWithResults, maxStale)
557+
}
558+
518559
_, err = w.Write(b)
519560
if err != nil {
520561
logger.Warn("ndjson write error", "Error", err)
@@ -531,4 +572,9 @@ func writeResultsIterNDJSON[T any](w http.ResponseWriter, resultIter iter.Result
531572
f.Flush()
532573
}
533574
}
575+
576+
if !hasResults {
577+
// There weren't results, cache for shorter
578+
setCacheControl(w, maxAgeWithoutResults, maxStale)
579+
}
534580
}

routing/http/server/server_test.go

Lines changed: 133 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ import (
77
"io"
88
"net/http"
99
"net/http/httptest"
10+
"regexp"
11+
"strconv"
1012
"testing"
1113
"time"
1214

@@ -47,6 +49,7 @@ func TestHeaders(t *testing.T) {
4749
require.Equal(t, 200, resp.StatusCode)
4850
header := resp.Header.Get("Content-Type")
4951
require.Equal(t, mediaTypeJSON, header)
52+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
5053

5154
resp, err = http.Get(serverAddr + "/routing/v1/providers/" + "BAD_CID")
5255
require.NoError(t, err)
@@ -66,6 +69,13 @@ func makePeerID(t *testing.T) (crypto.PrivKey, peer.ID) {
6669
return sk, pid
6770
}
6871

72+
func requireCloseToNow(t *testing.T, lastModified string) {
73+
// inspecting fields like 'Last-Modified' is prone to one-off errors, we test with 1m buffer
74+
lastModifiedTime, err := time.Parse(http.TimeFormat, lastModified)
75+
require.NoError(t, err)
76+
require.WithinDuration(t, time.Now(), lastModifiedTime, 1*time.Minute)
77+
}
78+
6979
func TestProviders(t *testing.T) {
7080
pidStr := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn"
7181
pid2Str := "12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"
@@ -79,25 +89,31 @@ func TestProviders(t *testing.T) {
7989
cid, err := cid.Decode(cidStr)
8090
require.NoError(t, err)
8191

82-
runTest := func(t *testing.T, contentType string, expectedStream bool, expectedBody string) {
92+
runTest := func(t *testing.T, contentType string, empty bool, expectedStream bool, expectedBody string) {
8393
t.Parallel()
8494

85-
results := iter.FromSlice([]iter.Result[types.Record]{
86-
{Val: &types.PeerRecord{
87-
Schema: types.SchemaPeer,
88-
ID: &pid,
89-
Protocols: []string{"transport-bitswap"},
90-
Addrs: []types.Multiaddr{},
91-
}},
92-
//lint:ignore SA1019 // ignore staticcheck
93-
{Val: &types.BitswapRecord{
95+
var results *iter.SliceIter[iter.Result[types.Record]]
96+
97+
if empty {
98+
results = iter.FromSlice([]iter.Result[types.Record]{})
99+
} else {
100+
results = iter.FromSlice([]iter.Result[types.Record]{
101+
{Val: &types.PeerRecord{
102+
Schema: types.SchemaPeer,
103+
ID: &pid,
104+
Protocols: []string{"transport-bitswap"},
105+
Addrs: []types.Multiaddr{},
106+
}},
94107
//lint:ignore SA1019 // ignore staticcheck
95-
Schema: types.SchemaBitswap,
96-
ID: &pid2,
97-
Protocol: "transport-bitswap",
98-
Addrs: []types.Multiaddr{},
99-
}}},
100-
)
108+
{Val: &types.BitswapRecord{
109+
//lint:ignore SA1019 // ignore staticcheck
110+
Schema: types.SchemaBitswap,
111+
ID: &pid2,
112+
Protocol: "transport-bitswap",
113+
Addrs: []types.Multiaddr{},
114+
}}},
115+
)
116+
}
101117

102118
router := &mockContentRouter{}
103119
server := httptest.NewServer(Handler(router))
@@ -117,8 +133,16 @@ func TestProviders(t *testing.T) {
117133
resp, err := http.DefaultClient.Do(req)
118134
require.NoError(t, err)
119135
require.Equal(t, 200, resp.StatusCode)
120-
header := resp.Header.Get("Content-Type")
121-
require.Equal(t, contentType, header)
136+
137+
require.Equal(t, contentType, resp.Header.Get("Content-Type"))
138+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
139+
140+
if empty {
141+
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
142+
} else {
143+
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
144+
}
145+
requireCloseToNow(t, resp.Header.Get("Last-Modified"))
122146

123147
body, err := io.ReadAll(resp.Body)
124148
require.NoError(t, err)
@@ -127,11 +151,19 @@ func TestProviders(t *testing.T) {
127151
}
128152

129153
t.Run("JSON Response", func(t *testing.T) {
130-
runTest(t, mediaTypeJSON, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`)
154+
runTest(t, mediaTypeJSON, false, false, `{"Providers":[{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"},{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}]}`)
155+
})
156+
157+
t.Run("Empty JSON Response", func(t *testing.T) {
158+
runTest(t, mediaTypeJSON, true, false, `{"Providers":null}`)
131159
})
132160

133161
t.Run("NDJSON Response", func(t *testing.T) {
134-
runTest(t, mediaTypeNDJSON, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n")
162+
runTest(t, mediaTypeNDJSON, false, true, `{"Addrs":[],"ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vn","Protocols":["transport-bitswap"],"Schema":"peer"}`+"\n"+`{"Schema":"bitswap","Protocol":"transport-bitswap","ID":"12D3KooWM8sovaEGU1bmiWGWAzvs47DEcXKZZTuJnpQyVTkRs2Vz"}`+"\n")
163+
})
164+
165+
t.Run("Empty NDJSON Response", func(t *testing.T) {
166+
runTest(t, mediaTypeNDJSON, true, true, "")
135167
})
136168
}
137169

@@ -155,7 +187,26 @@ func TestPeers(t *testing.T) {
155187
require.Equal(t, 400, resp.StatusCode)
156188
})
157189

158-
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (JSON)", func(t *testing.T) {
190+
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, JSON)", func(t *testing.T) {
191+
t.Parallel()
192+
193+
_, pid := makePeerID(t)
194+
results := iter.FromSlice([]iter.Result[*types.PeerRecord]{})
195+
196+
router := &mockContentRouter{}
197+
router.On("FindPeers", mock.Anything, pid, 20).Return(results, nil)
198+
199+
resp := makeRequest(t, router, mediaTypeJSON, peer.ToCid(pid).String())
200+
require.Equal(t, 200, resp.StatusCode)
201+
202+
require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type"))
203+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
204+
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
205+
206+
requireCloseToNow(t, resp.Header.Get("Last-Modified"))
207+
})
208+
209+
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (JSON)", func(t *testing.T) {
159210
t.Parallel()
160211

161212
_, pid := makePeerID(t)
@@ -181,8 +232,11 @@ func TestPeers(t *testing.T) {
181232
resp := makeRequest(t, router, mediaTypeJSON, libp2pKeyCID)
182233
require.Equal(t, 200, resp.StatusCode)
183234

184-
header := resp.Header.Get("Content-Type")
185-
require.Equal(t, mediaTypeJSON, header)
235+
require.Equal(t, mediaTypeJSON, resp.Header.Get("Content-Type"))
236+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
237+
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
238+
239+
requireCloseToNow(t, resp.Header.Get("Last-Modified"))
186240

187241
body, err := io.ReadAll(resp.Body)
188242
require.NoError(t, err)
@@ -191,7 +245,26 @@ func TestPeers(t *testing.T) {
191245
require.Equal(t, expectedBody, string(body))
192246
})
193247

194-
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body (NDJSON)", func(t *testing.T) {
248+
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (No Results, NDJSON)", func(t *testing.T) {
249+
t.Parallel()
250+
251+
_, pid := makePeerID(t)
252+
results := iter.FromSlice([]iter.Result[*types.PeerRecord]{})
253+
254+
router := &mockContentRouter{}
255+
router.On("FindPeers", mock.Anything, pid, 0).Return(results, nil)
256+
257+
resp := makeRequest(t, router, mediaTypeNDJSON, peer.ToCid(pid).String())
258+
require.Equal(t, 200, resp.StatusCode)
259+
260+
require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type"))
261+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
262+
require.Equal(t, "public, max-age=15, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
263+
264+
requireCloseToNow(t, resp.Header.Get("Last-Modified"))
265+
})
266+
267+
t.Run("GET /routing/v1/peers/{cid-libp2p-key-peer-id} returns 200 with correct body and headers (NDJSON)", func(t *testing.T) {
195268
t.Parallel()
196269

197270
_, pid := makePeerID(t)
@@ -217,8 +290,9 @@ func TestPeers(t *testing.T) {
217290
resp := makeRequest(t, router, mediaTypeNDJSON, libp2pKeyCID)
218291
require.Equal(t, 200, resp.StatusCode)
219292

220-
header := resp.Header.Get("Content-Type")
221-
require.Equal(t, mediaTypeNDJSON, header)
293+
require.Equal(t, mediaTypeNDJSON, resp.Header.Get("Content-Type"))
294+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
295+
require.Equal(t, "public, max-age=300, stale-while-revalidate=172800, stale-if-error=172800", resp.Header.Get("Cache-Control"))
222296

223297
body, err := io.ReadAll(resp.Body)
224298
require.NoError(t, err)
@@ -254,6 +328,7 @@ func TestPeers(t *testing.T) {
254328
require.Equal(t, 200, resp.StatusCode)
255329

256330
header := resp.Header.Get("Content-Type")
331+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
257332
require.Equal(t, mediaTypeJSON, header)
258333

259334
body, err := io.ReadAll(resp.Body)
@@ -290,6 +365,7 @@ func TestPeers(t *testing.T) {
290365
require.Equal(t, 200, resp.StatusCode)
291366

292367
header := resp.Header.Get("Content-Type")
368+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
293369
require.Equal(t, mediaTypeNDJSON, header)
294370

295371
body, err := io.ReadAll(resp.Body)
@@ -306,10 +382,8 @@ func makeName(t *testing.T) (crypto.PrivKey, ipns.Name) {
306382
return sk, ipns.NameFromPeer(pid)
307383
}
308384

309-
func makeIPNSRecord(t *testing.T, cid cid.Cid, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) {
385+
func makeIPNSRecord(t *testing.T, cid cid.Cid, eol time.Time, ttl time.Duration, sk crypto.PrivKey, opts ...ipns.Option) (*ipns.Record, []byte) {
310386
path := path.FromCid(cid)
311-
eol := time.Now().Add(time.Hour * 48)
312-
ttl := time.Second * 20
313387

314388
record, err := ipns.NewRecord(sk, path, 1, eol, ttl, opts...)
315389
require.NoError(t, err)
@@ -339,7 +413,18 @@ func TestIPNS(t *testing.T) {
339413

340414
runWithRecordOptions := func(t *testing.T, opts ...ipns.Option) {
341415
sk, name1 := makeName(t)
342-
record1, rawRecord1 := makeIPNSRecord(t, cid1, sk)
416+
now := time.Now()
417+
eol := now.Add(24 * time.Hour * 7) // record valid for a week
418+
ttl := 42 * time.Second // distinct TTL
419+
record1, rawRecord1 := makeIPNSRecord(t, cid1, eol, ttl, sk)
420+
421+
stringToDuration := func(s string) time.Duration {
422+
seconds, err := strconv.Atoi(s)
423+
if err != nil {
424+
return 0
425+
}
426+
return time.Duration(seconds) * time.Second
427+
}
343428

344429
_, name2 := makeName(t)
345430

@@ -355,8 +440,25 @@ func TestIPNS(t *testing.T) {
355440
resp := makeRequest(t, router, "/routing/v1/ipns/"+name1.String())
356441
require.Equal(t, 200, resp.StatusCode)
357442
require.Equal(t, mediaTypeIPNSRecord, resp.Header.Get("Content-Type"))
443+
require.Equal(t, "Accept", resp.Header.Get("Vary"))
358444
require.NotEmpty(t, resp.Header.Get("Etag"))
359-
require.Equal(t, "max-age=20", resp.Header.Get("Cache-Control"))
445+
446+
requireCloseToNow(t, resp.Header.Get("Last-Modified"))
447+
448+
require.Contains(t, resp.Header.Get("Cache-Control"), "public, max-age=42")
449+
450+
// expected "stale" values are int(time.Until(eol).Seconds())
451+
// but running test on slow machine may be off by a few seconds
452+
// and we need to assert with some room for drift (1 minute just to not break any CI)
453+
re := regexp.MustCompile(`(?:^|,\s*)(max-age|stale-while-revalidate|stale-if-error)=(\d+)`)
454+
matches := re.FindAllStringSubmatch(resp.Header.Get("Cache-Control"), -1)
455+
staleWhileRevalidate := stringToDuration(matches[1][2])
456+
staleWhileError := stringToDuration(matches[2][2])
457+
require.WithinDuration(t, eol, time.Now().Add(staleWhileRevalidate), 1*time.Minute)
458+
require.WithinDuration(t, eol, time.Now().Add(staleWhileError), 1*time.Minute)
459+
460+
// 'Expires' on IPNS result is expected to match EOL of IPNS Record with ValidityType=0
461+
require.Equal(t, eol.UTC().Format(http.TimeFormat), resp.Header.Get("Expires"))
360462

361463
body, err := io.ReadAll(resp.Body)
362464
require.NoError(t, err)

0 commit comments

Comments
 (0)