Skip to content

Commit e8b1683

Browse files
feat: add code samples for continuous tuning (#13594)
* feat: add code samples for continuous tuning * feat: add code samples for continuous tuning * Update continuous_tuning_create.py * Update test_tuning_examples.py * Update requirements.txt * Rename continuous_tuning_create.py to tuning_with_pretuned_model.py * Update test_tuning_examples.py * Update tuning_with_pretuned_model.py * Update genai/tuning/tuning_with_pretuned_model.py Co-authored-by: Sampath Kumar <[email protected]> --------- Co-authored-by: Sampath Kumar <[email protected]>
1 parent d08bfe4 commit e8b1683

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

genai/tuning/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
google-genai==1.30.0
1+
google-genai==1.41.0

genai/tuning/test_tuning_examples.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import tuning_with_checkpoints_list_checkpoints
3030
import tuning_with_checkpoints_set_default_checkpoint
3131
import tuning_with_checkpoints_textgen_with_txt
32+
import tuning_with_pretuned_model
3233

3334

3435
GCS_OUTPUT_BUCKET = "python-docs-samples-tests"
@@ -306,3 +307,23 @@ def test_tuning_with_checkpoints_textgen_with_txt(mock_genai_client: MagicMock)
306307
call(model="test-endpoint-1", contents="Why is the sky blue?"),
307308
call(model="test-endpoint-2", contents="Why is the sky blue?"),
308309
]
310+
311+
312+
@patch("google.genai.Client")
313+
def test_tuning_with_pretuned_model(mock_genai_client: MagicMock) -> None:
314+
# Mock the API response
315+
mock_tuning_job = types.TuningJob(
316+
name="test-tuning-job",
317+
experiment="test-experiment",
318+
tuned_model=types.TunedModel(
319+
model="test-model-2",
320+
endpoint="test-endpoint"
321+
)
322+
)
323+
mock_genai_client.return_value.tunings.tune.return_value = mock_tuning_job
324+
325+
response = tuning_with_pretuned_model.create_continuous_tuning_job(tuned_model_name="test-model", checkpoint_id="1")
326+
327+
mock_genai_client.assert_called_once_with(http_options=types.HttpOptions(api_version="v1beta1"))
328+
mock_genai_client.return_value.tunings.tune.assert_called_once()
329+
assert response == "test-tuning-job"
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_continuous_tuning_job(tuned_model_name: str, checkpoint_id: str) -> str:
17+
# [START googlegenaisdk_tuning_with_pretuned_model]
18+
import time
19+
20+
from google import genai
21+
from google.genai.types import HttpOptions, TuningDataset, CreateTuningJobConfig
22+
23+
# TODO(developer): Update and un-comment below line
24+
# tuned_model_name = "projects/123456789012/locations/us-central1/models/1234567890@1"
25+
# checkpoint_id = "1"
26+
27+
client = genai.Client(http_options=HttpOptions(api_version="v1beta1"))
28+
29+
training_dataset = TuningDataset(
30+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_train_data.jsonl",
31+
)
32+
validation_dataset = TuningDataset(
33+
gcs_uri="gs://cloud-samples-data/ai-platform/generative_ai/gemini/text/sft_validation_data.jsonl",
34+
)
35+
36+
tuning_job = client.tunings.tune(
37+
base_model=tuned_model_name, # Note: Using a Tuned Model
38+
training_dataset=training_dataset,
39+
config=CreateTuningJobConfig(
40+
tuned_model_display_name="Example tuning job",
41+
validation_dataset=validation_dataset,
42+
pre_tuned_model_checkpoint_id=checkpoint_id,
43+
),
44+
)
45+
46+
running_states = set([
47+
"JOB_STATE_PENDING",
48+
"JOB_STATE_RUNNING",
49+
])
50+
51+
while tuning_job.state in running_states:
52+
print(tuning_job.state)
53+
tuning_job = client.tunings.get(name=tuning_job.name)
54+
time.sleep(60)
55+
56+
print(tuning_job.tuned_model.model)
57+
print(tuning_job.tuned_model.endpoint)
58+
print(tuning_job.experiment)
59+
# Example response:
60+
# projects/123456789012/locations/us-central1/models/1234567890@2
61+
# projects/123456789012/locations/us-central1/endpoints/123456789012345
62+
# projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678
63+
64+
if tuning_job.tuned_model.checkpoints:
65+
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
66+
print(f"Checkpoint {i + 1}: ", checkpoint)
67+
# Example response:
68+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000'
69+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345'
70+
71+
# [END googlegenaisdk_tuning_with_pretuned_model]
72+
return tuning_job.name
73+
74+
75+
if __name__ == "__main__":
76+
pre_tuned_model_name = input("Pre-tuned model name: ")
77+
pre_tuned_model_checkpoint_id = input("Pre-tuned model checkpoint id: ")
78+
create_continuous_tuning_job(pre_tuned_model_name, pre_tuned_model_checkpoint_id)

0 commit comments

Comments
 (0)