|
29 | 29 | MessageWrapper, |
30 | 30 | ModelMiddleware, |
31 | 31 | ) |
| 32 | +from genkit.blocks.resource import ResourceInput, find_matching_resource, resolve_resources |
32 | 33 | from genkit.blocks.tools import ToolInterruptError |
33 | 34 | from genkit.codec import dump_dict |
34 | 35 | from genkit.core.action import ActionRunContext |
35 | 36 | from genkit.core.error import GenkitError, StatusName |
36 | 37 | from genkit.core.registry import Action, ActionKind, Registry |
37 | 38 | from genkit.core.typing import ( |
| 39 | + DocumentData, |
| 40 | + DocumentPart, |
38 | 41 | GenerateActionOptions, |
39 | 42 | GenerateRequest, |
40 | 43 | GenerateResponse, |
@@ -101,6 +104,9 @@ async def generate_action( |
101 | 104 |
|
102 | 105 | raw_request, formatter = apply_format(raw_request, format_def) |
103 | 106 |
|
| 107 | + if raw_request.resources: |
| 108 | + raw_request = await apply_resources(registry, raw_request) |
| 109 | + |
104 | 110 | assert_valid_tool_names(tools) |
105 | 111 |
|
106 | 112 | ( |
@@ -412,6 +418,131 @@ def apply_transfer_preamble(next_request: GenerateActionOptions, preamble: Gener |
412 | 418 | return next_request |
413 | 419 |
|
414 | 420 |
|
| 421 | +def _extract_resource_uri(resource_obj: Any) -> str | None: |
| 422 | + """Extract URI from a resource object. |
| 423 | +
|
| 424 | + Handles various Pydantic wrapper structures (Resource, Resource1, RootModel, dict). |
| 425 | +
|
| 426 | + Args: |
| 427 | + resource_obj: The resource object to extract URI from. |
| 428 | +
|
| 429 | + Returns: |
| 430 | + The extracted URI string, or None if not found. |
| 431 | + """ |
| 432 | + # Direct uri attribute (Resource1, ResourceInput, etc.) |
| 433 | + if hasattr(resource_obj, 'uri'): |
| 434 | + return resource_obj.uri |
| 435 | + |
| 436 | + # Unwrap RootModel structures |
| 437 | + if hasattr(resource_obj, 'root'): |
| 438 | + return _extract_resource_uri(resource_obj.root) |
| 439 | + |
| 440 | + # Unwrap nested resource attribute |
| 441 | + if hasattr(resource_obj, 'resource'): |
| 442 | + return _extract_resource_uri(resource_obj.resource) |
| 443 | + |
| 444 | + # Handle dict representation |
| 445 | + if isinstance(resource_obj, dict) and 'uri' in resource_obj: |
| 446 | + return resource_obj['uri'] |
| 447 | + |
| 448 | + return None |
| 449 | + |
| 450 | + |
| 451 | +async def apply_resources(registry: Registry, raw_request: GenerateActionOptions) -> GenerateActionOptions: |
| 452 | + """Applies resources to the request messages by hydrating resource parts. |
| 453 | +
|
| 454 | + Args: |
| 455 | + registry: The registry to use for resolving resources. |
| 456 | + raw_request: The generation request. |
| 457 | +
|
| 458 | + Returns: |
| 459 | + The updated generation request with hydrated resources. |
| 460 | + """ |
| 461 | + # Quick check if any message has a resource part |
| 462 | + has_resource = False |
| 463 | + for msg in raw_request.messages: |
| 464 | + for part in msg.content: |
| 465 | + if part.root.resource: |
| 466 | + has_resource = True |
| 467 | + break |
| 468 | + if has_resource: |
| 469 | + break |
| 470 | + |
| 471 | + if not has_resource: |
| 472 | + return raw_request |
| 473 | + |
| 474 | + # Resolve all declared resources |
| 475 | + resources = [] |
| 476 | + if raw_request.resources: |
| 477 | + resources = await resolve_resources(registry, raw_request.resources) |
| 478 | + |
| 479 | + updated_messages = [] |
| 480 | + for msg in raw_request.messages: |
| 481 | + if not any(p.root.resource for p in msg.content): |
| 482 | + updated_messages.append(msg) |
| 483 | + continue |
| 484 | + |
| 485 | + updated_content = [] |
| 486 | + for part in msg.content: |
| 487 | + if not part.root.resource: |
| 488 | + updated_content.append(part) |
| 489 | + continue |
| 490 | + |
| 491 | + resource_obj = part.root.resource |
| 492 | + |
| 493 | + # Extract URI from the resource object |
| 494 | + # The resource can be wrapped in various Pydantic structures (Resource, Resource1, etc.) |
| 495 | + ref_uri = _extract_resource_uri(resource_obj) |
| 496 | + if not ref_uri: |
| 497 | + logger.warning( |
| 498 | + f'Unable to extract URI from resource part: {type(resource_obj).__name__}. ' |
| 499 | + f'Resource part will be skipped.' |
| 500 | + ) |
| 501 | + continue |
| 502 | + |
| 503 | + # Find matching resource action |
| 504 | + if not resources: |
| 505 | + raise GenkitError( |
| 506 | + status='NOT_FOUND', |
| 507 | + message=f'failed to find matching resource for {ref_uri}', |
| 508 | + ) |
| 509 | + |
| 510 | + from genkit.blocks.resource import ResourceInput, find_matching_resource |
| 511 | + |
| 512 | + # Normalize to ResourceInput for matching |
| 513 | + resource_input = ResourceInput(uri=ref_uri) |
| 514 | + resource_action = await find_matching_resource(registry, resources, resource_input) |
| 515 | + |
| 516 | + if not resource_action: |
| 517 | + raise GenkitError( |
| 518 | + status='NOT_FOUND', |
| 519 | + message=f'failed to find matching resource for {ref_uri}', |
| 520 | + ) |
| 521 | + |
| 522 | + # Execute the resource |
| 523 | + # Create a simple context for the resource execution |
| 524 | + resource_ctx = ActionRunContext(on_chunk=None, context=None) |
| 525 | + response = await resource_action.arun(resource_input, resource_ctx) |
| 526 | + |
| 527 | + # response.response is ResourceOutput which has .content (list of Parts) |
| 528 | + # It usually returns a dict if coming from dynamic_resource (model_dump called) |
| 529 | + output_content = None |
| 530 | + if hasattr(response.response, 'content'): |
| 531 | + output_content = response.response.content |
| 532 | + elif isinstance(response.response, dict) and 'content' in response.response: |
| 533 | + output_content = response.response['content'] |
| 534 | + |
| 535 | + if output_content: |
| 536 | + updated_content.extend(output_content) |
| 537 | + |
| 538 | + updated_messages.append(Message(role=msg.role, content=updated_content, metadata=msg.metadata)) |
| 539 | + |
| 540 | + # Return a new request with updated messages |
| 541 | + new_request = raw_request.model_copy() |
| 542 | + new_request.messages = updated_messages |
| 543 | + return new_request |
| 544 | + |
| 545 | + |
415 | 546 | def assert_valid_tool_names(raw_request: GenerateActionOptions): |
416 | 547 | """Assert that tool names in the request are valid. |
417 | 548 |
|
|
0 commit comments