Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions cmd/firebolt-mcp-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ func main() {
Usage: "SSE transport listen address (used only if transport is set to sse)",
Sources: cli.EnvVars("FIREBOLT_MCP_TRANSPORT_SSE_LISTEN_ADDRESS"),
},
&cli.BoolFlag{
Name: "disable-resources",
Category: "MCP Transport",
Value: false,
Usage: "Return text content instead of embedded resources (for clients that do not support resources)",
Sources: cli.EnvVars("FIREBOLT_MCP_DISABLE_RESOURCES"),
},
&cli.StringFlag{
Name: "client-id",
Category: "Firebolt Authentication",
Expand Down Expand Up @@ -123,6 +130,7 @@ func run(ctx context.Context, cmd *cli.Command) error {

// Initialize MCP server
docsProof := generateRandomSecret()
disableResources := cmd.Bool("disable-resources")
resourceDocs := resources.NewDocs(fireboltdocs.FS, docsProof)
resourceAccounts := resources.NewAccounts(discoveryClient)
resourceDatabases := resources.NewDatabases(dbPool)
Expand All @@ -133,8 +141,8 @@ func run(ctx context.Context, cmd *cli.Command) error {
cmd.String("transport"),
cmd.String("transport-sse-listen-address"),
[]server.Tool{
tools.NewConnect(resourceAccounts, resourceDatabases, resourceEngines, docsProof),
tools.NewDocs(resourceDocs),
tools.NewConnect(resourceAccounts, resourceDatabases, resourceEngines, docsProof, disableResources),
tools.NewDocs(resourceDocs, disableResources),
tools.NewQuery(dbPool),
},
[]server.Prompt{
Expand Down
15 changes: 6 additions & 9 deletions pkg/clients/database/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func NewPoolWithConnectionFactory(
}

type poolImpl struct {
sync.RWMutex
sync.Mutex
isClosed bool
logger *slog.Logger
closers []func()
Expand All @@ -75,6 +75,9 @@ type poolImpl struct {

func (p *poolImpl) GetConnection(params PoolParams) (Connection, error) {

p.Lock()
defer p.Unlock()

connectionParams := ConnectionParams{
ClientID: p.clientID,
ClientSecret: p.clientSecret,
Expand All @@ -84,29 +87,23 @@ func (p *poolImpl) GetConnection(params PoolParams) (Connection, error) {
}
hash := connectionParams.Hash()

// First, try to get an existing connection with a read lock
p.RLock()
// First, try to get an existing connection
if p.isClosed {
p.RUnlock()
return nil, ErrPoolClosed
}
if conn, ok := p.connections[hash]; ok {
p.RUnlock()
return conn, nil
}
p.RUnlock()

// Create a new connection if one doesn't exist
conn, closer, err := p.newConnectionFunc(p.logger, connectionParams)
if err != nil {
return nil, fmt.Errorf("failed to create connection: %w", err)
}

// Store the new connection in the pool with a write lock
p.Lock()
// Store the new connection in the pool
p.connections[hash] = conn
p.closers = append(p.closers, closer)
p.Unlock()

return conn, nil
}
Expand Down
5 changes: 4 additions & 1 deletion pkg/tools/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type Connect struct {
databasesFetcher DatabaseResourcesFetcher // Fetches database resources
enginesFetcher EngineResourcesFetcher // Fetches engine resources
docsProof string // Shared with the docs resources
disableResources bool // Return text content instead of embedded resources
}

// NewConnect creates a new instance of the Connect tool with the provided resource fetchers.
Expand All @@ -53,12 +54,14 @@ func NewConnect(
databasesFetcher DatabaseResourcesFetcher,
enginesFetcher EngineResourcesFetcher,
docsProof string,
disableResources bool,
) *Connect {
return &Connect{
accountsFetcher: accountsFetcher,
databasesFetcher: databasesFetcher,
enginesFetcher: enginesFetcher,
docsProof: docsProof,
disableResources: disableResources,
}
}

Expand Down Expand Up @@ -178,7 +181,7 @@ func (t *Connect) Handler(ctx context.Context, request mcp.CallToolRequest) (*mc
return &mcp.CallToolResult{
Result: mcp.Result{},
Content: itertools.Map(results, func(i mcp.ResourceContents) mcp.Content {
return mcp.NewEmbeddedResource(i)
return textOrResourceContent(t.disableResources, i)
}),
IsError: false,
}, nil
Expand Down
86 changes: 78 additions & 8 deletions pkg/tools/connect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ func createEngineResource(accountName, engineName string) mcp.ResourceContents {

func TestNewConnect(t *testing.T) {
mock := &MockResourceFetcher{}
connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
assert.NotNil(t, connectTool)
}

func TestConnect_Tool(t *testing.T) {
mock := &MockResourceFetcher{}
connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)

tool := connectTool.Tool()
assert.Equal(t, "firebolt_connect", tool.Name)
Expand Down Expand Up @@ -131,7 +131,7 @@ func TestConnect_Handler_Success(t *testing.T) {
}

// Create the tool
connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)

// Execute the handler
request := mcp.CallToolRequest{}
Expand Down Expand Up @@ -199,7 +199,7 @@ func TestConnect_Handler_AccountFetchFailure(t *testing.T) {
},
}

connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
Expand All @@ -225,7 +225,7 @@ func TestConnect_Handler_InvalidAccountResource(t *testing.T) {
},
}

connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
Expand All @@ -251,7 +251,7 @@ func TestConnect_Handler_InvalidAccountJSON(t *testing.T) {
},
}

connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
Expand All @@ -273,7 +273,7 @@ func TestConnect_Handler_DatabasesFetchFailure(t *testing.T) {
},
}

connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
Expand All @@ -298,7 +298,7 @@ func TestConnect_Handler_EnginesFetchFailure(t *testing.T) {
},
}

connectTool := tools.NewConnect(mock, mock, mock, validProof)
connectTool := tools.NewConnect(mock, mock, mock, validProof, false)
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
Expand All @@ -309,3 +309,73 @@ func TestConnect_Handler_EnginesFetchFailure(t *testing.T) {
assert.Contains(t, err.Error(), "failed to discover engine resources")
assert.Nil(t, result)
}

func TestConnect_Handler_DisableResources(t *testing.T) {
// Create test data
accounts := []string{"account1"}
databases := map[string][]string{
"account1": {"db1"},
}
engines := map[string][]string{
"account1": {"engine1"},
}

// Create mock fetcher
mock := &MockResourceFetcher{
AccountsFunc: func(ctx context.Context, accountName string) ([]mcp.ResourceContents, error) {
var resources []mcp.ResourceContents
for _, acc := range accounts {
resources = append(resources, createAccountResource(acc))
}
return resources, nil
},
DatabasesFunc: func(ctx context.Context, accountName, databaseName string) ([]mcp.ResourceContents, error) {
var resources []mcp.ResourceContents
for _, db := range databases[accountName] {
resources = append(resources, createDatabaseResource(accountName, db))
}
return resources, nil
},
EnginesFunc: func(ctx context.Context, accountName, engineName string) ([]mcp.ResourceContents, error) {
var resources []mcp.ResourceContents
for _, eng := range engines[accountName] {
resources = append(resources, createEngineResource(accountName, eng))
}
return resources, nil
},
}

// Create the tool with disableResources set to true
connectTool := tools.NewConnect(mock, mock, mock, validProof, true)

// Execute the handler
request := mcp.CallToolRequest{}
request.Params.Arguments = map[string]any{
"docs_proof": validProof,
}
result, err := connectTool.Handler(t.Context(), request)

// Assertions
require.NoError(t, err)
require.NotNil(t, result)
assert.False(t, result.IsError)

// Calculate expected total resources
expectedCount := len(accounts) // accounts
for _, dbs := range databases {
expectedCount += len(dbs) // databases
}
for _, engs := range engines {
expectedCount += len(engs) // engines
}

// Check if we got the expected number of resources
assert.Len(t, result.Content, expectedCount)

// Verify the content contains text content instead of embedded resources
for _, content := range result.Content {
textContent, ok := content.(mcp.TextContent)
require.True(t, ok, "Expected TextContent when disableResources is true")
assert.NotEmpty(t, textContent.Text)
}
}
10 changes: 6 additions & 4 deletions pkg/tools/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@ type DocsResourcesFetcher interface {
// Docs represents a tool for fetching and returning Firebolt documentation.
// It provides access to documentation articles that explain Firebolt concepts and functionality.
type Docs struct {
docsFetcher DocsResourcesFetcher // Fetches documentation resources
docsFetcher DocsResourcesFetcher // Fetches documentation resources
disableResources bool // Return text content instead of embedded resources
}

// NewDocs creates a new instance of the Docs tool with the provided documentation fetcher.
// It requires an implementation for fetching documentation articles.
func NewDocs(docsFetcher DocsResourcesFetcher) *Docs {
func NewDocs(docsFetcher DocsResourcesFetcher, disableResources bool) *Docs {
return &Docs{
docsFetcher: docsFetcher,
docsFetcher: docsFetcher,
disableResources: disableResources,
}
}

Expand Down Expand Up @@ -103,7 +105,7 @@ func (t *Docs) Handler(ctx context.Context, request mcp.CallToolRequest) (*mcp.C
return &mcp.CallToolResult{
Result: mcp.Result{},
Content: itertools.Map(results, func(i mcp.ResourceContents) mcp.Content {
return mcp.NewEmbeddedResource(i)
return textOrResourceContent(t.disableResources, i)
}),
IsError: false,
}, nil
Expand Down
Loading