Skip to content

Commit 9dba4d4

Browse files
speedstorm1copybara-github
authored andcommitted
feat: Support distillation tuning
FUTURE_COPYBARA_INTEGRATE_REVIEW=#1966 from googleapis:release-please--branches--main 8b4a9a7 PiperOrigin-RevId: 859146888
1 parent c9851d6 commit 9dba4d4

File tree

3 files changed

+282
-2
lines changed

3 files changed

+282
-2
lines changed

google/genai/tests/tunings/test_tune.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@
2020
from .. import pytest_helper
2121
import pytest
2222

23+
24+
VERTEX_HTTP_OPTIONS = {
25+
'api_version': 'v1beta1',
26+
'base_url': 'https://us-central1-autopush-aiplatform.sandbox.googleapis.com/',
27+
}
28+
2329
evaluation_config=genai_types.EvaluationConfig(
2430
metrics=[
2531
genai_types.Metric(name="bleu", prompt_template="test prompt template")
@@ -158,6 +164,26 @@
158164
),
159165
exception_if_mldev="vertex_dataset_resource parameter is not supported in Gemini API.",
160166
),
167+
pytest_helper.TestTableItem(
168+
name="test_tune_distillation",
169+
parameters=genai_types.CreateTuningJobParameters(
170+
base_model="meta/[email protected]",
171+
training_dataset=genai_types.TuningDataset(
172+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-openai-opposites.jsonl",
173+
),
174+
config=genai_types.CreateTuningJobConfig(
175+
method="DISTILLATION",
176+
base_teacher_model="deepseek-ai/deepseek-v3.1-maas",
177+
epoch_count=20,
178+
validation_dataset=genai_types.TuningValidationDataset(
179+
gcs_uri="gs://nathreya-oss-tuning-sdk-test/distillation-val-openai-opposites.jsonl",
180+
),
181+
output_uri="gs://nathreya-oss-tuning-sdk-test/ayushagra-distillation-test-folder",
182+
http_options=VERTEX_HTTP_OPTIONS,
183+
),
184+
),
185+
exception_if_mldev="parameter is not supported in Gemini API.",
186+
),
161187
]
162188

