Skip to content

Commit 19adb40

Browse files
committed
internal/mcp: make parseEventID robust, and add a test
Address comments on CL 682555 about using sscanf for parseEventID, adding a test for parseEventID and formatEventID. Change-Id: Ie31c5431c33829eaf244a5927da9e07bf5c300b5 Reviewed-on: https://go-review.googlesource.com/c/tools/+/683336 LUCI-TryBot-Result: Go LUCI <[email protected]> Reviewed-by: Alan Donovan <[email protected]> Auto-Submit: Robert Findley <[email protected]>
1 parent b6ff505 commit 19adb40

File tree

2 files changed

+61
-6
lines changed

2 files changed

+61
-6
lines changed

internal/mcp/streamable.go

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"fmt"
1111
"io"
1212
"net/http"
13+
"strconv"
1314
"strings"
1415
"sync"
1516
"sync/atomic"
@@ -449,20 +450,28 @@ stream:
449450
// streamID and message index idx.
450451
//
451452
// See also [parseEventID].
452-
func formatEventID(id streamID, idx int) string {
453-
return fmt.Sprintf("%d_%d", id, idx)
453+
func formatEventID(sid streamID, idx int) string {
454+
return fmt.Sprintf("%d_%d", sid, idx)
454455
}
455456

456457
// parseEventID parses a Last-Event-ID value into a logical stream id and
457458
// index.
458459
//
459460
// See also [formatEventID].
460-
func parseEventID(eventID string) (conn streamID, idx int, ok bool) {
461-
_, err := fmt.Sscanf(eventID, "%d_%d", &conn, &idx)
462-
if err != nil || conn < 0 || idx < 0 {
461+
func parseEventID(eventID string) (sid streamID, idx int, ok bool) {
462+
parts := strings.Split(eventID, "_")
463+
if len(parts) != 2 {
463464
return 0, 0, false
464465
}
465-
return conn, idx, true
466+
stream, err := strconv.ParseInt(parts[0], 10, 64)
467+
if err != nil || stream < 0 {
468+
return 0, 0, false
469+
}
470+
idx, err = strconv.Atoi(parts[1])
471+
if err != nil || idx < 0 {
472+
return 0, 0, false
473+
}
474+
return streamID(stream), idx, true
466475
}
467476

468477
// Read implements the [Connection] interface.

internal/mcp/streamable_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,49 @@ func mustMarshal(t *testing.T, v any) json.RawMessage {
538538
}
539539
return data
540540
}
541+
542+
func TestEventID(t *testing.T) {
543+
tests := []struct {
544+
sid streamID
545+
idx int
546+
}{
547+
{0, 0},
548+
{0, 1},
549+
{1, 0},
550+
{1, 1},
551+
{1234, 5678},
552+
}
553+
554+
for _, test := range tests {
555+
t.Run(fmt.Sprintf("%d_%d", test.sid, test.idx), func(t *testing.T) {
556+
eventID := formatEventID(test.sid, test.idx)
557+
gotSID, gotIdx, ok := parseEventID(eventID)
558+
if !ok {
559+
t.Fatalf("parseEventID(%q) failed, want ok", eventID)
560+
}
561+
if gotSID != test.sid || gotIdx != test.idx {
562+
t.Errorf("parseEventID(%q) = %d, %d, want %d, %d", eventID, gotSID, gotIdx, test.sid, test.idx)
563+
}
564+
})
565+
}
566+
567+
invalid := []string{
568+
"",
569+
"_",
570+
"1_",
571+
"_1",
572+
"a_1",
573+
"1_a",
574+
"-1_1",
575+
"1_-1",
576+
}
577+
578+
for _, eventID := range invalid {
579+
t.Run(fmt.Sprintf("invalid_%q", eventID), func(t *testing.T) {
580+
if _, _, ok := parseEventID(eventID); ok {
581+
t.Errorf("parseEventID(%q) succeeded, want failure", eventID)
582+
}
583+
})
584+
}
585+
}
586+

0 commit comments

Comments
 (0)