Skip to content

Commit 6522228

Browse files
committed
cache token
1 parent 256bce9 commit 6522228

File tree

2 files changed

+236
-6
lines changed

2 files changed

+236
-6
lines changed

internal/metadata/metadata.go

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ type instanceData struct {
3232

3333
const instanceDataCacheTTL = 5 * time.Minute
3434

35+
// tokenRefreshMargin is how long before expiry we refresh the token
36+
const tokenRefreshMargin = 1 * time.Hour
37+
38+
type cachedToken struct {
39+
token string
40+
expiresAt time.Time
41+
}
42+
3543
type Reader struct {
3644
cfg Config
3745
logger *slog.Logger
@@ -40,6 +48,9 @@ type Reader struct {
4048
mu sync.Mutex
4149
cachedInstance *instanceData
4250
cachedFetchedAt time.Time
51+
52+
tokenMu sync.Mutex
53+
cachedIAM *cachedToken
4354
}
4455

4556
func NewReader(cfg Config, logger *slog.Logger) *Reader {
@@ -80,16 +91,64 @@ func (r *Reader) GetInstanceId() (instanceId string, isFallback bool, err error)
8091

8192
func (r *Reader) GetIamToken() (string, error) {
8293
if r.cfg.UseMetadataService {
83-
tokenPath := fmt.Sprintf("/v1/iam/%s/token/access_token", r.cfg.MetadataTokenType)
84-
body, err := r.fetchFromMetadataService(tokenPath)
94+
token, err := r.getCachedIAMToken()
8595
if err == nil {
86-
return strings.TrimSpace(string(body)), nil
96+
return token, nil
8797
}
8898
r.logger.Warn("Failed to get IAM token from IMDS, falling back to file", "error", err)
8999
}
90100
return r.readAndTrimFile(r.cfg.Path + "/" + r.cfg.IamTokenFilename)
91101
}
92102

103+
func (r *Reader) getCachedIAMToken() (string, error) {
104+
r.tokenMu.Lock()
105+
defer r.tokenMu.Unlock()
106+
107+
if r.cachedIAM != nil && time.Until(r.cachedIAM.expiresAt) > tokenRefreshMargin {
108+
return r.cachedIAM.token, nil
109+
}
110+
111+
tokenPath := fmt.Sprintf("/v1/iam/%s/token/access_token", r.cfg.MetadataTokenType)
112+
body, err := r.fetchFromMetadataService(tokenPath)
113+
if err != nil {
114+
if r.cachedIAM != nil && time.Until(r.cachedIAM.expiresAt) > 0 {
115+
r.logger.Warn("Failed to refresh IAM token, using cached token until expiry", "error", err, "expires_at", r.cachedIAM.expiresAt)
116+
return r.cachedIAM.token, nil
117+
}
118+
return "", fmt.Errorf("failed to fetch IAM token from IMDS: %w", err)
119+
}
120+
token := strings.TrimSpace(string(body))
121+
122+
expiresAt, err := r.fetchTokenExpiresAt()
123+
if err != nil {
124+
r.logger.Warn("Failed to get token expiry from IMDS, using default TTL", "error", err)
125+
expiresAt = time.Now().Add(instanceDataCacheTTL)
126+
}
127+
128+
if time.Until(expiresAt) <= 0 {
129+
return "", fmt.Errorf("token from IMDS is already expired (expires_at: %s)", expiresAt.Format(time.RFC3339Nano))
130+
}
131+
132+
r.cachedIAM = &cachedToken{
133+
token: token,
134+
expiresAt: expiresAt,
135+
}
136+
return token, nil
137+
}
138+
139+
func (r *Reader) fetchTokenExpiresAt() (time.Time, error) {
140+
expiresAtPath := fmt.Sprintf("/v1/iam/%s/token/expires_at", r.cfg.MetadataTokenType)
141+
body, err := r.fetchFromMetadataService(expiresAtPath)
142+
if err != nil {
143+
return time.Time{}, fmt.Errorf("failed to fetch token expires_at: %w", err)
144+
}
145+
expiresAt, err := time.Parse(time.RFC3339Nano, strings.TrimSpace(string(body)))
146+
if err != nil {
147+
return time.Time{}, fmt.Errorf("failed to parse expires_at timestamp: %w", err)
148+
}
149+
return expiresAt, nil
150+
}
151+
93152
func (r *Reader) getInstanceData() (*instanceData, error) {
94153
r.mu.Lock()
95154
defer r.mu.Unlock()

internal/metadata/metadata_test.go

Lines changed: 174 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,49 @@ func TestGetInstanceId_IMDSFallbackURL(t *testing.T) {
113113
}
114114

115115
func TestGetIamToken_IMDS(t *testing.T) {
116+
expiresAt := time.Now().Add(12 * time.Hour).UTC().Format(time.RFC3339Nano)
116117
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
117118
assert.Equal(t, "true", r.Header.Get("Metadata"))
118-
if r.URL.Path == "/v1/iam/tsa/token/access_token" {
119+
switch r.URL.Path {
120+
case "/v1/iam/tsa/token/access_token":
119121
_, err := w.Write([]byte("my-iam-token"))
120122
assert.NoError(t, err)
121-
return
123+
case "/v1/iam/tsa/token/expires_at":
124+
_, err := w.Write([]byte(expiresAt))
125+
assert.NoError(t, err)
126+
default:
127+
http.NotFound(w, r)
128+
}
129+
}))
130+
defer server.Close()
131+
132+
reader := NewReader(Config{
133+
UseMetadataService: true,
134+
MetadataServiceURL: server.URL,
135+
MetadataServiceFallbackURL: server.URL,
136+
MetadataTokenType: "tsa",
137+
}, testLogger())
138+
139+
token, err := reader.GetIamToken()
140+
require.NoError(t, err)
141+
assert.Equal(t, "my-iam-token", token)
142+
}
143+
144+
func TestGetIamToken_Cached(t *testing.T) {
145+
tokenCallCount := 0
146+
expiresAt := time.Now().Add(12 * time.Hour).UTC().Format(time.RFC3339Nano)
147+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
148+
switch r.URL.Path {
149+
case "/v1/iam/tsa/token/access_token":
150+
tokenCallCount++
151+
_, err := w.Write([]byte("my-iam-token"))
152+
assert.NoError(t, err)
153+
case "/v1/iam/tsa/token/expires_at":
154+
_, err := w.Write([]byte(expiresAt))
155+
assert.NoError(t, err)
156+
default:
157+
http.NotFound(w, r)
122158
}
123-
http.NotFound(w, r)
124159
}))
125160
defer server.Close()
126161

