Skip to content

Commit 29a6894

Browse files
feat(generativeai): New Samples for Model Tuning folder (#12605)
1 parent 539007c commit 29a6894

18 files changed

+1016
-0
lines changed
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
from vertexai.preview.evaluation import EvalResult
17+
18+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
19+
20+
21+
def create_evaluation_task() -> EvalResult:
22+
# [START generativeaionvertexai_create_evaluation_task]
23+
import pandas as pd
24+
25+
import vertexai
26+
from vertexai.preview.evaluation import EvalTask, MetricPromptTemplateExamples
27+
28+
# TODO(developer): Update and un-comment below line
29+
# PROJECT_ID = "your-project-id"
30+
vertexai.init(project=PROJECT_ID, location="us-central1")
31+
32+
eval_dataset = pd.DataFrame(
33+
{
34+
"instruction": [
35+
"Summarize the text in one sentence.",
36+
"Summarize the text such that a five-year-old can understand.",
37+
],
38+
"context": [
39+
"""As part of a comprehensive initiative to tackle urban congestion and foster
40+
sustainable urban living, a major city has revealed ambitious plans for an
41+
extensive overhaul of its public transportation system. The project aims not
42+
only to improve the efficiency and reliability of public transit but also to
43+
reduce the city\'s carbon footprint and promote eco-friendly commuting options.
44+
City officials anticipate that this strategic investment will enhance
45+
accessibility for residents and visitors alike, ushering in a new era of
46+
efficient, environmentally conscious urban transportation.""",
47+
"""A team of archaeologists has unearthed ancient artifacts shedding light on a
48+
previously unknown civilization. The findings challenge existing historical
49+
narratives and provide valuable insights into human history.""",
50+
],
51+
"response": [
52+
"A major city is revamping its public transportation system to fight congestion, reduce emissions, and make getting around greener and easier.",
53+
"Some people who dig for old things found some very special tools and objects that tell us about people who lived a long, long time ago! What they found is like a new puzzle piece that helps us understand how people used to live.",
54+
],
55+
}
56+
)
57+
58+
eval_task = EvalTask(
59+
dataset=eval_dataset,
60+
metrics=[
61+
MetricPromptTemplateExamples.Pointwise.SUMMARIZATION_QUALITY,
62+
MetricPromptTemplateExamples.Pointwise.GROUNDEDNESS,
63+
MetricPromptTemplateExamples.Pointwise.VERBOSITY,
64+
MetricPromptTemplateExamples.Pointwise.INSTRUCTION_FOLLOWING,
65+
],
66+
)
67+
68+
prompt_template = (
69+
"Instruction: {instruction}. Article: {context}. Summary: {response}"
70+
)
71+
result = eval_task.evaluate(prompt_template=prompt_template)
72+
73+
print("Summary Metrics:\n")
74+
75+
for key, value in result.summary_metrics.items():
76+
print(f"{key}: \t{value}")
77+
78+
print("\n\nMetrics Table:\n")
79+
print(result.metrics_table)
80+
# Example response:
81+
# Summary Metrics:
82+
# row_count: 2
83+
# summarization_quality/mean: 3.5
84+
# summarization_quality/std: 2.1213203435596424
85+
# ...
86+
87+
# [END generativeaionvertexai_create_evaluation_task]
88+
return result
89+
90+
91+
if __name__ == "__main__":
92+
create_evaluation_task()
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import create_evaluation_task_example
16+
17+
18+
def test_create_evaluation_task() -> None:
19+
response = create_evaluation_task_example.create_evaluation_task()
20+
assert response
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START generativeaionvertexai_sdk_distillation]
16+
from __future__ import annotations
17+
18+
import os
19+
20+
from typing import Optional
21+
22+
import vertexai
23+
from vertexai.preview.language_models import TextGenerationModel, TuningEvaluationSpec
24+
25+
26+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
27+
28+
29+
def distill_model(
30+
dataset: str,
31+
source_model: str,
32+
evaluation_dataset: Optional[str] = None,
33+
) -> None:
34+
"""Distill a new model using a teacher model and a dataset.
35+
Args:
36+
dataset (str): GCS URI of the JSONL file containing the training data.
37+
E.g., "gs://[BUCKET]/[FILENAME].jsonl".
38+
source_model (str): Name of the teacher model to distill from.
39+
E.g., "text-unicorn@001".
40+
evaluation_dataset (Optional[str]): GCS URI of the JSONL file containing the evaluation data.
41+
"""
42+
# TODO developer - override these parameters as needed:
43+
vertexai.init(project=PROJECT_ID, location="us-central1")
44+
45+
# Create a tuning evaluation specification with the evaluation dataset
46+
eval_spec = TuningEvaluationSpec(evaluation_data=evaluation_dataset)
47+
48+
# Load the student model from a pre-trained model
49+
student_model = TextGenerationModel.from_pretrained("text-bison@002")
50+
51+
# Start the distillation job using the teacher model and dataset
52+
distillation_job = student_model.distill_from(
53+
teacher_model=source_model,
54+
dataset=dataset,
55+
# Optional:
56+
train_steps=300, # Number of training steps to use when tuning the model.
57+
evaluation_spec=eval_spec,
58+
)
59+
60+
return distillation_job
61+
62+
63+
# [END generativeaionvertexai_sdk_distillation]
64+
65+
if __name__ == "__main__":
66+
distill_model(
67+
dataset="your-dataset-uri",
68+
source_model="your-source-model",
69+
evaluation_dataset="your-evaluation-dataset-uri",
70+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import uuid
17+
18+
import distillation_example
19+
20+
from google.cloud import aiplatform
21+
from google.cloud import storage
22+
23+
from google.cloud.aiplatform.compat.types import pipeline_state
24+
25+
import pytest
26+
27+
from vertexai.preview.language_models import TextGenerationModel
28+
29+
_BUCKET = os.environ["CLOUD_STORAGE_BUCKET"]
30+
31+
32+
def get_model_display_name(tuned_model: TextGenerationModel) -> str:
33+
language_model_tuning_job = tuned_model._job
34+
pipeline_job = language_model_tuning_job._job
35+
return dict(pipeline_job._gca_resource.runtime_config.parameter_values)[
36+
"model_display_name"
37+
]
38+
39+
40+
def upload_to_gcs(bucket: str, name: str, data: str) -> None:
41+
client = storage.Client()
42+
bucket = client.get_bucket(bucket)
43+
blob = bucket.blob(name)
44+
blob.upload_from_string(data)
45+
46+
47+
def download_from_gcs(bucket: str, name: str) -> str:
48+
client = storage.Client()
49+
bucket = client.get_bucket(bucket)
50+
blob = bucket.blob(name)
51+
data = blob.download_as_bytes()
52+
return "\n".join(data.decode().splitlines()[:10])
53+
54+
55+
def delete_from_gcs(bucket: str, name: str) -> None:
56+
client = storage.Client()
57+
bucket = client.get_bucket(bucket)
58+
blob = bucket.blob(name)
59+
blob.delete()
60+
61+
62+
@pytest.fixture(scope="function")
63+
def training_data_filename() -> str:
64+
temp_filename = f"{uuid.uuid4()}.jsonl"
65+
data = download_from_gcs(
66+
"cloud-samples-data", "ai-platform/generative_ai/headline_classification.jsonl"
67+
)
68+
upload_to_gcs(_BUCKET, temp_filename, data)
69+
try:
70+
yield f"gs://{_BUCKET}/{temp_filename}"
71+
finally:
72+
delete_from_gcs(_BUCKET, temp_filename)
73+
74+
75+
def teardown_model(
76+
tuned_model: TextGenerationModel, training_data_filename: str
77+
) -> None:
78+
for tuned_model_name in tuned_model.list_tuned_model_names():
79+
model_registry = aiplatform.models.ModelRegistry(model=tuned_model_name)
80+
if (
81+
training_data_filename
82+
in model_registry.get_version_info("1").model_display_name
83+
):
84+
display_name = model_registry.get_version_info("1").model_display_name
85+
for endpoint in aiplatform.Endpoint.list():
86+
for _ in endpoint.list_models():
87+
if endpoint.display_name == display_name:
88+
endpoint.undeploy_all()
89+
endpoint.delete()
90+
aiplatform.Model(model_registry.model_resource_name).delete()
91+
92+
93+
@pytest.mark.skip("Blocked on b/277959219")
94+
def test_distill_model(training_data_filename: str) -> None:
95+
"""Takes approx. 60 minutes."""
96+
student_model = distillation_example.distill_model(
97+
dataset=training_data_filename,
98+
teacher_model="text-unicorn@001",
99+
evaluation_dataset=training_data_filename,
100+
)
101+
try:
102+
assert (
103+
student_model._job.state
104+
== pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED
105+
)
106+
finally:
107+
teardown_model(student_model, training_data_filename)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# [START generativeaionvertexai_evaluate_model]
16+
import os
17+
18+
from google.auth import default
19+
20+
import vertexai
21+
from vertexai.preview.language_models import (
22+
EvaluationTextClassificationSpec,
23+
TextGenerationModel,
24+
)
25+
26+
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
27+
28+
29+
def evaluate_model() -> object:
30+
"""Evaluate the performance of a generative AI model."""
31+
32+
# Set credentials for the pipeline components used in the evaluation task
33+
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
34+
35+
vertexai.init(project=PROJECT_ID, location="us-central1", credentials=credentials)
36+
37+
# Create a reference to a generative AI model
38+
model = TextGenerationModel.from_pretrained("text-bison@002")
39+
40+
# Define the evaluation specification for a text classification task
41+
task_spec = EvaluationTextClassificationSpec(
42+
ground_truth_data=[
43+
"gs://cloud-samples-data/ai-platform/generative_ai/llm_classification_bp_input_prompts_with_ground_truth.jsonl"
44+
],
45+
class_names=["nature", "news", "sports", "health", "startups"],
46+
target_column_name="ground_truth",
47+
)
48+
49+
# Evaluate the model
50+
eval_metrics = model.evaluate(task_spec=task_spec)
51+
print(eval_metrics)
52+
# Example response:
53+
# ...
54+
# PipelineJob run completed.
55+
# Resource name: projects/123456789/locations/us-central1/pipelineJobs/evaluation-llm-classification-...
56+
# EvaluationClassificationMetric(label_name=None, auPrc=0.53833705, auRoc=0.8...
57+
58+
return eval_metrics
59+
60+
61+
# [END generativeaionvertexai_evaluate_model]
62+
63+
64+
if __name__ == "__main__":
65+
evaluate_model()
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import backoff
16+
17+
import evaluate_model_example
18+
19+
from google.api_core.exceptions import ResourceExhausted
20+
21+
import pytest
22+
23+
24+
@pytest.mark.skip(
25+
reason="Model is giving 404 Not found error."
26+
"Need to investigate. Created an issue tracker is at "
27+
"python-docs-samples/issues/11264"
28+
)
29+
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10)
30+
def test_evaluate_model() -> None:
31+
eval_metrics = evaluate_model_example.evaluate_model()
32+
assert hasattr(eval_metrics, "auRoc")

0 commit comments

Comments
 (0)