Skip to content

Commit 2ead202

Browse files
nullfuncKevyVo
andauthored
mcp prompt tests (#1437)
Co-authored-by: Kevin Vo <[email protected]>
1 parent 9f4e33f commit 2ead202

File tree

10 files changed

+480
-31
lines changed

10 files changed

+480
-31
lines changed

src/pkg/mcp/prompts/awsBYOC.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ import (
55
"errors"
66
"os"
77

8-
"github.com/DefangLabs/defang/src/pkg/cli"
98
"github.com/DefangLabs/defang/src/pkg/cli/client"
10-
"github.com/DefangLabs/defang/src/pkg/mcp/tools"
119
"github.com/mark3labs/mcp-go/mcp"
1210
"github.com/mark3labs/mcp-go/server"
1311
)
1412

15-
func setupAWSBYOPrompt(s *server.MCPServer, cluster string, providerId *client.ProviderID) {
13+
func setupAwsByocPrompt(s *server.MCPServer, cluster string, providerId *client.ProviderID) {
1614
awsBYOCPrompt := mcp.NewPrompt("AWS Setup",
1715
mcp.WithPromptDescription("Setup for AWS"),
1816

@@ -30,7 +28,12 @@ func setupAWSBYOPrompt(s *server.MCPServer, cluster string, providerId *client.P
3028
),
3129
)
3230

33-
s.AddPrompt(awsBYOCPrompt, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
31+
s.AddPrompt(awsBYOCPrompt, awsByocPromptHandler(cluster, providerId))
32+
}
33+
34+
// awsByocPromptHandler is extracted for testability
35+
func awsByocPromptHandler(cluster string, providerId *client.ProviderID) func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
36+
return func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
3437
// Can never be nil or empty due to RequiredArgument
3538
awsID := req.Params.Arguments["AWS Credential"]
3639
if isValidAWSKey(awsID) {
@@ -74,12 +77,12 @@ func setupAWSBYOPrompt(s *server.MCPServer, cluster string, providerId *client.P
7477
}
7578
}
7679

77-
fabric, err := cli.Connect(ctx, cluster)
80+
fabric, err := Connect(ctx, cluster)
7881
if err != nil {
7982
return nil, err
8083
}
8184

82-
_, err = tools.CheckProviderConfigured(ctx, fabric, client.ProviderAWS, "", 0)
85+
_, err = CheckProviderConfigured(ctx, fabric, client.ProviderAWS, "", 0)
8386
if err != nil {
8487
return nil, err
8588
}
@@ -101,7 +104,7 @@ func setupAWSBYOPrompt(s *server.MCPServer, cluster string, providerId *client.P
101104
},
102105
},
103106
}, nil
104-
})
107+
}
105108
}
106109

