Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
/cdn/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/compute/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/dns/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/gemma2/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/generative_ai/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/iam/cloud-client/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/kms/**/** @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
/media_cdn/**/* @GoogleCloudPlatform/dee-infra @GoogleCloudPlatform/python-samples-reviewers @GoogleCloudPlatform/cloud-samples-reviewers
Expand Down
79 changes: 79 additions & 0 deletions gemma2/gemma2_predict_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")


def gemma2_predict_gpu(ENDPOINT_REGION: str, ENDPOINT_ID: str) -> str:
# [START generativeaionvertexai_gemma2_predict_gpu]
"""
Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with GPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for GPUs
# Example format: [{'inputs': 'Why is the sky blue?', 'parameters': {'temperature': 0.9}}]
input = {"inputs": prompt, "parameters": config}

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
endpoint=gemma2_end_point,
instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

# [END generativeaionvertexai_gemma2_predict_gpu]
return text_responses[0]


if __name__ == "__main__":
if len(sys.argv) != 3:
print(
"Usage: python gemma2_predict_gpu.py <GEMMA2_ENDPOINT_REGION> <GEMMA2_ENDPOINT_ID>"
)
sys.exit(1)

ENDPOINT_REGION = sys.argv[1]
ENDPOINT_ID = sys.argv[2]
gemma2_predict_gpu(ENDPOINT_REGION, ENDPOINT_ID)
80 changes: 80 additions & 0 deletions gemma2/gemma2_predict_tpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys

PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")


def gemma2_predict_tpu(ENDPOINT_REGION: str, ENDPOINT_ID: str) -> str:
# [START generativeaionvertexai_gemma2_predict_tpu]
"""
Sample to run interference on a Gemma2 model deployed to a Vertex AI endpoint with TPU accellerators.
"""

from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value

# TODO(developer): Update & uncomment lines below
# PROJECT_ID = "your-project-id"
# ENDPOINT_REGION = "your-vertex-endpoint-region"
# ENDPOINT_ID = "your-vertex-endpoint-id"

# Default configuration
config = {"max_tokens": 1024, "temperature": 0.9, "top_p": 1.0, "top_k": 1}

# Prompt used in the prediction
prompt = "Why is the sky blue?"

# Encapsulate the prompt in a correct format for TPUs
# Example format: [{'prompt': 'Why is the sky blue?', 'temperature': 0.9}]
input = {"prompt": prompt}
input.update(config)

# Convert input message to a list of GAPIC instances for model input
instances = [json_format.ParseDict(input, Value())]

# Create a client
api_endpoint = f"{ENDPOINT_REGION}-aiplatform.googleapis.com"
client = aiplatform.gapic.PredictionServiceClient(
client_options={"api_endpoint": api_endpoint}
)

# Call the Gemma2 endpoint
gemma2_end_point = (
f"projects/{PROJECT_ID}/locations/{ENDPOINT_REGION}/endpoints/{ENDPOINT_ID}"
)
response = client.predict(
endpoint=gemma2_end_point,
instances=instances,
)
text_responses = response.predictions
print(text_responses[0])

# [END generativeaionvertexai_gemma2_predict_tpu]
return text_responses[0]


if __name__ == "__main__":
if len(sys.argv) != 3:
print(
"Usage: python gemma2_predict_tpu.py <GEMMA2_ENDPOINT_REGION> <GEMMA2_ENDPOINT_ID>"
)
sys.exit(1)

ENDPOINT_REGION = sys.argv[1]
ENDPOINT_ID = sys.argv[2]
gemma2_predict_tpu(ENDPOINT_REGION, ENDPOINT_ID)
102 changes: 102 additions & 0 deletions gemma2/gemma2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHcontent WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import MutableSequence, Optional
from unittest import mock
from unittest.mock import MagicMock

from google.cloud.aiplatform_v1.types import prediction_service
import google.protobuf.struct_pb2 as struct_pb2
from google.protobuf.struct_pb2 import Value

from gemma2_predict_gpu import gemma2_predict_gpu
from gemma2_predict_tpu import gemma2_predict_tpu

# Global variables
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
GPU_ENDPOINT_REGION = "us-east1"
GPU_ENDPOINT_ID = "123456789" # Mock ID used to check if GPU was called

TPU_ENDPOINT_REGION = "us-west1"
TPU_ENDPOINT_ID = "987654321" # Mock ID used to check if TPU was called

# MOCKED RESPONSE
MODEL_RESPONSES = """
The sky appears blue due to a phenomenon called **Rayleigh scattering**.

**Here's how it works:**

