Skip to content

Commit 8421334

Browse files
feat(genai): Update tuning samples to include automatic evaluations (#13550)
* feat: Update create tuning job to include tuning with automatic evaluation * Update genai/tuning/tuning_job_create.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update genai/tuning/tuning_with_checkpoints_create.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update genai/tuning/tuning_with_checkpoints_create.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update genai/tuning/tuning_job_create.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Update tuning_job_create.py * Update tuning_with_checkpoints_create.py * Update tuning_job_create.py * Update tuning_with_checkpoints_create.py * Update requirements.txt Update GenAI SDK version * Update test_tuning_examples.py * Update test_tuning_examples.py * Update test_tuning_examples.py * Update test_tuning_examples.py * Update tuning_with_checkpoints_create.py * Update tuning_job_create.py * Update test_tuning_examples.py * Update test_tuning_examples.py * Update test_tuning_examples.py --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 15b35c1 commit 8421334

File tree

4 files changed

+79
-15
lines changed

4 files changed

+79
-15
lines changed

genai/tuning/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-genai==1.27.0
1+
google-genai==1.30.0

genai/tuning/test_tuning_examples.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from datetime import datetime as dt
16+
1517
from unittest.mock import call, MagicMock, patch
1618

19+
from google.cloud import storage
1720
from google.genai import types
21+
import pytest
1822

1923
import tuning_job_create
2024
import tuning_job_get
@@ -27,8 +31,24 @@
2731
import tuning_with_checkpoints_textgen_with_txt
2832

2933

34+
GCS_OUTPUT_BUCKET = "python-docs-samples-tests"
35+
36+
37+
@pytest.fixture(scope="session")
38+
def output_gcs_uri() -> str:
39+
prefix = f"text_output/{dt.now()}"
40+
41+
yield f"gs://{GCS_OUTPUT_BUCKET}/{prefix}"
42+
43+
storage_client = storage.Client()
44+
bucket = storage_client.get_bucket(GCS_OUTPUT_BUCKET)
45+
blobs = bucket.list_blobs(prefix=prefix)
46+
for blob in blobs:
47+
blob.delete()
48+
49+
3050
@patch("google.genai.Client")
31-
def test_tuning_job_create(mock_genai_client: MagicMock) -> None:
51+
def test_tuning_job_create(mock_genai_client: MagicMock, output_gcs_uri: str) -> None:
3252
# Mock the API response
3353
mock_tuning_job = types.TuningJob(
3454
name="test-tuning-job",
@@ -40,9 +60,9 @@ def test_tuning_job_create(mock_genai_client: MagicMock) -> None:
4060
)
4161
mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job
4262

43-
response = tuning_job_create.create_tuning_job()
63+
response = tuning_job_create.create_tuning_job(output_gcs_uri=output_gcs_uri)
4464

45-
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
65+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1"))
4666
mock_genai_client.return_value.tunings.tune.assert_called_once()
4767
assert response == "test-tuning-job"
4868

@@ -121,7 +141,7 @@ def test_tuning_textgen_with_txt(mock_genai_client: MagicMock) -> None:
121141

122142

123143
@patch("google.genai.Client")
124-
def test_tuning_job_create_with_checkpoints(mock_genai_client: MagicMock) -> None:
144+
def test_tuning_job_create_with_checkpoints(mock_genai_client: MagicMock, output_gcs_uri: str) -> None:
125145
# Mock the API response
126146
mock_tuning_job = types.TuningJob(
127147
name="test-tuning-job",
@@ -137,9 +157,9 @@ def test_tuning_job_create_with_checkpoints(mock_genai_client: MagicMock) -> Non
137157
)
138158
mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job
139159

140-
response = tuning_with_checkpoints_create.create_with_checkpoints()
160+
response = tuning_with_checkpoints_create.create_with_checkpoints(output_gcs_uri=output_gcs_uri)
141161

142-
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1"))
162+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1"))
143163
mock_genai_client.return_value.tunings.tune.assert_called_once()
144164
assert response == "test-tuning-job"
145165

genai/tuning/tuning_job_create.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,46 @@
1313
# limitations under the License.
1414

1515

16-
def create_tuning_job() -> str:
16+
def create_tuning_job(output_gcs_uri: str) -> str:
1717
# [START googlegenaisdk_tuning_job_create]
1818
import time
1919

2020
from google import genai
21-
from google.genai.types import HttpOptions, CreateTuningJobConfig, TuningDataset
21+
from google.genai.types import HttpOptions, CreateTuningJobConfig, TuningDataset, EvaluationConfig, OutputConfig, GcsDestination, Metric
2222

23-
client = genai.Client(http_options=HttpOptions(api_version="v1"))
23+
# TODO(developer): Update and un-comment below line
24+
# output_gcs_uri = "gs://your-bucket/your-prefix"
25+
26+
client = genai.Client(http_options=HttpOptions(api_version="v1beta1"))
2427

2528
training_dataset = TuningDataset(
2629
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl",
2730
)
31+
validation_dataset = TuningDataset(
32+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl",
33+
)
34+
35+
evaluation_config = EvaluationConfig(
36+
metrics=[
37+
Metric(
38+
name="FLUENCY",
39+
prompt_template="""Evaluate this {response}"""
40+
)
41+
],
42+
output_config=OutputConfig(
43+
gcs_destination=GcsDestination(
44+
output_uri_prefix=output_gcs_uri,
45+
)
46+
),
47+
)
2848

2949
tuning_job = client.tunings.tune(
3050
base_model="gemini-2.5-flash",
3151
training_dataset=training_dataset,
3252
config=CreateTuningJobConfig(
3353
tuned_model_display_name="Example tuning job",
54+
validation_dataset=validation_dataset,
55+
evaluation_config=evaluation_config,
3456
),
3557
)
3658

@@ -64,4 +86,4 @@ def create_tuning_job() -> str:
6486

6587

6688
if __name__ == "__main__":
67-
create_tuning_job()
89+
create_tuning_job(output_gcs_uri="gs://your-bucket/your-prefix")

genai/tuning/tuning_with_checkpoints_create.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,38 @@
1313
# limitations under the License.
1414

1515

16-
def create_with_checkpoints() -> str:
16+
def create_with_checkpoints(output_gcs_uri: str) -> str:
1717
# [START googlegenaisdk_tuning_with_checkpoints_create]
1818
import time
1919

2020
from google import genai
21-
from google.genai.types import HttpOptions, CreateTuningJobConfig, TuningDataset
21+
from google.genai.types import HttpOptions, CreateTuningJobConfig, TuningDataset, EvaluationConfig, OutputConfig, GcsDestination, Metric
2222

23-
client = genai.Client(http_options=HttpOptions(api_version="v1"))
23+
# TODO(developer): Update and un-comment below line
24+
# output_gcs_uri = "gs://your-bucket/your-prefix"
25+
26+
client = genai.Client(http_options=HttpOptions(api_version="v1beta1"))
2427

2528
training_dataset = TuningDataset(
2629
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl",
2730
)
31+
validation_dataset = TuningDataset(
32+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl",
33+
)
34+
35+
evaluation_config = EvaluationConfig(
36+
metrics=[
37+
Metric(
38+
name="FLUENCY",
39+
prompt_template="""Evaluate this {response}"""
40+
)
41+
],
42+
output_config=OutputConfig(
43+
gcs_destination=GcsDestination(
44+
output_uri_prefix=output_gcs_uri,
45+
)
46+
),
47+
)
2848

2949
tuning_job = client.tunings.tune(
3050
base_model="gemini-2.5-flash",
@@ -33,6 +53,8 @@ def create_with_checkpoints() -> str:
3353
tuned_model_display_name="Example tuning job",
3454
# Set to True to disable tuning intermediate checkpoints. Default is False.
3555
export_last_checkpoint_only=False,
56+
validation_dataset=validation_dataset,
57+
evaluation_config=evaluation_config,
3658
),
3759
)
3860

@@ -66,4 +88,4 @@ def create_with_checkpoints() -> str:
6688

6789

6890
if __name__ == "__main__":
69-
create_with_checkpoints()
91+
create_with_checkpoints(output_gcs_uri="gs://your-bucket/your-prefix")

0 commit comments

Comments
 (0)