Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,8 +508,17 @@ func (c Client) retry(maxRetries int, req *http.Request) (*http.Response, error)
if err == nil && !contains(retrylist, resp.StatusCode) {
return resp, nil
}
// For IMDS, use exponential backoff based on attempt number
var waitTime time.Duration
if c.source == DefaultToIMDS {
// Exponential backoff with base of 1 second: 1s, 2s, 4s, 8s, etc.
waitTime = time.Second * time.Duration(1<<uint(attempt))
} else {
// For non-IMDS sources, use the fixed 1 second delay
waitTime = time.Second
}
select {
case <-time.After(time.Second):
case <-time.After(waitTime):
case <-req.Context().Done():
err = req.Context().Err()
return resp, err
Expand Down
73 changes: 68 additions & 5 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ func TestRetryFunction(t *testing.T) {
expectedBody string
maxRetries int
source Source
expectedDelays []time.Duration // Expected delays for IMDS exponential backoff
}{
{
name: "Successful Request",
Expand Down Expand Up @@ -228,27 +229,78 @@ func TestRetryFunction(t *testing.T) {
maxRetries: 2,
source: DefaultToIMDS,
},
{
name: "Successful Request IMDS with Exponential Backoff",
mockResponses: []struct {
body string
statusCode int
}{
{"Failed", http.StatusInternalServerError},
{"Failed", http.StatusInternalServerError},
{"Failed", http.StatusInternalServerError},
{"Success", http.StatusOK},
},
expectedStatus: http.StatusOK,
expectedBody: "Success",
maxRetries: 4,
source: DefaultToIMDS,
expectedDelays: []time.Duration{1 * time.Second, 2 * time.Second, 4 * time.Second},
},
{
name: "Successful Request Non-IMDS with Fixed Delay",
mockResponses: []struct {
body string
statusCode int
}{
{"Failed", http.StatusInternalServerError},
{"Success", http.StatusOK},
},
expectedStatus: http.StatusOK,
expectedBody: "Success",
maxRetries: 3,
source: AzureArc, // Non-IMDS source
expectedDelays: []time.Duration{1 * time.Second},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockClient := mock.NewClient()
for _, resp := range tt.mockResponses {
var actualDelays []time.Duration
var lastRequestTime time.Time

for i, resp := range tt.mockResponses {
body := bytes.NewBufferString(resp.body)
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode))
callback := func(r *http.Request) {
if !lastRequestTime.IsZero() {
actualDelays = append(actualDelays, time.Since(lastRequestTime))
}
lastRequestTime = time.Now()
}
// Apply callback only to retryable responses
if i < len(tt.mockResponses)-1 {
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode), mock.WithCallback(callback))
} else {
mockClient.AppendResponse(mock.WithBody(body.Bytes()), mock.WithHTTPStatusCode(resp.statusCode), mock.WithCallback(callback))
}
}
client, err := New(SystemAssigned(), WithHTTPClient(mockClient), WithRetryPolicyDisabled())
client, err := New(SystemAssigned(), WithHTTPClient(mockClient))
if err != nil {
t.Fatal(err)
}
// Manually set the source for testing purposes
client.source = tt.source

reqBody := bytes.NewBufferString("Test Body")
req, err := http.NewRequest("POST", "https://example.com", reqBody)
if err != nil {
t.Fatal(err)
}
finalResp, err := client.retry(tt.maxRetries, req)
if err != nil {
t.Fatal(err)
if tt.expectedStatus != finalResp.StatusCode {
t.Fatal(err)
}
}
if finalResp.StatusCode != tt.expectedStatus {
t.Fatalf("Expected status code %d, got %d", tt.expectedStatus, finalResp.StatusCode)
Expand All @@ -261,6 +313,17 @@ func TestRetryFunction(t *testing.T) {
if string(bodyBytes) != tt.expectedBody {
t.Fatalf("Expected body %q, got %q", tt.expectedBody, bodyBytes)
}

if len(tt.expectedDelays) > 0 {
if len(actualDelays) != len(tt.expectedDelays) {
t.Fatalf("Expected %d delays, got %d. Actual delays: %v", len(tt.expectedDelays), len(actualDelays), actualDelays)
}
for i, expectedDelay := range tt.expectedDelays {
if actualDelays[i] < expectedDelay-500*time.Millisecond || actualDelays[i] > expectedDelay+500*time.Millisecond {
t.Fatalf("Expected delay %v at attempt %d, got %v", expectedDelay, i, actualDelays[i])
}
}
}
})
}
}
Expand Down Expand Up @@ -964,7 +1027,7 @@ func TestAzureArcErrors(t *testing.T) {
},
{
name: "Invalid file path",
headerValue: "Basic realm=" + filepath.Join("path", "to", secretKey),
headerValue: basicRealm + filepath.Join("path", "to", secretKey),
expectedError: "invalid file path, expected " + testCaseFilePath + ", got " + filepath.Join("path", "to"),
},
{
Expand Down