|
26 | 26 | from llama_stack_api.openai_responses import ( |
27 | 27 | OpenAIResponseInputToolChoice as ToolChoice, |
28 | 28 | ) |
| 29 | +from llama_stack_api.openai_responses import ( |
| 30 | + OpenAIResponseInputToolChoiceAllowedTools as AllowedTools, |
| 31 | +) |
29 | 32 | from llama_stack_api.openai_responses import ( |
30 | 33 | OpenAIResponseInputToolChoiceMode as ToolChoiceMode, |
31 | 34 | ) |
@@ -417,6 +420,162 @@ def extract_vector_store_ids_from_tools( |
417 | 420 | return vector_store_ids |
418 | 421 |
|
419 | 422 |
|
| 423 | +def tool_matches_allowed_entry(tool: InputTool, entry: dict[str, str]) -> bool: |
| 424 | + """Return True if the tool satisfies every key in the allowlist entry. |
| 425 | +
|
| 426 | + Parameters: |
| 427 | + tool: A configured input tool. |
| 428 | + entry: One allowlist entry from allowed_tools.tools. |
| 429 | +
|
| 430 | + Returns: |
| 431 | + True if all entry keys match the tool. |
| 432 | + """ |
| 433 | + for key, value in entry.items(): |
| 434 | + if not hasattr(tool, key): |
| 435 | + return False |
| 436 | + attr = getattr(tool, key) |
| 437 | + if attr is None: |
| 438 | + return False |
| 439 | + if attr != value and str(attr) != value: |
| 440 | + return False |
| 441 | + return True |
| 442 | + |
| 443 | + |
| 444 | +def group_mcp_tools_by_server( |
| 445 | + entries: list[dict[str, str]], |
| 446 | +) -> dict[str, Optional[list[str]]]: |
| 447 | + """Group MCP tool filters by server_label. |
| 448 | +
|
| 449 | + Rules: |
| 450 | + - Non-MCP entries are ignored. |
| 451 | + - Entries without server_label are ignored. |
| 452 | + - If any entry for a server has no "name", that server is unrestricted (None). |
| 453 | + - Otherwise, collect unique tool names in first-seen order. |
| 454 | +
|
| 455 | + Returns: |
| 456 | + Dict mapping: |
| 457 | + server_label -> None (unrestricted) OR list of allowed tool names |
| 458 | + """ |
| 459 | + unrestricted_servers: set[str] = set() |
| 460 | + server_to_names: dict[str, list[str]] = {} |
| 461 | + for entry in entries: |
| 462 | + if entry.get("type") != "mcp": |
| 463 | + continue |
| 464 | + server = entry.get("server_label") |
| 465 | + if not server: |
| 466 | + continue |
| 467 | + # Unrestricted entry (no "name") |
| 468 | + if "name" not in entry: |
| 469 | + unrestricted_servers.add(server) |
| 470 | + continue |
| 471 | + # Skip collecting names if already unrestricted |
| 472 | + if server in unrestricted_servers: |
| 473 | + continue |
| 474 | + name = entry["name"] |
| 475 | + if server not in server_to_names: |
| 476 | + server_to_names[server] = [] |
| 477 | + |
| 478 | + if name not in server_to_names[server]: |
| 479 | + server_to_names[server].append(name) |
| 480 | + |
| 481 | + # Build final result |
| 482 | + result: dict[str, Optional[list[str]]] = {} |
| 483 | + for server in unrestricted_servers: |
| 484 | + result[server] = None |
| 485 | + |
| 486 | + for server, names in server_to_names.items(): |
| 487 | + if server not in unrestricted_servers: |
| 488 | + result[server] = names |
| 489 | + |
| 490 | + return result |
| 491 | + |
| 492 | + |
| 493 | +def mcp_strip_name_from_allowlist_entries( |
| 494 | + allowed_entries: list[dict[str, str]], |
| 495 | +) -> list[dict[str, str]]: |
| 496 | + """Return a copy of entries where 'name' is removed only for MCP entries.""" |
| 497 | + result: list[dict[str, str]] = [] |
| 498 | + for entry in allowed_entries: |
| 499 | + new_entry = entry.copy() |
| 500 | + if new_entry.get("type") == "mcp": |
| 501 | + new_entry.pop("name", None) |
| 502 | + |
| 503 | + result.append(new_entry) |
| 504 | + |
| 505 | + return result |
| 506 | + |
| 507 | + |
| 508 | +def mcp_project_allowed_tools_to_names( |
| 509 | + tool: InputToolMCP, names: list[str] |
| 510 | +) -> list[str] | None: |
| 511 | + """Intersect narrowed names with what the MCP tool already permits. |
| 512 | +
|
| 513 | + Returns: |
| 514 | + List of permitted tool names, or None if the intersection is empty. |
| 515 | + """ |
| 516 | + if not names: |
| 517 | + return None |
| 518 | + name_set = set(names) |
| 519 | + allowed = tool.allowed_tools |
| 520 | + if allowed is None: |
| 521 | + permitted = name_set |
| 522 | + elif isinstance(allowed, list): |
| 523 | + permitted = name_set & set(allowed) |
| 524 | + else: |
| 525 | + if allowed.tool_names is None: |
| 526 | + permitted = name_set |
| 527 | + else: |
| 528 | + permitted = name_set & set(allowed.tool_names) |
| 529 | + |
| 530 | + if not permitted: |
| 531 | + return None |
| 532 | + |
| 533 | + return list(permitted) |
| 534 | + |
| 535 | + |
| 536 | +def filter_tools_by_allowed_entries( |
| 537 | + tools: list[InputTool], |
| 538 | + allowed_entries: list[dict[str, str]], |
| 539 | +) -> list[InputTool]: |
| 540 | + """Filter tools based on allowlist entries. |
| 541 | +
|
| 542 | + - Keeps tools matching at least one entry. |
| 543 | + - Applies MCP name narrowing when applicable. |
| 544 | + """ |
| 545 | + if not allowed_entries: |
| 546 | + return [] |
| 547 | + |
| 548 | + mcp_names_by_server = group_mcp_tools_by_server(allowed_entries) |
| 549 | + sanitized_entries = mcp_strip_name_from_allowlist_entries(allowed_entries) |
| 550 | + filtered: list[InputTool] = [] |
| 551 | + for tool in tools: |
| 552 | + # Skip tools not matching any allowlist entry |
| 553 | + if not any(tool_matches_allowed_entry(tool, e) for e in sanitized_entries): |
| 554 | + continue |
| 555 | + # Non-MCP tools pass through and are handled separately |
| 556 | + if tool.type != "mcp": |
| 557 | + filtered.append(tool) |
| 558 | + continue |
| 559 | + |
| 560 | + mcp_tool = cast(InputToolMCP, tool) |
| 561 | + server = mcp_tool.server_label |
| 562 | + |
| 563 | + narrowed_names = mcp_names_by_server.get(server) |
| 564 | + # No filters specified for this MCP server |
| 565 | + if narrowed_names is None: |
| 566 | + filtered.append(tool) |
| 567 | + continue |
| 568 | + |
| 569 | + # Apply intersection |
| 570 | + permitted = mcp_project_allowed_tools_to_names(mcp_tool, narrowed_names) |
| 571 | + if permitted is None: |
| 572 | + continue |
| 573 | + |
| 574 | + filtered.append(mcp_tool.model_copy(update={"allowed_tools": permitted})) |
| 575 | + |
| 576 | + return filtered |
| 577 | + |
| 578 | + |
420 | 579 | def resolve_vector_store_ids( |
421 | 580 | vector_store_ids: list[str], byok_rags: list[ByokRag] |
422 | 581 | ) -> list[str]: |
@@ -1330,54 +1489,69 @@ async def resolve_tool_choice( |
1330 | 1489 | mcp_headers: Optional[McpHeaders] = None, |
1331 | 1490 | request_headers: Optional[Mapping[str, str]] = None, |
1332 | 1491 | ) -> tuple[Optional[list[InputTool]], Optional[ToolChoice], Optional[list[str]]]: |
1333 | | - """Resolve tools and tool_choice for the Responses API. |
| 1492 | + """Resolve tools and tool choice for the Responses API. |
| 1493 | +
|
| 1494 | + When tool choice disables tools, always return Nones so Llama Stack |
| 1495 | + sees no tools, even if the request listed tools. |
1334 | 1496 |
|
1335 | | - If the request includes tools, uses them as-is and derives vector_store_ids |
1336 | | - from tool configs; otherwise loads tools via prepare_tools (using all |
1337 | | - configured vector stores) and honors tool_choice "none" via the no_tools |
1338 | | - flag. When no tools end up configured, tool_choice is cleared to None. |
| 1497 | + Allowed-tools mode: filter tools to the allowlist and narrow tool choice to |
| 1498 | + auto or required from the allowlist mode. |
| 1499 | +
|
| 1500 | + Otherwise: use request tools (with filtering) and derive vector store IDs, or |
| 1501 | + load tools via prepare_tools, then filter. Clear tool choice when no tools |
| 1502 | + remain. |
1339 | 1503 |
|
1340 | 1504 | Args: |
1341 | | - tools: Tools from the request, or None to use LCORE-configured tools. |
1342 | | - tool_choice: Requested tool choice (e.g. auto, required, none) or None. |
1343 | | - token: User token for MCP/auth. |
1344 | | - mcp_headers: Optional MCP headers to propagate. |
1345 | | - request_headers: Optional request headers for tool resolution. |
| 1505 | + tools: Request tools, or None for LCORE-configured tools. |
| 1506 | + tool_choice: Requested strategy, or None. |
| 1507 | + token: User token for MCP and auth. |
| 1508 | + mcp_headers: Optional MCP headers. |
| 1509 | + request_headers: Optional headers for tool resolution. |
1346 | 1510 |
|
1347 | 1511 | Returns: |
1348 | | - A tuple of (prepared_tools, prepared_tool_choice, vector_store_ids): |
1349 | | - prepared_tools is the list of tools to use, or None if none configured; |
1350 | | - prepared_tool_choice is the resolved tool choice, or None when there |
1351 | | - are no tools; vector_store_ids is extracted from tools (in user-facing format) |
1352 | | - when provided, otherwise None. |
| 1512 | + Prepared tools, resolved tool choice, and vector store IDs (user-facing), |
| 1513 | + each possibly None. |
1353 | 1514 | """ |
| 1515 | + # If tool_choice is "none", no tools are allowed |
| 1516 | + if isinstance(tool_choice, ToolChoiceMode) and tool_choice == ToolChoiceMode.none: |
| 1517 | + return None, None, None |
| 1518 | + |
| 1519 | + # Extract the allowed filters if specified and overwrite tool choice mode |
| 1520 | + allowed_filters: Optional[list[dict[str, str]]] = None |
| 1521 | + if isinstance(tool_choice, AllowedTools): |
| 1522 | + allowed_filters = tool_choice.tools |
| 1523 | + tool_choice = ToolChoiceMode(tool_choice.mode) |
| 1524 | + |
1354 | 1525 | prepared_tools: Optional[list[InputTool]] = None |
1355 | | - client = AsyncLlamaStackClientHolder().get_client() |
1356 | | - if tools: # explicitly specified in request |
1357 | | - # Per-request override of vector stores (user-facing rag_ids) |
1358 | | - vector_store_ids = extract_vector_store_ids_from_tools(tools) |
1359 | | - # Translate user-facing rag_ids to llama-stack vector_store_ids in each file_search tool |
| 1526 | + if tools is not None: # explicitly specified in request |
1360 | 1527 | byok_rags = configuration.configuration.byok_rag |
1361 | 1528 | prepared_tools = translate_tools_vector_store_ids(tools, byok_rags) |
| 1529 | + if allowed_filters is not None: |
| 1530 | + prepared_tools = filter_tools_by_allowed_entries( |
| 1531 | + prepared_tools, allowed_filters |
| 1532 | + ) |
| 1533 | + if not prepared_tools: |
| 1534 | + return None, None, None |
| 1535 | + vector_store_ids_list = extract_vector_store_ids_from_tools(prepared_tools) |
| 1536 | + vector_store_ids = vector_store_ids_list if vector_store_ids_list else None |
1362 | 1537 | prepared_tool_choice = tool_choice or ToolChoiceMode.auto |
1363 | 1538 | else: |
1364 | | - # Vector stores were not overwritten in request, use all configured vector stores |
1365 | 1539 | vector_store_ids = None |
1366 | | - # Get all tools configured in LCORE (returns None or non-empty list) |
1367 | | - no_tools = ( |
1368 | | - isinstance(tool_choice, ToolChoiceMode) |
1369 | | - and tool_choice == ToolChoiceMode.none |
1370 | | - ) |
1371 | | - # Vector stores are prepared in llama-stack format |
| 1540 | + client = AsyncLlamaStackClientHolder().get_client() |
1372 | 1541 | prepared_tools = await prepare_tools( |
1373 | 1542 | client=client, |
1374 | | - vector_store_ids=vector_store_ids, # allow all configured vector stores |
1375 | | - no_tools=no_tools, |
| 1543 | + vector_store_ids=vector_store_ids, |
| 1544 | + no_tools=False, |
1376 | 1545 | token=token, |
1377 | 1546 | mcp_headers=mcp_headers, |
1378 | 1547 | request_headers=request_headers, |
1379 | 1548 | ) |
1380 | | - # If there are no tools, tool_choice cannot be set at all - LLS implicit behavior |
| 1549 | + if allowed_filters is not None and prepared_tools: |
| 1550 | + prepared_tools = filter_tools_by_allowed_entries( |
| 1551 | + prepared_tools, allowed_filters |
| 1552 | + ) |
| 1553 | + if not prepared_tools: |
| 1554 | + prepared_tools = None |
1381 | 1555 | prepared_tool_choice = tool_choice if prepared_tools else None |
1382 | 1556 |
|
1383 | 1557 | return prepared_tools, prepared_tool_choice, vector_store_ids |
0 commit comments