Skip to content

Commit e1ff4fc

Browse files
author
James Su
committed
feat: add code samples for preference tuning
1 parent eff0040 commit e1ff4fc

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def create_tuning_job() -> str:
17+
# [START googlegenaisdk_tuning_job_create]
18+
import time
19+
20+
from google import genai
21+
from google.genai.types import HttpOptions, CreateTuningJobConfig, TuningDataset
22+
23+
client = genai.Client(http_options=HttpOptions(api_version="v1"))
24+
25+
training_dataset = TuningDataset(
26+
gcs_uri="gs://mybucket/preference_tuning/data/train_data.jsonl",
27+
)
28+
validation_dataset = TuningDataset(
29+
gcs_uri="gs://mybucket/preference_tuning/data/validation_data.jsonl",
30+
)
31+
32+
# Refer to https://docs.cloud.google.com/vertex-ai/generative-ai/docs/models/gemini-use-continuous-tuning#google-gen-ai-sdk
33+
# for example to continuous tune from SFT tuned model.
34+
tuning_job = client.tunings.tune(
35+
base_model="gemini-2.5-flash",
36+
training_dataset=training_dataset,
37+
config=CreateTuningJobConfig(
38+
tuned_model_display_name="Example tuning job",
39+
method="PREFERENCE_TUNING",
40+
validation_dataset=validation_dataset,
41+
),
42+
)
43+
44+
running_states = set([
45+
"JOB_STATE_PENDING",
46+
"JOB_STATE_RUNNING",
47+
])
48+
49+
while tuning_job.state in running_states:
50+
print(tuning_job.state)
51+
tuning_job = client.tunings.get(name=tuning_job.name)
52+
time.sleep(60)
53+
54+
print(tuning_job.tuned_model.model)
55+
print(tuning_job.tuned_model.endpoint)
56+
print(tuning_job.experiment)
57+
# Example response:
58+
# projects/123456789012/locations/us-central1/models/1234567890@1
59+
# projects/123456789012/locations/us-central1/endpoints/123456789012345
60+
# projects/123456789012/locations/us-central1/metadataStores/default/contexts/tuning-experiment-2025010112345678
61+
62+
if tuning_job.tuned_model.checkpoints:
63+
for i, checkpoint in enumerate(tuning_job.tuned_model.checkpoints):
64+
print(f"Checkpoint {i + 1}: ", checkpoint)
65+
# Example response:
66+
# Checkpoint 1: checkpoint_id='1' epoch=1 step=10 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789000000'
67+
# Checkpoint 2: checkpoint_id='2' epoch=2 step=20 endpoint='projects/123456789012/locations/us-central1/endpoints/123456789012345'
68+
69+
# [END googlegenaisdk_tuning_job_create]
70+
return tuning_job.name
71+
72+
73+
if __name__ == "__main__":
74+
create_tuning_job()

0 commit comments

Comments
 (0)