|
16 | 16 | import importlib |
17 | 17 | import json |
18 | 18 | import os |
| 19 | +import cloudpickle |
| 20 | +import sys |
19 | 21 | from unittest import mock |
20 | 22 | from typing import Optional |
21 | 23 |
|
22 | 24 | from google import auth |
| 25 | +from google.auth import credentials as auth_credentials |
| 26 | +from google.cloud import storage |
23 | 27 | import vertexai |
| 28 | +from google.cloud import aiplatform |
| 29 | +from google.cloud.aiplatform_v1 import types as aip_types |
| 30 | +from google.cloud.aiplatform_v1.services import reasoning_engine_service |
| 31 | +from google.cloud.aiplatform import base |
24 | 32 | from google.cloud.aiplatform import initializer |
25 | 33 | from vertexai.agent_engines import _utils |
26 | 34 | from vertexai import agent_engines |
| 35 | +from vertexai.agent_engines.templates import adk as adk_template |
| 36 | +from vertexai.agent_engines import _agent_engines |
| 37 | +from google.api_core import operation as ga_operation |
27 | 38 | from google.genai import types |
28 | 39 | import pytest |
29 | 40 | import uuid |
@@ -75,6 +86,52 @@ def __init__(self, name: str, model: str): |
75 | 86 | "streaming_mode": "sse", |
76 | 87 | "max_llm_calls": 500, |
77 | 88 | } |
| 89 | +_TEST_STAGING_BUCKET = "gs://test-bucket" |
| 90 | +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) |
| 91 | +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" |
| 92 | +_TEST_RESOURCE_ID = "1028944691210842416" |
| 93 | +_TEST_AGENT_ENGINE_RESOURCE_NAME = ( |
| 94 | + f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" |
| 95 | +) |
| 96 | +_TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name" |
| 97 | +_TEST_GCS_DIR_NAME = _agent_engines._DEFAULT_GCS_DIR_NAME |
| 98 | +_TEST_BLOB_FILENAME = _agent_engines._BLOB_FILENAME |
| 99 | +_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE |
| 100 | +_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE |
| 101 | +_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}".format( |
| 102 | + _TEST_STAGING_BUCKET, |
| 103 | + _TEST_GCS_DIR_NAME, |
| 104 | + _TEST_BLOB_FILENAME, |
| 105 | +) |
| 106 | +_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}".format( |
| 107 | + _TEST_STAGING_BUCKET, |
| 108 | + _TEST_GCS_DIR_NAME, |
| 109 | + _TEST_EXTRA_PACKAGES_FILE, |
| 110 | +) |
| 111 | +_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}".format( |
| 112 | + _TEST_STAGING_BUCKET, |
| 113 | + _TEST_GCS_DIR_NAME, |
| 114 | + _TEST_REQUIREMENTS_FILE, |
| 115 | +) |
| 116 | +_TEST_AGENT_ENGINE_PACKAGE_SPEC = aip_types.ReasoningEngineSpec.PackageSpec( |
| 117 | + python_version=f"{sys.version_info.major}.{sys.version_info.minor}", |
| 118 | + pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI, |
| 119 | + dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, |
| 120 | + requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, |
| 121 | +) |
| 122 | +_ADK_AGENT_FRAMEWORK = adk_template.AdkApp.agent_framework |
| 123 | +_TEST_AGENT_ENGINE_OBJ = aip_types.ReasoningEngine( |
| 124 | + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
| 125 | + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, |
| 126 | + spec=aip_types.ReasoningEngineSpec( |
| 127 | + package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC, |
| 128 | + agent_framework=_ADK_AGENT_FRAMEWORK, |
| 129 | + ), |
| 130 | +) |
| 131 | + |
| 132 | +GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( |
| 133 | + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" |
| 134 | +) |
78 | 135 |
|
79 | 136 |
|
80 | 137 | @pytest.fixture(scope="module") |
@@ -727,3 +784,174 @@ async def test_async_stream_query_invalid_message_type(self): |
727 | 784 | ): |
728 | 785 | async for _ in app.async_stream_query(user_id=_TEST_USER_ID, message=123): |
729 | 786 | pass |
| 787 | + |
| 788 | + |
| 789 | +@pytest.fixture(scope="module") |
| 790 | +def create_agent_engine_mock(): |
| 791 | + with mock.patch.object( |
| 792 | + reasoning_engine_service.ReasoningEngineServiceClient, |
| 793 | + "create_reasoning_engine", |
| 794 | + ) as create_agent_engine_mock: |
| 795 | + create_agent_engine_lro_mock = mock.Mock(ga_operation.Operation) |
| 796 | + create_agent_engine_lro_mock.result.return_value = _TEST_AGENT_ENGINE_OBJ |
| 797 | + create_agent_engine_mock.return_value = create_agent_engine_lro_mock |
| 798 | + yield create_agent_engine_mock |
| 799 | + |
| 800 | + |
| 801 | +@pytest.fixture(scope="module") |
| 802 | +def get_agent_engine_mock(): |
| 803 | + with mock.patch.object( |
| 804 | + reasoning_engine_service.ReasoningEngineServiceClient, |
| 805 | + "get_reasoning_engine", |
| 806 | + ) as get_agent_engine_mock: |
| 807 | + api_client_mock = mock.Mock() |
| 808 | + api_client_mock.get_reasoning_engine.return_value = _TEST_AGENT_ENGINE_OBJ |
| 809 | + get_agent_engine_mock.return_value = api_client_mock |
| 810 | + yield get_agent_engine_mock |
| 811 | + |
| 812 | + |
| 813 | +@pytest.fixture(scope="module") |
| 814 | +def cloud_storage_create_bucket_mock(): |
| 815 | + with mock.patch.object(storage, "Client") as cloud_storage_mock: |
| 816 | + bucket_mock = mock.Mock(spec=storage.Bucket) |
| 817 | + bucket_mock.blob.return_value.open.return_value = "blob_file" |
| 818 | + bucket_mock.blob.return_value.upload_from_filename.return_value = None |
| 819 | + bucket_mock.blob.return_value.upload_from_string.return_value = None |
| 820 | + |
| 821 | + cloud_storage_mock.get_bucket = mock.Mock( |
| 822 | + side_effect=ValueError("bucket not found") |
| 823 | + ) |
| 824 | + cloud_storage_mock.bucket.return_value = bucket_mock |
| 825 | + cloud_storage_mock.create_bucket.return_value = bucket_mock |
| 826 | + |
| 827 | + yield cloud_storage_mock |
| 828 | + |
| 829 | + |
| 830 | +@pytest.fixture(scope="module") |
| 831 | +def cloudpickle_dump_mock(): |
| 832 | + with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: |
| 833 | + yield cloudpickle_dump_mock |
| 834 | + |
| 835 | + |
| 836 | +@pytest.fixture(scope="module") |
| 837 | +def cloudpickle_load_mock(): |
| 838 | + with mock.patch.object(cloudpickle, "load") as cloudpickle_load_mock: |
| 839 | + yield cloudpickle_load_mock |
| 840 | + |
| 841 | + |
| 842 | +@pytest.fixture(scope="function") |
| 843 | +def get_gca_resource_mock(): |
| 844 | + with mock.patch.object( |
| 845 | + base.VertexAiResourceNoun, |
| 846 | + "_get_gca_resource", |
| 847 | + ) as get_gca_resource_mock: |
| 848 | + get_gca_resource_mock.return_value = _TEST_AGENT_ENGINE_OBJ |
| 849 | + yield get_gca_resource_mock |
| 850 | + |
| 851 | + |
| 852 | +# Function scope is required for the pytest parameterized tests. |
| 853 | +@pytest.fixture(scope="function") |
| 854 | +def update_agent_engine_mock(): |
| 855 | + with mock.patch.object( |
| 856 | + reasoning_engine_service.ReasoningEngineServiceClient, |
| 857 | + "update_reasoning_engine", |
| 858 | + ) as update_agent_engine_mock: |
| 859 | + yield update_agent_engine_mock |
| 860 | + |
| 861 | + |
| 862 | +@pytest.mark.usefixtures("google_auth_mock") |
| 863 | +class TestAgentEngines: |
| 864 | + def setup_method(self): |
| 865 | + importlib.reload(initializer) |
| 866 | + importlib.reload(aiplatform) |
| 867 | + aiplatform.init( |
| 868 | + project=_TEST_PROJECT, |
| 869 | + location=_TEST_LOCATION, |
| 870 | + credentials=_TEST_CREDENTIALS, |
| 871 | + staging_bucket=_TEST_STAGING_BUCKET, |
| 872 | + ) |
| 873 | + |
| 874 | + def teardown_method(self): |
| 875 | + initializer.global_pool.shutdown(wait=True) |
| 876 | + |
| 877 | + @pytest.mark.parametrize( |
| 878 | + "env_vars,expected_env_vars", |
| 879 | + [ |
| 880 | + ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
| 881 | + (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
| 882 | + ( |
| 883 | + {"some_env": "some_val"}, |
| 884 | + { |
| 885 | + "some_env": "some_val", |
| 886 | + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
| 887 | + }, |
| 888 | + ), |
| 889 | + ( |
| 890 | + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
| 891 | + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
| 892 | + ), |
| 893 | + ], |
| 894 | + ) |
| 895 | + def test_create_default_telemetry_enablement( |
| 896 | + self, |
| 897 | + create_agent_engine_mock: mock.Mock, |
| 898 | + cloud_storage_create_bucket_mock: mock.Mock, |
| 899 | + cloudpickle_dump_mock: mock.Mock, |
| 900 | + cloudpickle_load_mock: mock.Mock, |
| 901 | + get_gca_resource_mock: mock.Mock, |
| 902 | + env_vars: dict[str, str], |
| 903 | + expected_env_vars: dict[str, str], |
| 904 | + ): |
| 905 | + agent_engines.create( |
| 906 | + agent_engine=agent_engines.AdkApp(agent=_TEST_AGENT), |
| 907 | + env_vars=env_vars, |
| 908 | + ) |
| 909 | + create_agent_engine_mock.assert_called_once() |
| 910 | + deployment_spec = create_agent_engine_mock.call_args.kwargs[ |
| 911 | + "reasoning_engine" |
| 912 | + ].spec.deployment_spec |
| 913 | + assert _utils.to_dict(deployment_spec)["env"] == [ |
| 914 | + {"name": key, "value": value} for key, value in expected_env_vars.items() |
| 915 | + ] |
| 916 | + |
| 917 | + @pytest.mark.parametrize( |
| 918 | + "env_vars,expected_env_vars", |
| 919 | + [ |
| 920 | + ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
| 921 | + (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}), |
| 922 | + ( |
| 923 | + {"some_env": "some_val"}, |
| 924 | + { |
| 925 | + "some_env": "some_val", |
| 926 | + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true", |
| 927 | + }, |
| 928 | + ), |
| 929 | + ( |
| 930 | + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
| 931 | + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, |
| 932 | + ), |
| 933 | + ], |
| 934 | + ) |
| 935 | + def test_update_default_telemetry_enablement( |
| 936 | + self, |
| 937 | + update_agent_engine_mock: mock.Mock, |
| 938 | + cloud_storage_create_bucket_mock: mock.Mock, |
| 939 | + cloudpickle_dump_mock: mock.Mock, |
| 940 | + cloudpickle_load_mock: mock.Mock, |
| 941 | + get_gca_resource_mock: mock.Mock, |
| 942 | + get_agent_engine_mock: mock.Mock, |
| 943 | + env_vars: dict[str, str], |
| 944 | + expected_env_vars: dict[str, str], |
| 945 | + ): |
| 946 | + agent_engines.update( |
| 947 | + resource_name=_TEST_AGENT_ENGINE_RESOURCE_NAME, |
| 948 | + description="foobar", # avoid "At least one of ... must be specified" errors. |
| 949 | + env_vars=env_vars, |
| 950 | + ) |
| 951 | + update_agent_engine_mock.assert_called_once() |
| 952 | + deployment_spec = update_agent_engine_mock.call_args.kwargs[ |
| 953 | + "request" |
| 954 | + ].reasoning_engine.spec.deployment_spec |
| 955 | + assert _utils.to_dict(deployment_spec)["env"] == [ |
| 956 | + {"name": key, "value": value} for key, value in expected_env_vars.items() |
| 957 | + ] |
0 commit comments