@@ -131,9 +166,145 @@ func TestGetIamToken_IMDS(t *testing.T) {
131166
MetadataTokenType: "tsa",
132167
}, testLogger())
133168

169+
// First call fetches from IMDS
134170
token, err := reader.GetIamToken()
135171
require.NoError(t, err)
136172
assert.Equal(t, "my-iam-token", token)
173+
174+
// Second call should use cache
175+
token, err = reader.GetIamToken()
176+
require.NoError(t, err)
177+
assert.Equal(t, "my-iam-token", token)
178+
179+
assert.Equal(t, 1, tokenCallCount, "token should be fetched only once while cached")
180+
}
181+
182+
func TestGetIamToken_RefreshesWhenNearExpiry(t *testing.T) {
183+
tokenCallCount := 0
184+
expiresAt := time.Now().Add(12 * time.Hour).UTC().Format(time.RFC3339Nano)
185+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
186+
switch r.URL.Path {
187+
case "/v1/iam/tsa/token/access_token":
188+
tokenCallCount++
189+
_, err := fmt.Fprintf(w, "token-%d", tokenCallCount)
190+
assert.NoError(t, err)
191+
case "/v1/iam/tsa/token/expires_at":
192+
_, err := w.Write([]byte(expiresAt))
193+
assert.NoError(t, err)
194+
default:
195+
http.NotFound(w, r)
196+
}
197+
}))
198+
defer server.Close()
199+
200+
reader := NewReader(Config{
201+
UseMetadataService: true,
202+
MetadataServiceURL: server.URL,
203+
MetadataServiceFallbackURL: server.URL,
204+
MetadataTokenType: "tsa",
205+
}, testLogger())
206+
207+
// First fetch
208+
token, err := reader.GetIamToken()
209+
require.NoError(t, err)
210+
assert.Equal(t, "token-1", token)
211+
assert.Equal(t, 1, tokenCallCount)
212+
213+
// Simulate token about to expire (within refresh margin)
214+
reader.tokenMu.Lock()
215+
reader.cachedIAM.expiresAt = time.Now().Add(30 * time.Minute) // less than 1 hour margin
216+
reader.tokenMu.Unlock()
217+
218+
// Should re-fetch
219+
token, err = reader.GetIamToken()
220+
require.NoError(t, err)
221+
assert.Equal(t, "token-2", token)
222+
assert.Equal(t, 2, tokenCallCount)
223+
}
224+
225+
func TestGetIamToken_UsesStaleTokenOnRefreshFailure(t *testing.T) {
226+
tokenCallCount := 0
227+
expiresAt := time.Now().Add(12 * time.Hour).UTC().Format(time.RFC3339Nano)
228+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
229+
switch r.URL.Path {
230+
case "/v1/iam/tsa/token/access_token":
231+
tokenCallCount++
232+
if tokenCallCount > 1 {
233+
w.WriteHeader(http.StatusInternalServerError)
234+
return
235+
}
236+
_, err := w.Write([]byte("original-token"))
237+
assert.NoError(t, err)
238+
case "/v1/iam/tsa/token/expires_at":
239+
_, err := w.Write([]byte(expiresAt))
240+
assert.NoError(t, err)
241+
default:
242+
http.NotFound(w, r)
243+
}
244+
}))
245+
defer server.Close()
246+
247+
reader := NewReader(Config{
248+
UseMetadataService: true,
249+
MetadataServiceURL: server.URL,
250+
MetadataServiceFallbackURL: server.URL,
251+
MetadataTokenType: "tsa",
252+
}, testLogger())
253+
254+
// First fetch succeeds
255+
token, err := reader.GetIamToken()
256+
require.NoError(t, err)
257+
assert.Equal(t, "original-token", token)
258+
259+
// Simulate near expiry but not yet expired
260+
reader.tokenMu.Lock()
261+
reader.cachedIAM.expiresAt = time.Now().Add(30 * time.Minute) // needs refresh but not expired
262+
reader.tokenMu.Unlock()
263+
264+
// Refresh fails — should return stale token since it hasn't expired yet
265+
token, err = reader.GetIamToken()
266+
require.NoError(t, err)
267+
assert.Equal(t, "original-token", token)
268+
}
269+
270+
func TestGetIamToken_AlreadyExpired(t *testing.T) {
271+
expiredAt := time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339Nano)
272+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
273+
switch r.URL.Path {
274+
case "/v1/iam/tsa/token/access_token":
275+
_, err := w.Write([]byte("expired-token"))
276+
assert.NoError(t, err)
277+
case "/v1/iam/tsa/token/expires_at":
278+
_, err := w.Write([]byte(expiredAt))
279+
assert.NoError(t, err)
280+
default:
281+
http.NotFound(w, r)
282+
}
283+
}))
284+
defer server.Close()
285+
286+
tmpDir := t.TempDir()
287+
err := os.WriteFile(filepath.Join(tmpDir, "tsa-token"), []byte("file-token\n"), 0644)
288+
require.NoError(t, err)
289+
290+
reader := NewReader(Config{
291+
UseMetadataService: true,
292+
MetadataServiceURL: server.URL,
293+
MetadataServiceFallbackURL: server.URL,
294+
MetadataTokenType: "tsa",
295+
Path: tmpDir,
296+
IamTokenFilename: "tsa-token",
297+
}, testLogger())
298+
299+
// Token from IMDS is expired — should error from getCachedIAMToken and fall back to file
300+
token, err := reader.GetIamToken()
301+
require.NoError(t, err)
302+
assert.Equal(t, "file-token", token)
303+
304+
// Verify token was not cached
305+
reader.tokenMu.Lock()
306+
assert.Nil(t, reader.cachedIAM, "expired token should not be cached")
307+
reader.tokenMu.Unlock()
137308
}
138309

139310
func TestGetIamToken_FileFallback(t *testing.T) {

0 commit comments

Comments
 (0)