|
37 | 37 | except (ImportError, AttributeError): |
38 | 38 | Event = Any |
39 | 39 |
|
| 40 | + try: |
| 41 | + from google.adk.apps import App |
| 42 | + |
| 43 | + App = App |
| 44 | + except (ImportError, AttributeError): |
| 45 | + App = Any |
| 46 | + |
40 | 47 | try: |
41 | 48 | from google.adk.agents import BaseAgent |
42 | 49 |
|
@@ -449,7 +456,8 @@ class AdkApp: |
449 | 456 | def __init__( |
450 | 457 | self, |
451 | 458 | *, |
452 | | - agent: "BaseAgent", |
| 459 | + agent: "BaseAgent" = None, |
| 460 | + app: "App" = None, |
453 | 461 | app_name: Optional[str] = None, |
454 | 462 | plugins: Optional[List["BasePlugin"]] = None, |
455 | 463 | enable_tracing: Optional[bool] = None, |
@@ -505,10 +513,28 @@ def __init__( |
505 | 513 | ) |
506 | 514 | raise ValueError(msg) |
507 | 515 |
|
| 516 | + if not agent and not app: |
| 517 | + raise ValueError("One of `agent` or `app` must be provided.") |
| 518 | + if app: |
| 519 | + if app_name: |
| 520 | + raise ValueError( |
| 521 | + "When app is provided, app_name should not be provided." |
| 522 | + ) |
| 523 | + if agent: |
| 524 | + raise ValueError( |
| 525 | + "When app is provided, agent should not be provided." |
| 526 | + ) |
| 527 | + if plugins: |
| 528 | + raise ValueError( |
| 529 | + "When app is provided, plugins should not be provided and" |
| 530 | + " should be provided in the app instead." |
| 531 | + ) |
| 532 | + |
508 | 533 | self._tmpl_attrs: Dict[str, Any] = { |
509 | 534 | "project": initializer.global_config.project, |
510 | 535 | "location": initializer.global_config.location, |
511 | 536 | "agent": agent, |
| 537 | + "app": app, |
512 | 538 | "app_name": app_name, |
513 | 539 | "plugins": plugins, |
514 | 540 | "enable_tracing": enable_tracing, |
@@ -625,6 +651,7 @@ def clone(self): |
625 | 651 |
|
626 | 652 | return self.__class__( |
627 | 653 | agent=copy.deepcopy(self._tmpl_attrs.get("agent")), |
| 654 | + app=copy.deepcopy(self._tmpl_attrs.get("app")), |
628 | 655 | enable_tracing=self._tmpl_attrs.get("enable_tracing"), |
629 | 656 | app_name=self._tmpl_attrs.get("app_name"), |
630 | 657 | plugins=self._tmpl_attrs.get("plugins"), |
@@ -775,18 +802,24 @@ def tracing_enabled() -> bool: |
775 | 802 |
|
776 | 803 | self._tmpl_attrs["runner"] = Runner( |
777 | 804 | agent=self._tmpl_attrs.get("agent"), |
| 805 | + app=self._tmpl_attrs.get("app"), |
778 | 806 | plugins=self._tmpl_attrs.get("plugins"), |
779 | 807 | session_service=self._tmpl_attrs.get("session_service"), |
780 | 808 | artifact_service=self._tmpl_attrs.get("artifact_service"), |
781 | 809 | memory_service=self._tmpl_attrs.get("memory_service"), |
782 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 810 | + app_name=( |
| 811 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("app_name") |
| 812 | + ), |
783 | 813 | ) |
784 | 814 | self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() |
785 | 815 | self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() |
786 | 816 | self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() |
787 | 817 | self._tmpl_attrs["in_memory_runner"] = Runner( |
788 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 818 | + app_name=( |
| 819 | + None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("app_name") |
| 820 | + ), |
789 | 821 | agent=self._tmpl_attrs.get("agent"), |
| 822 | + app=self._tmpl_attrs.get("app"), |
790 | 823 | plugins=self._tmpl_attrs.get("plugins"), |
791 | 824 | session_service=self._tmpl_attrs.get("in_memory_session_service"), |
792 | 825 | artifact_service=self._tmpl_attrs.get("in_memory_artifact_service"), |
@@ -968,12 +1001,13 @@ async def streaming_agent_run_with_events(self, request_json: str): |
968 | 1001 | self.set_up() |
969 | 1002 | session_service = self._tmpl_attrs.get("in_memory_session_service") |
970 | 1003 | artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") |
| 1004 | + app = self._tmpl_attrs.get("app") |
971 | 1005 | # Try to get the session, if it doesn't exist, create a new one. |
972 | 1006 | session = None |
973 | 1007 | if request.session_id: |
974 | 1008 | try: |
975 | 1009 | session = await session_service.get_session( |
976 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1010 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
977 | 1011 | user_id=request.user_id, |
978 | 1012 | session_id=request.session_id, |
979 | 1013 | ) |
@@ -1006,8 +1040,9 @@ async def streaming_agent_run_with_events(self, request_json: str): |
1006 | 1040 | yield converted_event |
1007 | 1041 | finally: |
1008 | 1042 | if session and not request.session_id: |
| 1043 | + app = self._tmpl_attrs.get("app") |
1009 | 1044 | await session_service.delete_session( |
1010 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1045 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1011 | 1046 | user_id=request.user_id, |
1012 | 1047 | session_id=session.id, |
1013 | 1048 | ) |
@@ -1039,8 +1074,9 @@ async def async_get_session( |
1039 | 1074 | """ |
1040 | 1075 | if not self._tmpl_attrs.get("session_service"): |
1041 | 1076 | self.set_up() |
| 1077 | + app = self._tmpl_attrs.get("app") |
1042 | 1078 | session = await self._tmpl_attrs.get("session_service").get_session( |
1043 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1079 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1044 | 1080 | user_id=user_id, |
1045 | 1081 | session_id=session_id, |
1046 | 1082 | **kwargs, |
@@ -1116,8 +1152,9 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): |
1116 | 1152 | """ |
1117 | 1153 | if not self._tmpl_attrs.get("session_service"): |
1118 | 1154 | self.set_up() |
| 1155 | + app = self._tmpl_attrs.get("app") |
1119 | 1156 | return await self._tmpl_attrs.get("session_service").list_sessions( |
1120 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1157 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1121 | 1158 | user_id=user_id, |
1122 | 1159 | **kwargs, |
1123 | 1160 | ) |
@@ -1188,8 +1225,9 @@ async def async_create_session( |
1188 | 1225 | """ |
1189 | 1226 | if not self._tmpl_attrs.get("session_service"): |
1190 | 1227 | self.set_up() |
| 1228 | + app = self._tmpl_attrs.get("app") |
1191 | 1229 | session = await self._tmpl_attrs.get("session_service").create_session( |
1192 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1230 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1193 | 1231 | user_id=user_id, |
1194 | 1232 | session_id=session_id, |
1195 | 1233 | state=state, |
@@ -1269,8 +1307,9 @@ async def async_delete_session( |
1269 | 1307 | """ |
1270 | 1308 | if not self._tmpl_attrs.get("session_service"): |
1271 | 1309 | self.set_up() |
| 1310 | + app = self._tmpl_attrs.get("app") |
1272 | 1311 | await self._tmpl_attrs.get("session_service").delete_session( |
1273 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1312 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1274 | 1313 | user_id=user_id, |
1275 | 1314 | session_id=session_id, |
1276 | 1315 | **kwargs, |
@@ -1359,8 +1398,9 @@ async def async_search_memory(self, *, user_id: str, query: str): |
1359 | 1398 | """ |
1360 | 1399 | if not self._tmpl_attrs.get("memory_service"): |
1361 | 1400 | self.set_up() |
| 1401 | + app = self._tmpl_attrs.get("app") |
1362 | 1402 | return await self._tmpl_attrs.get("memory_service").search_memory( |
1363 | | - app_name=self._tmpl_attrs.get("app_name"), |
| 1403 | + app_name=app.name if app else self._tmpl_attrs.get("app_name"), |
1364 | 1404 | user_id=user_id, |
1365 | 1405 | query=query, |
1366 | 1406 | ) |
|
0 commit comments