163189
pytestmark = pytest_helper.setup(

google/genai/tunings.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,24 @@ def _CreateTuningJobConfig_to_mldev(
213213
if getv(from_object, ['beta']) is not None:
214214
raise ValueError('beta parameter is not supported in Gemini API.')
215215

216+
if getv(from_object, ['base_teacher_model']) is not None:
217+
raise ValueError(
218+
'base_teacher_model parameter is not supported in Gemini API.'
219+
)
220+
221+
if getv(from_object, ['tuned_teacher_model_source']) is not None:
222+
raise ValueError(
223+
'tuned_teacher_model_source parameter is not supported in Gemini API.'
224+
)
225+
226+
if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
227+
raise ValueError(
228+
'sft_loss_weight_multiplier parameter is not supported in Gemini API.'
229+
)
230+
231+
if getv(from_object, ['output_uri']) is not None:
232+
raise ValueError('output_uri parameter is not supported in Gemini API.')
233+
216234
return to_object
217235

218236

@@ -246,6 +264,16 @@ def _CreateTuningJobConfig_to_vertex(
246264
),
247265
)
248266

267+
elif discriminator == 'DISTILLATION':
268+
if getv(from_object, ['validation_dataset']) is not None:
269+
setv(
270+
parent_object,
271+
['distillationSpec'],
272+
_TuningValidationDataset_to_vertex(
273+
getv(from_object, ['validation_dataset']), to_object, root_object
274+
),
275+
)
276+
249277
if getv(from_object, ['tuned_model_display_name']) is not None:
250278
setv(
251279
parent_object,
@@ -275,6 +303,14 @@ def _CreateTuningJobConfig_to_vertex(
275303
getv(from_object, ['epoch_count']),
276304
)
277305

306+
elif discriminator == 'DISTILLATION':
307+
if getv(from_object, ['epoch_count']) is not None:
308+
setv(
309+
parent_object,
310+
['distillationSpec', 'hyperParameters', 'epochCount'],
311+
getv(from_object, ['epoch_count']),
312+
)
313+
278314
discriminator = getv(root_object, ['config', 'method'])
279315
if discriminator is None:
280316
discriminator = 'SUPERVISED_FINE_TUNING'
@@ -298,6 +334,14 @@ def _CreateTuningJobConfig_to_vertex(
298334
getv(from_object, ['learning_rate_multiplier']),
299335
)
300336

337+
elif discriminator == 'DISTILLATION':
338+
if getv(from_object, ['learning_rate_multiplier']) is not None:
339+
setv(
340+
parent_object,
341+
['distillationSpec', 'hyperParameters', 'learningRateMultiplier'],
342+
getv(from_object, ['learning_rate_multiplier']),
343+
)
344+
301345
discriminator = getv(root_object, ['config', 'method'])
302346
if discriminator is None:
303347
discriminator = 'SUPERVISED_FINE_TUNING'
@@ -317,6 +361,14 @@ def _CreateTuningJobConfig_to_vertex(
317361
getv(from_object, ['export_last_checkpoint_only']),
318362
)
319363

364+
elif discriminator == 'DISTILLATION':
365+
if getv(from_object, ['export_last_checkpoint_only']) is not None:
366+
setv(
367+
parent_object,
368+
['distillationSpec', 'exportLastCheckpointOnly'],
369+
getv(from_object, ['export_last_checkpoint_only']),
370+
)
371+
320372
discriminator = getv(root_object, ['config', 'method'])
321373
if discriminator is None:
322374
discriminator = 'SUPERVISED_FINE_TUNING'
@@ -336,6 +388,14 @@ def _CreateTuningJobConfig_to_vertex(
336388
getv(from_object, ['adapter_size']),
337389
)
338390

391+
elif discriminator == 'DISTILLATION':
392+
if getv(from_object, ['adapter_size']) is not None:
393+
setv(
394+
parent_object,
395+
['distillationSpec', 'hyperParameters', 'adapterSize'],
396+
getv(from_object, ['adapter_size']),
397+
)
398+
339399
if getv(from_object, ['batch_size']) is not None:
340400
raise ValueError('batch_size parameter is not supported in Vertex AI.')
341401

@@ -365,6 +425,16 @@ def _CreateTuningJobConfig_to_vertex(
365425
),
366426
)
367427

428+
elif discriminator == 'DISTILLATION':
429+
if getv(from_object, ['evaluation_config']) is not None:
430+
setv(
431+
parent_object,
432+
['distillationSpec', 'evaluationConfig'],
433+
_EvaluationConfig_to_vertex(
434+
getv(from_object, ['evaluation_config']), to_object, root_object
435+
),
436+
)
437+
368438
if getv(from_object, ['labels']) is not None:
369439
setv(parent_object, ['labels'], getv(from_object, ['labels']))
370440

@@ -375,6 +445,30 @@ def _CreateTuningJobConfig_to_vertex(
375445
getv(from_object, ['beta']),
376446
)
377447

448+
if getv(from_object, ['base_teacher_model']) is not None:
449+
setv(
450+
parent_object,
451+
['distillationSpec', 'baseTeacherModel'],
452+
getv(from_object, ['base_teacher_model']),
453+
)
454+
455+
if getv(from_object, ['tuned_teacher_model_source']) is not None:
456+
setv(
457+
parent_object,
458+
['distillationSpec', 'tunedTeacherModelSource'],
459+
getv(from_object, ['tuned_teacher_model_source']),
460+
)
461+
462+
if getv(from_object, ['sft_loss_weight_multiplier']) is not None:
463+
setv(
464+
parent_object,
465+
['distillationSpec', 'hyperParameters', 'sftLossWeightMultiplier'],
466+
getv(from_object, ['sft_loss_weight_multiplier']),
467+
)
468+
469+
if getv(from_object, ['output_uri']) is not None:
470+
setv(parent_object, ['outputUri'], getv(from_object, ['output_uri']))
471+
378472
return to_object
379473

380474

@@ -920,6 +1014,14 @@ def _TuningDataset_to_vertex(
9201014
getv(from_object, ['gcs_uri']),
9211015
)
9221016

1017+
elif discriminator == 'DISTILLATION':
1018+
if getv(from_object, ['gcs_uri']) is not None:
1019+
setv(
1020+
parent_object,
1021+
['distillationSpec', 'promptDatasetUri'],
1022+
getv(from_object, ['gcs_uri']),
1023+
)
1024+
9231025
discriminator = getv(root_object, ['config', 'method'])
9241026
if discriminator is None:
9251027
discriminator = 'SUPERVISED_FINE_TUNING'
@@ -939,6 +1041,14 @@ def _TuningDataset_to_vertex(
9391041
getv(from_object, ['vertex_dataset_resource']),
9401042
)
9411043

1044+
elif discriminator == 'DISTILLATION':
1045+
if getv(from_object, ['vertex_dataset_resource']) is not None:
1046+
setv(
1047+
parent_object,
1048+
['distillationSpec', 'promptDatasetUri'],
1049+
getv(from_object, ['vertex_dataset_resource']),
1050+
)
1051+
9421052
if getv(from_object, ['examples']) is not None:
9431053
raise ValueError('examples parameter is not supported in Vertex AI.')
9441054

@@ -1066,6 +1176,13 @@ def _TuningJob_from_vertex(
10661176
getv(from_object, ['preferenceOptimizationSpec']),
10671177
)
10681178

1179+
if getv(from_object, ['distillationSpec']) is not None:
1180+
setv(
1181+
to_object,
1182+
['distillation_spec'],
1183+
getv(from_object, ['distillationSpec']),
1184+
)
1185+
10691186
if getv(from_object, ['tuningDataStats']) is not None:
10701187
setv(
10711188
to_object, ['tuning_data_stats'], getv(from_object, ['tuningDataStats'])

0 commit comments

Comments
 (0)