|
16 | 16 | import importlib |
17 | 17 | import json |
18 | 18 | import os |
19 | | -import cloudpickle |
20 | | -import sys |
21 | 19 | from unittest import mock |
22 | 20 | from typing import Optional |
23 | 21 |
|
24 | 22 | from google import auth |
25 | | -from google.auth import credentials as auth_credentials |
26 | | -from google.cloud import storage |
27 | 23 | import vertexai |
28 | | -from google.cloud import aiplatform |
29 | | -from google.cloud.aiplatform_v1 import types as aip_types |
30 | | -from google.cloud.aiplatform_v1.services import reasoning_engine_service |
31 | | -from google.cloud.aiplatform import base |
32 | 24 | from google.cloud.aiplatform import initializer |
33 | 25 | from vertexai.agent_engines import _utils |
34 | 26 | from vertexai import agent_engines |
35 | | -from vertexai.agent_engines.templates import adk as adk_template |
36 | | -from vertexai.agent_engines import _agent_engines |
37 | | -from google.api_core import operation as ga_operation |
38 | 27 | from google.genai import types |
39 | 28 | import pytest |
40 | 29 | import uuid |
@@ -86,52 +75,6 @@ def __init__(self, name: str, model: str): |
86 | 75 | "streaming_mode": "sse", |
87 | 76 | "max_llm_calls": 500, |
88 | 77 | } |
89 | | -_TEST_STAGING_BUCKET = "gs://test-bucket" |
90 | | -_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) |
91 | | -_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
92 | | -_TEST_RESOURCE_ID = "1028944691210842416" |
93 | | -_TEST_AGENT_ENGINE_RESOURCE_NAME = ( |
94 | | - f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" |
95 | | -) |
96 | | -_TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name" |
97 | | -_TEST_GCS_DIR_NAME = _agent_engines._DEFAULT_GCS_DIR_NAME |
98 | | -_TEST_BLOB_FILENAME = _agent_engines._BLOB_FILENAME |
99 | | -_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE |
100 | | -_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE |
101 | | -_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}".format( |
102 | | - _TEST_STAGING_BUCKET, |
103 | | - _TEST_GCS_DIR_NAME, |
104 | | - _TEST_BLOB_FILENAME, |
105 | | -) |
106 | | -_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}".format( |
107 | | - _TEST_STAGING_BUCKET, |
108 | | - _TEST_GCS_DIR_NAME, |
109 | | - _TEST_EXTRA_PACKAGES_FILE, |
110 | | -) |
111 | | -_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}".format( |
112 | | - _TEST_STAGING_BUCKET, |
113 | | - _TEST_GCS_DIR_NAME, |
114 | | - _TEST_REQUIREMENTS_FILE, |
115 | | -) |
116 | | -_TEST_AGENT_ENGINE_PACKAGE_SPEC = aip_types.ReasoningEngineSpec.PackageSpec( |
117 | | - python_version=f"{sys.version_info.major}.{sys.version_info.minor}", |
118 | | - pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI, |
119 | | - dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, |
120 | | - requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, |
121 | | -) |
122 | | -_ADK_AGENT_FRAMEWORK = adk_template.AdkApp.agent_framework |
123 | | -_TEST_AGENT_ENGINE_OBJ = aip_types.ReasoningEngine( |
124 | | - name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
125 | | - display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, |
126 | | - spec=aip_types.ReasoningEngineSpec( |
127 | | - package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC, |
128 | | - agent_framework=_ADK_AGENT_FRAMEWORK, |
129 | | - ), |
130 | | -) |
131 | | - |
132 | | -GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( |
133 | | - "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" |
134 | | -) |
135 | 78 |
|
136 | 79 |
|
137 | 80 | @pytest.fixture(scope="module") |
@@ -783,174 +726,3 @@ async def test_async_stream_query_invalid_message_type(self): |
783 | 726 | ): |
784 | 727 | async for _ in app.async_stream_query(user_id=_TEST_USER_ID, message=123): |
785 | 728 | pass |
786 | | - |
787 | | - |
788 | | -@pytest.fixture(scope="module") |
789 | | -def create_agent_engine_mock(): |
790 | | - with mock.patch.object( |
791 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
792 | | - "create_reasoning_engine", |
793 | | - ) as create_agent_engine_mock: |
794 | | - create_agent_engine_lro_mock = mock.Mock(ga_operation.Operation) |
795 | | - create_agent_engine_lro_mock.result.return_value = _TEST_AGENT_ENGINE_OBJ |
796 | | - create_agent_engine_mock.return_value = create_agent_engine_lro_mock |
797 | | - yield create_agent_engine_mock |
798 | | - |
799 | | - |
800 | | -@pytest.fixture(scope="module") |
801 | | -def get_agent_engine_mock(): |
802 | | - with mock.patch.object( |
803 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
804 | | - "get_reasoning_engine", |
805 | | - ) as get_agent_engine_mock: |
806 | | - api_client_mock = mock.Mock() |
807 | | - api_client_mock.get_reasoning_engine.return_value = _TEST_AGENT_ENGINE_OBJ |
808 | | - get_agent_engine_mock.return_value = api_client_mock |
809 | | - yield get_agent_engine_mock |
810 | | - |
811 | | - |
812 | | -@pytest.fixture(scope="module") |
813 | | -def cloud_storage_create_bucket_mock(): |
814 | | - with mock.patch.object(storage, "Client") as cloud_storage_mock: |
815 | | - bucket_mock = mock.Mock(spec=storage.Bucket) |
816 | | - bucket_mock.blob.return_value.open.return_value = "blob_file" |
817 | | - bucket_mock.blob.return_value.upload_from_filename.return_value = None |
818 | | - bucket_mock.blob.return_value.upload_from_string.return_value = None |
819 | | - |
820 | | - cloud_storage_mock.get_bucket = mock.Mock( |
821 | | - side_effect=ValueError("bucket not found") |
822 | | - ) |
823 | | - cloud_storage_mock.bucket.return_value = bucket_mock |
824 | | - cloud_storage_mock.create_bucket.return_value = bucket_mock |
825 | | - |
826 | | - yield cloud_storage_mock |
827 | | - |
828 | | - |
829 | | -@pytest.fixture(scope="module") |
830 | | -def cloudpickle_dump_mock(): |
831 | | - with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: |
832 | | - yield cloudpickle_dump_mock |
833 | | - |
834 | | - |
835 | | -@pytest.fixture(scope="module") |
836 | | -def cloudpickle_load_mock(): |
837 | | - with mock.patch.object(cloudpickle, "load") as cloudpickle_load_mock: |
838 | | - yield cloudpickle_load_mock |
839 | | - |
840 | | - |
841 | | -@pytest.fixture(scope="function") |
842 | | -def get_gca_resource_mock(): |
843 | | - with mock.patch.object( |
844 | | - base.VertexAiResourceNoun, |
845 | | - "_get_gca_resource", |
846 | | - ) as get_gca_resource_mock: |
847 | | - get_gca_resource_mock.return_value = _TEST_AGENT_ENGINE_OBJ |
848 | | - yield get_gca_resource_mock |
849 | | - |
850 | | - |
851 | | -# Function scope is required for the pytest parameterized tests. |
852 | | -@pytest.fixture(scope="function") |
853 | | -def update_agent_engine_mock(): |
854 | | - with mock.patch.object( |
855 | | - reasoning_engine_service.ReasoningEngineServiceClient, |
856 | | - "update_reasoning_engine", |
857 | | - ) as update_agent_engine_mock: |
858 | | - yield update_agent_engine_mock |
859 | | - |
860 | | - |
861 | | -@pytest.mark.usefixtures("google_auth_mock") |
862 | | -class TestAgentEngines: |
863 | | - def setup_method(self): |
864 | | - importlib.reload(initializer) |
865 | | - importlib.reload(aiplatform) |
866 | | - aiplatform.init( |
867 | | - project=_TEST_PROJECT, |
868 | | - location=_TEST_LOCATION, |
869 | | - credentials=_TEST_CREDENTIALS, |
870 | | - staging_bucket=_TEST_STAGING_BUCKET, |
871 | | - ) |
872 | | - |
873 | | - def teardown_method(self): |
874 | | - initializer.global_pool.shutdown(wait=True) |
875 | | - |
876 | | - @pytest.mark.parametrize( |
877 | | - "env_vars,expected_env_vars", |
878 | | - [ |
879 | | - ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
880 | | - (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
881 | | - ( |
882 | | - {"some_env": "some_val"}, |
883 | | - { |
884 | | - "some_env": "some_val", |
885 | | - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
886 | | - }, |
887 | | - ), |
888 | | - ( |
889 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
890 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
891 | | - ), |
892 | | - ], |
893 | | - ) |
894 | | - def test_create_default_telemetry_enablement( |
895 | | - self, |
896 | | - create_agent_engine_mock: mock.Mock, |
897 | | - cloud_storage_create_bucket_mock: mock.Mock, |
898 | | - cloudpickle_dump_mock: mock.Mock, |
899 | | - cloudpickle_load_mock: mock.Mock, |
900 | | - get_gca_resource_mock: mock.Mock, |
901 | | - env_vars: dict[str, str], |
902 | | - expected_env_vars: dict[str, str], |
903 | | - ): |
904 | | - agent_engines.create( |
905 | | - agent_engine=agent_engines.AdkApp(agent=_TEST_AGENT), |
906 | | - env_vars=env_vars, |
907 | | - ) |
908 | | - create_agent_engine_mock.assert_called_once() |
909 | | - deployment_spec = create_agent_engine_mock.call_args.kwargs[ |
910 | | - "reasoning_engine" |
911 | | - ].spec.deployment_spec |
912 | | - assert _utils.to_dict(deployment_spec)["env"] == [ |
913 | | - {"name": key, "value": value} for key, value in expected_env_vars.items() |
914 | | - ] |
915 | | - |
916 | | - @pytest.mark.parametrize( |
917 | | - "env_vars,expected_env_vars", |
918 | | - [ |
919 | | - ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
920 | | - (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
921 | | - ( |
922 | | - {"some_env": "some_val"}, |
923 | | - { |
924 | | - "some_env": "some_val", |
925 | | - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
926 | | - }, |
927 | | - ), |
928 | | - ( |
929 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
930 | | - {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
931 | | - ), |
932 | | - ], |
933 | | - ) |
934 | | - def test_update_default_telemetry_enablement( |
935 | | - self, |
936 | | - update_agent_engine_mock: mock.Mock, |
937 | | - cloud_storage_create_bucket_mock: mock.Mock, |
938 | | - cloudpickle_dump_mock: mock.Mock, |
939 | | - cloudpickle_load_mock: mock.Mock, |
940 | | - get_gca_resource_mock: mock.Mock, |
941 | | - get_agent_engine_mock: mock.Mock, |
942 | | - env_vars: dict[str, str], |
943 | | - expected_env_vars: dict[str, str], |
944 | | - ): |
945 | | - agent_engines.update( |
946 | | - resource_name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
947 | | - description="foobar", # avoid "At least one of ... must be specified" errors. |
948 | | - env_vars=env_vars, |
949 | | - ) |
950 | | - update_agent_engine_mock.assert_called_once() |
951 | | - deployment_spec = update_agent_engine_mock.call_args.kwargs[ |
952 | | - "request" |
953 | | - ].reasoning_engine.spec.deployment_spec |
954 | | - assert _utils.to_dict(deployment_spec)["env"] == [ |
955 | | - {"name": key, "value": value} for key, value in expected_env_vars.items() |
956 | | - ] |
0 commit comments