Skip to content

Commit cc5f47a

Browse files
appleboyclaude
andcommitted
fix(main): return explicit error on oversized responses
Address Copilot review feedback: - Fix comment from "1 MB" to "1 MiB" (1<<20 is a mebibyte) - Read maxResponseSize+1 bytes and return a clear "response too large" error instead of silently truncating (which caused confusing JSON parse errors) - Extract readResponseBody helper to deduplicate the pattern - Add tests for readResponseBody (within limit, at limit, exceeds limit) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 55ac4f0 commit cc5f47a

2 files changed

Lines changed: 58 additions & 5 deletions

File tree

main.go

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ const (
5656
tokenExchangeTimeout = 10 * time.Second
5757
tokenVerificationTimeout = 10 * time.Second
5858
refreshTokenTimeout = 10 * time.Second
59-
maxResponseSize = 1 << 20 // 1 MB
59+
maxResponseSize = 1 << 20 // 1 MiB
6060
)
6161

6262
func init() {
@@ -269,6 +269,25 @@ type tokenResponse struct {
269269
Scope string `json:"scope"`
270270
}
271271

272+
// errResponseTooLarge is returned when a server response exceeds maxResponseSize.
273+
var errResponseTooLarge = fmt.Errorf(
274+
"response body exceeds maximum allowed size of %d bytes",
275+
maxResponseSize,
276+
)
277+
278+
// readResponseBody reads up to maxResponseSize bytes from r and returns an
279+
// explicit error when the response is too large (rather than silently truncating).
280+
func readResponseBody(r io.Reader) ([]byte, error) {
281+
body, err := io.ReadAll(io.LimitReader(r, maxResponseSize+1))
282+
if err != nil {
283+
return nil, err
284+
}
285+
if int64(len(body)) > maxResponseSize {
286+
return nil, errResponseTooLarge
287+
}
288+
return body, nil
289+
}
290+
272291
// parseOAuthError attempts to extract a structured OAuth error from a non-200
273292
// response body. Falls back to including the raw body in the error message.
274293
func parseOAuthError(statusCode int, body []byte, action string) error {
@@ -348,7 +367,7 @@ func exchangeCode(ctx context.Context, code, codeVerifier string) (*tui.TokenSto
348367
}
349368
defer resp.Body.Close()
350369

351-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
370+
body, err := readResponseBody(resp.Body)
352371
if err != nil {
353372
return nil, fmt.Errorf("failed to read response: %w", err)
354373
}
@@ -412,7 +431,7 @@ func refreshAccessToken(ctx context.Context, refreshToken string) (*tui.TokenSto
412431
}
413432
defer resp.Body.Close()
414433

415-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
434+
body, err := readResponseBody(resp.Body)
416435
if err != nil {
417436
return nil, fmt.Errorf("failed to read response: %w", err)
418437
}
@@ -479,7 +498,7 @@ func verifyToken(ctx context.Context, accessToken string) (string, error) {
479498
}
480499
defer resp.Body.Close()
481500

482-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
501+
body, err := readResponseBody(resp.Body)
483502
if err != nil {
484503
return "", fmt.Errorf("failed to read response: %w", err)
485504
}
@@ -538,7 +557,7 @@ func makeAPICallWithAutoRefresh(ctx context.Context, storage *tui.TokenStorage)
538557
defer resp.Body.Close()
539558
}
540559

541-
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseSize))
560+
body, err := readResponseBody(resp.Body)
542561
if err != nil {
543562
return fmt.Errorf("failed to read response: %w", err)
544563
}

main_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package main
22

33
import (
4+
"bytes"
5+
"errors"
46
"path/filepath"
57
"testing"
68
"time"
@@ -276,6 +278,38 @@ func TestInitTokenStore_Invalid(t *testing.T) {
276278
}
277279
}
278280

281+
func TestReadResponseBody(t *testing.T) {
282+
t.Run("within limit", func(t *testing.T) {
283+
data := bytes.Repeat([]byte("a"), 100)
284+
body, err := readResponseBody(bytes.NewReader(data))
285+
if err != nil {
286+
t.Fatalf("unexpected error: %v", err)
287+
}
288+
if len(body) != 100 {
289+
t.Errorf("expected 100 bytes, got %d", len(body))
290+
}
291+
})
292+
293+
t.Run("exactly at limit", func(t *testing.T) {
294+
data := bytes.Repeat([]byte("a"), maxResponseSize)
295+
body, err := readResponseBody(bytes.NewReader(data))
296+
if err != nil {
297+
t.Fatalf("unexpected error: %v", err)
298+
}
299+
if len(body) != maxResponseSize {
300+
t.Errorf("expected %d bytes, got %d", maxResponseSize, len(body))
301+
}
302+
})
303+
304+
t.Run("exceeds limit", func(t *testing.T) {
305+
data := bytes.Repeat([]byte("a"), maxResponseSize+1)
306+
_, err := readResponseBody(bytes.NewReader(data))
307+
if !errors.Is(err, errResponseTooLarge) {
308+
t.Errorf("expected errResponseTooLarge, got: %v", err)
309+
}
310+
})
311+
}
312+
279313
// containsSubstring is a helper to avoid importing strings in tests.
280314
func containsSubstring(s, sub string) bool {
281315
return len(s) >= len(sub) && findSubstring(s, sub)

0 commit comments

Comments
 (0)