Skip to content

Commit 00f2970

Browse files
Re-adding the fastmcp @mcp.tool call
1 parent 200fc15 commit 00f2970

File tree

1 file changed

+112
-0
lines changed
  • examples/mcp_agent_server/asyncio

1 file changed

+112
-0
lines changed

examples/mcp_agent_server/asyncio/main.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,118 @@ async def grade_story_async(story: str, app_ctx: Optional[AppContext] = None) ->
249249

250250
return result
251251

252+
# Add custom tool to get token usage for a workflow
253+
@mcp.tool(
254+
name="get_token_usage",
255+
structured_output=True,
256+
description="""
257+
Get detailed token usage information for a specific workflow run.
258+
This provides a comprehensive breakdown of token usage including:
259+
- Total tokens used across all LLM calls within the workflow
260+
- Breakdown by model provider and specific models
261+
- Hierarchical usage tree showing usage at each level (workflow -> agent -> llm)
262+
- Total cost estimate based on model pricing
263+
Args:
264+
workflow_id: Optional workflow ID (if multiple workflows have the same name)
265+
run_id: Optional ID of the workflow run to get token usage for
266+
workflow_name: Optional name of the workflow (used as fallback)
267+
Returns:
268+
Detailed token usage information for the specific workflow run
269+
""",
270+
)
271+
async def get_workflow_token_usage(
272+
workflow_id: str | None = None,
273+
run_id: str | None = None,
274+
workflow_name: str | None = None,
275+
) -> Dict[str, Any]:
276+
"""Get token usage information for a specific workflow run."""
277+
context = app.context
278+
279+
if not context.token_counter:
280+
return {
281+
"error": "Token counter not available",
282+
"message": "Token tracking is not enabled for this application",
283+
}
284+
285+
# Find the specific workflow node
286+
workflow_node = await context.token_counter.get_workflow_node(
287+
name=workflow_name, workflow_id=workflow_id, run_id=run_id
288+
)
289+
290+
if not workflow_node:
291+
return {
292+
"error": "Workflow not found",
293+
"message": f"Could not find workflow with run_id='{run_id}'",
294+
}
295+
296+
# Get the aggregated usage for this workflow
297+
workflow_usage = workflow_node.aggregate_usage()
298+
299+
# Calculate cost for this workflow
300+
workflow_cost = context.token_counter._calculate_node_cost(workflow_node)
301+
302+
# Build the response
303+
result = {
304+
"workflow": {
305+
"name": workflow_node.name,
306+
"run_id": workflow_node.metadata.get("run_id"),
307+
"workflow_id": workflow_node.metadata.get("workflow_id"),
308+
},
309+
"usage": {
310+
"input_tokens": workflow_usage.input_tokens,
311+
"output_tokens": workflow_usage.output_tokens,
312+
"total_tokens": workflow_usage.total_tokens,
313+
},
314+
"cost": round(workflow_cost, 4),
315+
"model_breakdown": {},
316+
"usage_tree": workflow_node.to_dict(),
317+
}
318+
319+
# Get model breakdown for this workflow
320+
model_usage = {}
321+
322+
def collect_model_usage(node: TokenNode):
323+
"""Recursively collect model usage from a node tree"""
324+
if node.usage.model_name:
325+
model_name = node.usage.model_name
326+
provider = node.usage.model_info.provider if node.usage.model_info else None
327+
328+
# Use tuple as key to handle same model from different providers
329+
model_key = (model_name, provider)
330+
331+
if model_key not in model_usage:
332+
model_usage[model_key] = {
333+
"model_name": model_name,
334+
"provider": provider,
335+
"input_tokens": 0,
336+
"output_tokens": 0,
337+
"total_tokens": 0,
338+
}
339+
340+
model_usage[model_key]["input_tokens"] += node.usage.input_tokens
341+
model_usage[model_key]["output_tokens"] += node.usage.output_tokens
342+
model_usage[model_key]["total_tokens"] += node.usage.total_tokens
343+
344+
for child in node.children:
345+
collect_model_usage(child)
346+
347+
collect_model_usage(workflow_node)
348+
349+
# Calculate costs for each model and format for output
350+
for (model_name, provider), usage in model_usage.items():
351+
cost = context.token_counter.calculate_cost(
352+
model_name, usage["input_tokens"], usage["output_tokens"], provider
353+
)
354+
355+
# Create display key with provider info if available
356+
display_key = f"{model_name} ({provider})" if provider else model_name
357+
358+
result["model_breakdown"][display_key] = {
359+
**usage,
360+
"cost": round(cost, 4),
361+
}
362+
363+
return result
252364

253365
async def main():
254366
parser = argparse.ArgumentParser()

0 commit comments

Comments
 (0)