Skip to content

Commit bbb9d73

Browse files
committed
Refactor cache, dedupe code
1 parent ad9cab3 commit bbb9d73

File tree

2 files changed

+16
-26
lines changed

2 files changed

+16
-26
lines changed

contrib/mark3labs/mcp-go/hooks.go

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,27 @@ func (h *Hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.Ini
6969
h.spanCache.Set(id, taskSpan, ttlcache.DefaultTTL)
7070
}
7171

72-
// onAfterInitialize is called after initialization is successfully executed.
73-
func (h *Hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) {
72+
// finishSpanWithIO retrieves a span from cache, annotates it with input/output, and finishes it.
73+
func finishSpanWithIO[Req any, Res any](h *Hooks, id any, request Req, result Res) {
7474
if item := h.spanCache.Get(id); item != nil {
7575
span := item.Value()
76-
if taskSpan, ok := span.(*llmobs.TaskSpan); ok {
76+
if annotator, ok := span.(textIOAnnotator); ok {
7777
inputJSON, _ := json.Marshal(request)
78-
var outputText string
79-
if result != nil {
80-
resultJSON, _ := json.Marshal(result)
81-
outputText = string(resultJSON)
82-
}
78+
resultJSON, _ := json.Marshal(result)
79+
outputText := string(resultJSON)
8380

84-
taskSpan.AnnotateTextIO(string(inputJSON), outputText)
85-
taskSpan.Finish()
81+
annotator.AnnotateTextIO(string(inputJSON), outputText)
82+
span.Finish()
8683
}
8784
h.spanCache.Delete(id)
8885
}
8986
}
9087

88+
// onAfterInitialize is called after initialization is successfully executed.
89+
func (h *Hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) {
90+
finishSpanWithIO(h, id, request, result)
91+
}
92+
9193
// onBeforeCallTool is called before a tool is executed.
9294
func (h *Hooks) onBeforeCallTool(ctx context.Context, id any, request *mcp.CallToolRequest) {
9395
toolSpan, _ := llmobs.StartToolSpan(ctx, request.Params.Name)
@@ -96,21 +98,7 @@ func (h *Hooks) onBeforeCallTool(ctx context.Context, id any, request *mcp.CallT
9698

9799
// onAfterCallTool is called after a tool is successfully executed.
98100
func (h *Hooks) onAfterCallTool(ctx context.Context, id any, request *mcp.CallToolRequest, result *mcp.CallToolResult) {
99-
if item := h.spanCache.Get(id); item != nil {
100-
span := item.Value()
101-
if toolSpan, ok := span.(*llmobs.ToolSpan); ok {
102-
inputJSON, _ := json.Marshal(request)
103-
var outputText string
104-
if result != nil {
105-
resultJSON, _ := json.Marshal(result)
106-
outputText = string(resultJSON)
107-
}
108-
109-
toolSpan.AnnotateTextIO(string(inputJSON), outputText)
110-
toolSpan.Finish()
111-
}
112-
h.spanCache.Delete(id)
113-
}
101+
finishSpanWithIO(h, id, request, result)
114102
}
115103

116104
// textIOAnnotator mirrors the internal textIOSpan interface

contrib/mark3labs/mcp-go/hooks_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ func TestNewHooks(t *testing.T) {
2727

2828
hooks := NewHooks()
2929
assert.NotNil(t, hooks)
30-
assert.NotNil(t, hooks.toolCache)
30+
assert.NotNil(t, hooks.spanCache)
3131

3232
hooks.Stop()
3333
}
@@ -43,6 +43,8 @@ func TestAddHooks(t *testing.T) {
4343
ddHooks.AddHooks(serverHooks)
4444

4545
// Verify hooks were added
46+
assert.Len(t, serverHooks.OnBeforeInitialize, 1)
47+
assert.Len(t, serverHooks.OnAfterInitialize, 1)
4648
assert.Len(t, serverHooks.OnBeforeCallTool, 1)
4749
assert.Len(t, serverHooks.OnAfterCallTool, 1)
4850
assert.Len(t, serverHooks.OnError, 1)

0 commit comments

Comments
 (0)