1. **Sunlight:** Sunlight is composed of all the colors of the rainbow.

2. **Earth's Atmosphere:** When sunlight enters the Earth's atmosphere, it collides with tiny particles like nitrogen and oxygen molecules.

3. **Scattering:** These particles scatter the sunlight in all directions. However, blue light (which has a shorter wavelength) is scattered more effectively than other colors.

4. **Our Perception:** As a result, we see a blue sky because the scattered blue light reaches our eyes from all directions.

**Why not other colors?**

* **Violet light** has an even shorter wavelength than blue and is scattered even more. However, our eyes are less sensitive to violet light, so we perceive the sky as blue.
* **Longer wavelengths** like red, orange, and yellow are scattered less and travel more directly through the atmosphere. This is why we see these colors during sunrise and sunset, when sunlight has to travel through more of the atmosphere.
"""


# Mocked function - we check if proper format was used depending on selected architecture
def mock_predict(
endpoint: Optional[str] = None,
instances: Optional[MutableSequence[struct_pb2.Value]] = None,
) -> prediction_service.PredictResponse:
gpu_endpoint = f"projects/{PROJECT_ID}/locations/{GPU_ENDPOINT_REGION}/endpoints/{GPU_ENDPOINT_ID}"
tpu_endpoint = f"projects/{PROJECT_ID}/locations/{TPU_ENDPOINT_REGION}/endpoints/{TPU_ENDPOINT_ID}"
instance_fields = instances[0].struct_value.fields

if endpoint == gpu_endpoint:
assert "string_value" in instance_fields["inputs"]
assert "struct_value" in instance_fields["parameters"]
parameters = instance_fields["parameters"].struct_value.fields
assert "number_value" in parameters["max_tokens"]
assert "number_value" in parameters["temperature"]
assert "number_value" in parameters["top_p"]
assert "number_value" in parameters["top_k"]
elif endpoint == tpu_endpoint:
assert "string_value" in instance_fields["prompt"]
assert "number_value" in instance_fields["max_tokens"]
assert "number_value" in instance_fields["temperature"]
assert "number_value" in instance_fields["top_p"]
assert "number_value" in instance_fields["top_k"]
else:
assert False

response = prediction_service.PredictResponse()
response.predictions.append(Value(string_value=MODEL_RESPONSES))
return response


@mock.patch("google.cloud.aiplatform.gapic.PredictionServiceClient")
def test_gemma2_predict_gpu(mock_client: MagicMock) -> None:
mock_client_instance = mock_client.return_value
mock_client_instance.predict = mock_predict

response = gemma2_predict_gpu(GPU_ENDPOINT_REGION, GPU_ENDPOINT_ID)
assert "Rayleigh scattering" in response


@mock.patch("google.cloud.aiplatform.gapic.PredictionServiceClient")
def test_gemma2_predict_tpu(mock_client: MagicMock) -> None:
mock_client_instance = mock_client.return_value
mock_client_instance.predict = mock_predict

response = gemma2_predict_tpu(TPU_ENDPOINT_REGION, TPU_ENDPOINT_ID)
assert "Rayleigh scattering" in response
42 changes: 42 additions & 0 deletions gemma2/noxfile_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Default TEST_CONFIG_OVERRIDE for python repos.

# You can copy this file into your directory, then it will be imported from
# the noxfile.py.

# The source of truth:
# https://github.com/GoogleCloudPlatform/python-docs-samples/blob/main/noxfile_config.py

TEST_CONFIG_OVERRIDE = {
# You can opt out from the test for specific Python versions.
"ignored_versions": ["2.7", "3.7", "3.9", "3.10", "3.11"],
# Old samples are opted out of enforcing Python type hints
# All new samples should feature them
"enforce_type_hints": True,
# An envvar key for determining the project id to use. Change it
# to 'BUILD_SPECIFIC_GCLOUD_PROJECT' if you want to opt in using a
# build specific Cloud project. You can also use your own string
# to use your own Cloud project.
"gcloud_project_env": "GOOGLE_CLOUD_PROJECT",
# 'gcloud_project_env': 'BUILD_SPECIFIC_GCLOUD_PROJECT',
# If you need to use a specific version of pip,
# change pip_version_override to the string representation
# of the version number, for example, "20.2.4"
"pip_version_override": None,
# A dictionary you want to inject into your test. Don't put any
# secrets here. These values will override predefined values.
"envs": {},
}
1 change: 1 addition & 0 deletions gemma2/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest==8.3.3
2 changes: 2 additions & 0 deletions gemma2/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
google-cloud-aiplatform[all]==1.64.0
protobuf==5.28.1
Loading