|
5 | 5 | import httpx |
6 | 6 | import pickle |
7 | 7 | import yaml |
8 | | -from typing import Optional |
9 | | - |
| 8 | +import logging |
10 | 9 | from pathlib import Path |
11 | | -from typing import Any, Union, List |
| 10 | +from typing import Union, List, Any, Optional |
12 | 11 | import json |
13 | 12 | from importlib.resources import files, as_file |
14 | 13 | from pydantic_ai.models.openai import OpenAIChatModel |
15 | | -from pydantic_ai.models.anthropic import AnthropicModel |
16 | 14 | from pydantic_ai.models.google import GoogleModel |
17 | 15 | from pydantic_ai.models.huggingface import HuggingFaceModel |
18 | 16 | from pydantic_ai.models.groq import GroqModel |
19 | 17 | from pydantic_ai.models.mistral import MistralModel |
20 | 18 | from fasta2a import Skill |
21 | 19 |
|
| 20 | +from llama_index.core.embeddings import BaseEmbedding |
| 21 | +from llama_index.embeddings.openai import OpenAIEmbedding |
| 22 | + |
22 | 23 | try: |
| 24 | + from llama_index.embeddings.ollama import OllamaEmbedding |
| 25 | +except ImportError: |
| 26 | + OllamaEmbedding = None |
23 | 27 |
|
| 28 | +try: |
24 | 29 | from openai import AsyncOpenAI |
25 | 30 | from pydantic_ai.providers.openai import OpenAIProvider |
26 | 31 | except ImportError: |
|
42 | 47 | MistralProvider = None |
43 | 48 |
|
44 | 49 | try: |
| 50 | + from pydantic_ai.models.anthropic import AnthropicModel |
45 | 51 | from anthropic import AsyncAnthropic |
46 | 52 | from pydantic_ai.providers.anthropic import AnthropicProvider |
47 | 53 | except ImportError: |
| 54 | + AnthropicModel = None |
48 | 55 | AsyncAnthropic = None |
49 | 56 | AnthropicProvider = None |
50 | 57 |
|
51 | 58 |
|
52 | | -from llama_index.core.embeddings import BaseEmbedding |
53 | | -from llama_index.embeddings.openai import OpenAIEmbedding |
54 | | - |
55 | | -try: |
56 | | - from llama_index.embeddings.ollama import OllamaEmbedding |
57 | | -except ImportError: |
58 | | - OllamaEmbedding = None |
| 59 | +logger = logging.getLogger(__name__) |
59 | 60 |
|
60 | 61 |
|
61 | 62 | def to_integer(string: Union[str, int] = None) -> int: |
@@ -373,6 +374,74 @@ def create_model( |
373 | 374 | return OpenAIChatModel(model_name=model_id, provider="openai") |
374 | 375 |
|
375 | 376 |
|
| 377 | +def extract_tool_tags(tool_def: Any) -> List[str]: |
| 378 | + """ |
| 379 | + Extracts tags from a tool definition object. |
| 380 | +
|
| 381 | + Found structure in debug: |
| 382 | + tool_def.name (str) |
| 383 | + tool_def.meta (dict) -> {'fastmcp': {'tags': ['tag']}} |
| 384 | +
|
| 385 | + This function checks multiple paths to be robust: |
| 386 | + 1. tool_def.meta['fastmcp']['tags'] |
| 387 | + 2. tool_def.meta['tags'] |
| 388 | + 3. tool_def.metadata['tags'] (legacy/alternative wrapper) |
| 389 | + 4. tool_def.metadata.get('meta')... (nested path) |
| 390 | + """ |
| 391 | + tags_list = [] |
| 392 | + |
| 393 | + meta = getattr(tool_def, "meta", None) |
| 394 | + if isinstance(meta, dict): |
| 395 | + fastmcp = meta.get("fastmcp") or meta.get("_fastmcp") or {} |
| 396 | + tags_list = fastmcp.get("tags", []) |
| 397 | + if tags_list: |
| 398 | + return tags_list |
| 399 | + |
| 400 | + tags_list = meta.get("tags", []) |
| 401 | + if tags_list: |
| 402 | + return tags_list |
| 403 | + |
| 404 | + metadata = getattr(tool_def, "metadata", None) |
| 405 | + if isinstance(metadata, dict): |
| 406 | + tags_list = metadata.get("tags", []) |
| 407 | + if tags_list: |
| 408 | + return tags_list |
| 409 | + |
| 410 | + meta_nested = metadata.get("meta") or {} |
| 411 | + fastmcp = meta_nested.get("fastmcp") or meta_nested.get("_fastmcp") or {} |
| 412 | + tags_list = fastmcp.get("tags", []) |
| 413 | + if tags_list: |
| 414 | + return tags_list |
| 415 | + |
| 416 | + tags_list = meta_nested.get("tags", []) |
| 417 | + if tags_list: |
| 418 | + return tags_list |
| 419 | + |
| 420 | + tags_list = getattr(tool_def, "tags", []) |
| 421 | + if isinstance(tags_list, list) and tags_list: |
| 422 | + return tags_list |
| 423 | + |
| 424 | + return [] |
| 425 | + |
| 426 | + |
| 427 | +def tool_in_tag(tool_def: Any, tag: str) -> bool: |
| 428 | + """ |
| 429 | + Checks if a tool belongs to a specific tag. |
| 430 | + """ |
| 431 | + tool_tags = extract_tool_tags(tool_def) |
| 432 | + if tag in tool_tags: |
| 433 | + return True |
| 434 | + else: |
| 435 | + return False |
| 436 | + |
| 437 | + |
| 438 | +def filter_tools_by_tag(tools: List[Any], tag: str) -> List[Any]: |
| 439 | + """ |
| 440 | + Filters a list of tools for a given tag. |
| 441 | + """ |
| 442 | + return [t for t in tools if tool_in_tag(t, tag)] |
| 443 | + |
| 444 | + |
376 | 445 | def get_embedding_model() -> BaseEmbedding: |
377 | 446 | """ |
378 | 447 | Get the embedding model based on environment variables. |
|
0 commit comments