107110
// Check if the provided AWS access key ID is valid
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package prompts
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestIsValidAWSKey_ValidKeys(t *testing.T) {
8+
validKeys := []string{
9+
"AKIA12345678901234",
10+
"AIDA12345678901234",
11+
"ASIA12345678901234",
12+
"APKA12345678901234",
13+
"AROA12345678901234",
14+
"ASCA12345678901234",
15+
}
16+
for _, key := range validKeys {
17+
if !isValidAWSKey(key) {
18+
t.Errorf("expected key %q to be valid", key)
19+
}
20+
}
21+
}
22+
23+
func TestIsValidAWSKey_InvalidKeys(t *testing.T) {
24+
invalidKeys := []string{
25+
"", // empty
26+
"AKIA1234", // too short
27+
"AKIA", // too short
28+
"AKIA1234567890", // too short
29+
"AKI12345678901234", // prefix too short
30+
"ZZZZ12345678901234", // invalid prefix
31+
}
32+
for _, key := range invalidKeys {
33+
if isValidAWSKey(key) {
34+
t.Errorf("expected key %q to be invalid", key)
35+
}
36+
}
37+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package prompts
2+
3+
import (
4+
"context"
5+
"os"
6+
"testing"
7+
8+
"github.com/DefangLabs/defang/src/pkg/cli/client"
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/stretchr/testify/require"
11+
)
12+
13+
func TestAwsByocPromptHandler_Success_AccessKey(t *testing.T) {
14+
origConnect := Connect
15+
origCheck := CheckProviderConfigured
16+
Connect = func(ctx context.Context, cluster string) (*client.GrpcClient, error) { return nil, nil }
17+
CheckProviderConfigured = func(ctx context.Context, fabric client.FabricClient, providerId client.ProviderID, s string, i int) (client.Provider, error) {
18+
return &MockProvider{}, nil
19+
}
20+
defer func() {
21+
Connect = origConnect
22+
CheckProviderConfigured = origCheck
23+
}()
24+
25+
providerId := client.ProviderID("")
26+
handler := awsByocPromptHandler("test-cluster", &providerId)
27+
28+
req := mcp.GetPromptRequest{
29+
Params: mcp.GetPromptParams{
30+
Arguments: map[string]string{
31+
"AWS Credential": "AKIAEXAMPLEKEY1234",
32+
"AWS_SECRET_ACCESS_KEY": "secret",
33+
"AWS_REGION": "us-west-2",
34+
},
35+
},
36+
}
37+
38+
// make sure these env do not exist before the test
39+
t.Setenv("AWS_ACCESS_KEY_ID", "")
40+
t.Setenv("AWS_SECRET_ACCESS_KEY", "")
41+
t.Setenv("AWS_REGION", "")
42+
t.Setenv("DEFANG_PROVIDER", "")
43+
44+
res, err := handler(t.Context(), req)
45+
require.NoError(t, err)
46+
require.NotNil(t, res)
47+
require.Equal(t, client.ProviderAWS, providerId)
48+
require.Equal(t, "AKIAEXAMPLEKEY1234", os.Getenv("AWS_ACCESS_KEY_ID"))
49+
require.Equal(t, "secret", os.Getenv("AWS_SECRET_ACCESS_KEY"))
50+
require.Equal(t, "us-west-2", os.Getenv("AWS_REGION"))
51+
require.Equal(t, "aws", os.Getenv("DEFANG_PROVIDER"))
52+
}
53+
54+
func TestAwsByocPromptHandler_Success_Profile(t *testing.T) {
55+
origConnect := Connect
56+
origCheck := CheckProviderConfigured
57+
Connect = func(ctx context.Context, cluster string) (*client.GrpcClient, error) { return nil, nil }
58+
CheckProviderConfigured = func(ctx context.Context, fabric client.FabricClient, providerId client.ProviderID, s string, i int) (client.Provider, error) {
59+
return &MockProvider{}, nil
60+
}
61+
defer func() {
62+
Connect = origConnect
63+
CheckProviderConfigured = origCheck
64+
}()
65+
66+
providerId := client.ProviderID("")
67+
handler := awsByocPromptHandler("test-cluster", &providerId)
68+
69+
req := mcp.GetPromptRequest{
70+
Params: mcp.GetPromptParams{
71+
Arguments: map[string]string{
72+
"AWS Credential": "my-profile",
73+
"AWS_REGION": "us-east-1",
74+
},
75+
},
76+
}
77+
78+
// make sure these env do not exist before the test
79+
t.Setenv("AWS_PROFILE", "")
80+
t.Setenv("AWS_REGION", "")
81+
t.Setenv("DEFANG_PROVIDER", "")
82+
83+
res, err := handler(t.Context(), req)
84+
require.NoError(t, err)
85+
require.NotNil(t, res)
86+
require.Equal(t, client.ProviderAWS, providerId)
87+
require.Equal(t, "my-profile", os.Getenv("AWS_PROFILE"))
88+
require.Equal(t, "us-east-1", os.Getenv("AWS_REGION"))
89+
require.Equal(t, "aws", os.Getenv("DEFANG_PROVIDER"))
90+
}
91+
92+
func TestAwsByocPromptHandler_MissingSecret(t *testing.T) {
93+
providerId := client.ProviderID("")
94+
handler := awsByocPromptHandler("test-cluster", &providerId)
95+
96+
req := mcp.GetPromptRequest{
97+
Params: mcp.GetPromptParams{
98+
Arguments: map[string]string{
99+
"AWS Credential": "AKIAEXAMPLEKEY1234",
100+
"AWS_REGION": "us-west-2",
101+
},
102+
},
103+
}
104+
105+
res, err := handler(t.Context(), req)
106+
require.ErrorContains(t, err, "AWS_SECRET_ACCESS_KEY is required")
107+
require.Nil(t, res)
108+
}
109+
110+
func TestAwsByocPromptHandler_MissingRegion_AccessKey(t *testing.T) {
111+
providerId := client.ProviderID("")
112+
handler := awsByocPromptHandler("test-cluster", &providerId)
113+
114+
req := mcp.GetPromptRequest{
115+
Params: mcp.GetPromptParams{
116+
Arguments: map[string]string{
117+
"AWS Credential": "AKIAEXAMPLEKEY1234",
118+
"AWS_SECRET_ACCESS_KEY": "secret",
119+
},
120+
},
121+
}
122+
123+
res, err := handler(t.Context(), req)
124+
require.ErrorContains(t, err, "AWS_REGION is required")
125+
require.Nil(t, res)
126+
}
127+
128+
func TestAwsByocPromptHandler_ConnectError(t *testing.T) {
129+
origConnect := Connect
130+
Connect = func(ctx context.Context, cluster string) (*client.GrpcClient, error) { return nil, os.ErrNotExist }
131+
defer func() { Connect = origConnect }()
132+
133+
providerId := client.ProviderID("")
134+
handler := awsByocPromptHandler("test-cluster", &providerId)
135+
136+
req := mcp.GetPromptRequest{
137+
Params: mcp.GetPromptParams{
138+
Arguments: map[string]string{
139+
"AWS Credential": "AKIAEXAMPLEKEY1234",
140+
"AWS_SECRET_ACCESS_KEY": "secret",
141+
"AWS_REGION": "us-west-2",
142+
},
143+
},
144+
}
145+
146+
res, err := handler(t.Context(), req)
147+
require.Error(t, err)
148+
require.Nil(t, res)
149+
}
150+
151+
func TestAwsByocPromptHandler_CheckProviderConfiguredError(t *testing.T) {
152+
origConnect := Connect
153+
origCheck := CheckProviderConfigured
154+
Connect = func(ctx context.Context, cluster string) (*client.GrpcClient, error) { return nil, nil }
155+
CheckProviderConfigured = func(ctx context.Context, fabric client.FabricClient, providerId client.ProviderID, s string, i int) (client.Provider, error) {
156+
return nil, os.ErrPermission
157+
}
158+
defer func() {
159+
Connect = origConnect
160+
CheckProviderConfigured = origCheck
161+
}()
162+
163+
providerId := client.ProviderID("")
164+
handler := awsByocPromptHandler("test-cluster", &providerId)
165+
166+
req := mcp.GetPromptRequest{
167+
Params: mcp.GetPromptParams{
168+
Arguments: map[string]string{
169+
"AWS Credential": "AKIAEXAMPLEKEY1234",
170+
"AWS_SECRET_ACCESS_KEY": "secret",
171+
"AWS_REGION": "us-west-2",
172+
},
173+
},
174+
}
175+
176+
res, err := handler(t.Context(), req)
177+
require.Error(t, err)
178+
require.Nil(t, res)
179+
}

src/pkg/mcp/prompts/gcpBYOC.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,35 @@ import (
66

77
"github.com/DefangLabs/defang/src/pkg/cli"
88
"github.com/DefangLabs/defang/src/pkg/cli/client"
9+
cliClient "github.com/DefangLabs/defang/src/pkg/cli/client"
910
"github.com/DefangLabs/defang/src/pkg/mcp/tools"
1011
"github.com/mark3labs/mcp-go/mcp"
1112
"github.com/mark3labs/mcp-go/server"
1213
)
1314

14-
func setupGCPBYOPrompt(s *server.MCPServer, cluster string, providerId *client.ProviderID) {
15-
gcpBYOPrompt := mcp.NewPrompt("GCP Setup",
16-
mcp.WithPromptDescription("Setup for GCP"),
15+
// Patch points for testability
16+
var (
17+
Connect = cli.Connect
18+
CheckProviderConfigured = func(ctx context.Context, client cliClient.FabricClient, providerId cliClient.ProviderID, projectName string, serviceCount int) (cliClient.Provider, error) {
19+
return tools.CheckProviderConfigured(ctx, client, providerId, projectName, serviceCount)
20+
}
21+
)
1722

23+
func setupGcpByocPrompt(s *server.MCPServer, cluster string, providerId *client.ProviderID) {
24+
gcpBYOCPrompt := mcp.NewPrompt("GCP Setup",
25+
mcp.WithPromptDescription("Setup for GCP"),
1826
mcp.WithArgument("GCP_PROJECT_ID",
1927
mcp.ArgumentDescription("Your GCP Project ID"),
2028
mcp.RequiredArgument(),
2129
),
2230
)
2331

24-
s.AddPrompt(gcpBYOPrompt, func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
32+
s.AddPrompt(gcpBYOCPrompt, gcpByocPromptHandler(cluster, providerId))
33+
}
34+
35+
// gcpByocPromptHandler is extracted for testability
36+
func gcpByocPromptHandler(cluster string, providerId *client.ProviderID) func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
37+
return func(ctx context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
2538
// Can never be nil or empty due to RequiredArgument
2639
projectID := req.Params.Arguments["GCP_PROJECT_ID"]
2740

@@ -30,12 +43,12 @@ func setupGCPBYOPrompt(s *server.MCPServer, cluster string, providerId *client.P
3043
return nil, err
3144
}
3245

33-
fabric, err := cli.Connect(ctx, cluster)
46+
fabric, err := Connect(ctx, cluster)
3447
if err != nil {
3548
return nil, err
3649
}
3750

38-
_, err = tools.CheckProviderConfigured(ctx, fabric, client.ProviderGCP, "", 0)
51+
_, err = CheckProviderConfigured(ctx, fabric, client.ProviderGCP, "", 0)
3952
if err != nil {
4053
return nil, err
4154
}
@@ -57,5 +70,5 @@ func setupGCPBYOPrompt(s *server.MCPServer, cluster string, providerId *client.P
5770
},
5871
},
5972
}, nil
60-
})
73+
}
6174
}

0 commit comments

Comments
 (0)