Skip to content

Commit 8d2995d

Browse files
committed
adding transformer to cleanly handle special toolsets
1 parent 7cdf26f commit 8d2995d

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

internal/ghmcp/server.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"net/url"
1111
"os"
1212
"os/signal"
13+
"slices"
1314
"strings"
1415
"syscall"
1516
"time"
@@ -117,7 +118,7 @@ func NewMCPServer(cfg MCPServerConfig) (*server.MCPServer, error) {
117118
}
118119
}
119120

120-
enabledToolsets = transformDefault(enabledToolsets)
121+
enabledToolsets = transformSpecialToolsets(enabledToolsets)
121122

122123
// Generate instructions based on enabled toolsets
123124
instructions := github.GenerateInstructions(enabledToolsets)
@@ -473,19 +474,24 @@ func (t *bearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, erro
473474
return t.transport.RoundTrip(req)
474475
}
475476

476-
// transformDefault replaces "default" in the enabled toolsets with the actual default toolset IDs.
477-
// If "default" is present, it removes it and adds the default toolset IDs from GetDefaultToolsetIDs().
477+
// transformSpecialToolsets handles special toolset keywords in the enabled toolsets list:
478+
// - "all": Returns ["all"] immediately, ignoring all other toolsets
479+
// - "default": Replaces with the actual default toolset IDs from GetDefaultToolsetIDs()
478480
// Duplicates are removed from the final result.
479-
func transformDefault(enabledToolsets []string) []string {
480-
hasDefault := false
481-
result := make([]string, 0, len(enabledToolsets))
481+
func transformSpecialToolsets(enabledToolsets []string) []string {
482+
// Check if "all" is present - if so, return immediately
483+
if slices.Contains(enabledToolsets, github.ToolsetMetadataAll.ID) {
484+
return []string{github.ToolsetMetadataAll.ID}
485+
}
486+
487+
hasDefault := slices.Contains(enabledToolsets, github.ToolsetMetadataDefault.ID)
488+
482489
seen := make(map[string]bool)
490+
result := make([]string, 0, len(enabledToolsets))
483491

484-
// First pass: check if "default" exists and collect non-default toolsets
492+
// Add non-default toolsets, removing duplicates
485493
for _, toolset := range enabledToolsets {
486-
if toolset == github.ToolsetMetadataDefault.ID {
487-
hasDefault = true
488-
} else if !seen[toolset] {
494+
if toolset != github.ToolsetMetadataDefault.ID && !seen[toolset] {
489495
result = append(result, toolset)
490496
seen[toolset] = true
491497
}

internal/ghmcp/server_test.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"github.com/stretchr/testify/require"
99
)
1010

11-
func TestTransformDefault(t *testing.T) {
11+
func TestTransformSpecialToolsets(t *testing.T) {
1212
tests := []struct {
1313
name string
1414
input []string
@@ -19,6 +19,26 @@ func TestTransformDefault(t *testing.T) {
1919
input: []string{},
2020
expected: []string{},
2121
},
22+
{
23+
name: "all only",
24+
input: []string{"all"},
25+
expected: []string{"all"},
26+
},
27+
{
28+
name: "all with other toolsets",
29+
input: []string{"all", "actions", "gists"},
30+
expected: []string{"all"},
31+
},
32+
{
33+
name: "all at the end",
34+
input: []string{"actions", "gists", "all"},
35+
expected: []string{"all"},
36+
},
37+
{
38+
name: "all with default",
39+
input: []string{"default", "all", "actions"},
40+
expected: []string{"all"},
41+
},
2242
{
2343
name: "default only",
2444
input: []string{"default"},
@@ -104,7 +124,7 @@ func TestTransformDefault(t *testing.T) {
104124

105125
for _, tt := range tests {
106126
t.Run(tt.name, func(t *testing.T) {
107-
result := transformDefault(tt.input)
127+
result := transformSpecialToolsets(tt.input)
108128

109129
// Check that the result has the correct length
110130
require.Len(t, result, len(tt.expected), "result length should match expected length")
@@ -132,10 +152,10 @@ func TestTransformDefault(t *testing.T) {
132152
}
133153
}
134154

135-
func TestTransformDefaultWithActualDefaults(t *testing.T) {
155+
func TestTransformSpecialToolsetsWithActualDefaults(t *testing.T) {
136156
// This test verifies that the function uses the actual default toolsets from GetDefaultToolsetIDs()
137157
input := []string{"default"}
138-
result := transformDefault(input)
158+
result := transformSpecialToolsets(input)
139159

140160
defaultToolsets := github.GetDefaultToolsetIDs()
141161

0 commit comments

Comments
 (0)