@@ -249,6 +249,118 @@ async def grade_story_async(story: str, app_ctx: Optional[AppContext] = None) ->
249249
250250 return result
251251
252+ # Add custom tool to get token usage for a workflow
253+ @mcp .tool (
254+ name = "get_token_usage" ,
255+ structured_output = True ,
256+ description = """
257+ Get detailed token usage information for a specific workflow run.
258+ This provides a comprehensive breakdown of token usage including:
259+ - Total tokens used across all LLM calls within the workflow
260+ - Breakdown by model provider and specific models
261+ - Hierarchical usage tree showing usage at each level (workflow -> agent -> llm)
262+ - Total cost estimate based on model pricing
263+ Args:
264+ workflow_id: Optional workflow ID (if multiple workflows have the same name)
265+ run_id: Optional ID of the workflow run to get token usage for
266+ workflow_name: Optional name of the workflow (used as fallback)
267+ Returns:
268+ Detailed token usage information for the specific workflow run
269+ """ ,
270+ )
271+ async def get_workflow_token_usage (
272+ workflow_id : str | None = None ,
273+ run_id : str | None = None ,
274+ workflow_name : str | None = None ,
275+ ) -> Dict [str , Any ]:
276+ """Get token usage information for a specific workflow run."""
277+ context = app .context
278+
279+ if not context .token_counter :
280+ return {
281+ "error" : "Token counter not available" ,
282+ "message" : "Token tracking is not enabled for this application" ,
283+ }
284+
285+ # Find the specific workflow node
286+ workflow_node = await context .token_counter .get_workflow_node (
287+ name = workflow_name , workflow_id = workflow_id , run_id = run_id
288+ )
289+
290+ if not workflow_node :
291+ return {
292+ "error" : "Workflow not found" ,
293+ "message" : f"Could not find workflow with run_id='{ run_id } '" ,
294+ }
295+
296+ # Get the aggregated usage for this workflow
297+ workflow_usage = workflow_node .aggregate_usage ()
298+
299+ # Calculate cost for this workflow
300+ workflow_cost = context .token_counter ._calculate_node_cost (workflow_node )
301+
302+ # Build the response
303+ result = {
304+ "workflow" : {
305+ "name" : workflow_node .name ,
306+ "run_id" : workflow_node .metadata .get ("run_id" ),
307+ "workflow_id" : workflow_node .metadata .get ("workflow_id" ),
308+ },
309+ "usage" : {
310+ "input_tokens" : workflow_usage .input_tokens ,
311+ "output_tokens" : workflow_usage .output_tokens ,
312+ "total_tokens" : workflow_usage .total_tokens ,
313+ },
314+ "cost" : round (workflow_cost , 4 ),
315+ "model_breakdown" : {},
316+ "usage_tree" : workflow_node .to_dict (),
317+ }
318+
319+ # Get model breakdown for this workflow
320+ model_usage = {}
321+
322+ def collect_model_usage (node : TokenNode ):
323+ """Recursively collect model usage from a node tree"""
324+ if node .usage .model_name :
325+ model_name = node .usage .model_name
326+ provider = node .usage .model_info .provider if node .usage .model_info else None
327+
328+ # Use tuple as key to handle same model from different providers
329+ model_key = (model_name , provider )
330+
331+ if model_key not in model_usage :
332+ model_usage [model_key ] = {
333+ "model_name" : model_name ,
334+ "provider" : provider ,
335+ "input_tokens" : 0 ,
336+ "output_tokens" : 0 ,
337+ "total_tokens" : 0 ,
338+ }
339+
340+ model_usage [model_key ]["input_tokens" ] += node .usage .input_tokens
341+ model_usage [model_key ]["output_tokens" ] += node .usage .output_tokens
342+ model_usage [model_key ]["total_tokens" ] += node .usage .total_tokens
343+
344+ for child in node .children :
345+ collect_model_usage (child )
346+
347+ collect_model_usage (workflow_node )
348+
349+ # Calculate costs for each model and format for output
350+ for (model_name , provider ), usage in model_usage .items ():
351+ cost = context .token_counter .calculate_cost (
352+ model_name , usage ["input_tokens" ], usage ["output_tokens" ], provider
353+ )
354+
355+ # Create display key with provider info if available
356+ display_key = f"{ model_name } ({ provider } )" if provider else model_name
357+
358+ result ["model_breakdown" ][display_key ] = {
359+ ** usage ,
360+ "cost" : round (cost , 4 ),
361+ }
362+
363+ return result
252364
253365async def main ():
254366 parser = argparse .ArgumentParser ()
0 commit comments