|
37 | 37 | except (ImportError, AttributeError): |
38 | 38 | Event = Any |
39 | 39 |
|
| 40 | + try: |
| 41 | + from google.adk.apps import App |
| 42 | + |
| 43 | + App = App |
| 44 | + except (ImportError, AttributeError): |
| 45 | + App = Any |
| 46 | + |
40 | 47 | try: |
41 | 48 | from google.adk.agents import BaseAgent |
42 | 49 |
|
@@ -449,7 +456,8 @@ class AdkApp: |
449 | 456 | def __init__( |
450 | 457 | self, |
451 | 458 | *, |
452 | | - agent: "BaseAgent", |
| 459 | + agent: "BaseAgent" = None, |
| 460 | + app: "App" = None, |
453 | 461 | app_name: Optional[str] = None, |
454 | 462 | plugins: Optional[List["BasePlugin"]] = None, |
455 | 463 | enable_tracing: Optional[bool] = None, |
@@ -505,10 +513,28 @@ def __init__( |
505 | 513 | ) |
506 | 514 | raise ValueError(msg) |
507 | 515 |
|
| 516 | + if not agent and not app: |
| 517 | + raise ValueError("One of `agent` or `app` must be provided.") |
| 518 | + if app: |
| 519 | + if app_name: |
| 520 | + raise ValueError( |
| 521 | + "When app is provided, app_name should not be provided." |
| 522 | + ) |
| 523 | + if agent: |
| 524 | + raise ValueError( |
| 525 | + "When app is provided, agent should not be provided." |
| 526 | + ) |
| 527 | + if plugins: |
| 528 | + raise ValueError( |
| 529 | + "When app is provided, plugins should not be provided and" |
| 530 | + " should be provided in the app instead." |
| 531 | + ) |
| 532 | + |
508 | 533 | self._tmpl_attrs: Dict[str, Any] = { |
509 | 534 | "project": initializer.global_config.project, |
510 | 535 | "location": initializer.global_config.location, |
511 | 536 | "agent": agent, |
| 537 | + "app": app, |
512 | 538 | "app_name": app_name, |
513 | 539 | "plugins": plugins, |
514 | 540 | "enable_tracing": enable_tracing, |
@@ -625,6 +651,7 @@ def clone(self): |
625 | 651 |
|
626 | 652 | return self.__class__( |
627 | 653 | agent=copy.deepcopy(self._tmpl_attrs.get("agent")), |
| 654 | + app=copy.deepcopy(self._tmpl_attrs.get("app")), |
628 | 655 | enable_tracing=self._tmpl_attrs.get("enable_tracing"), |
629 | 656 | app_name=self._tmpl_attrs.get("app_name"), |
630 | 657 | plugins=self._tmpl_attrs.get("plugins"), |
@@ -775,18 +802,28 @@ def tracing_enabled() -> bool: |
775 | 802 |
|
776 | 803 | self._tmpl_attrs["runner"] = Runner( |
777 | 804 | agent=self._tmpl_attrs.get("agent"), |
| 805 | + app=self._tmpl_attrs.get("app"), |
778 | 806 | plugins=self._tmpl_attrs.get("plugins"), |
779 | 807 | session_service=self._tmpl_attrs.get("session_service"), |
780 | 808 | artifact_service=self._tmpl_attrs.get("artifact_service"), |
781 | 809 | memory_service=self._tmpl_attrs.get("memory_service"), |
782 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 810 | + app_name=( |
| 811 | + self._tmpl_attrs.get("app").name |
| 812 | + if self._tmpl_attrs.get("app") |
| 813 | + else self._tmpl_attrs.get("app_name") |
| 814 | + ), |
783 | 815 | ) |
784 | 816 | self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() |
785 | 817 | self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() |
786 | 818 | self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() |
787 | 819 | self._tmpl_attrs["in_memory_runner"] = Runner( |
788 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 820 | + app_name=( |
| 821 | + self._tmpl_attrs.get("app").name |
| 822 | + if self._tmpl_attrs.get("app") |
| 823 | + else self._tmpl_attrs.get("app_name") |
| 824 | + ), |
789 | 825 | agent=self._tmpl_attrs.get("agent"), |
| 826 | + app=self._tmpl_attrs.get("app"), |
790 | 827 | plugins=self._tmpl_attrs.get("plugins"), |
791 | 828 | session_service=self._tmpl_attrs.get("in_memory_session_service"), |
792 | 829 | artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), |
|
0 commit comments