@@ -213,6 +213,24 @@ def _CreateTuningJobConfig_to_mldev(
213213 if getv (from_object , ['beta' ]) is not None :
214214 raise ValueError ('beta parameter is not supported in Gemini API.' )
215215
216+ if getv (from_object , ['base_teacher_model' ]) is not None :
217+ raise ValueError (
218+ 'base_teacher_model parameter is not supported in Gemini API.'
219+ )
220+
221+ if getv (from_object , ['tuned_teacher_model_source' ]) is not None :
222+ raise ValueError (
223+ 'tuned_teacher_model_source parameter is not supported in Gemini API.'
224+ )
225+
226+ if getv (from_object , ['sft_loss_weight_multiplier' ]) is not None :
227+ raise ValueError (
228+ 'sft_loss_weight_multiplier parameter is not supported in Gemini API.'
229+ )
230+
231+ if getv (from_object , ['output_uri' ]) is not None :
232+ raise ValueError ('output_uri parameter is not supported in Gemini API.' )
233+
216234 return to_object
217235
218236
@@ -246,6 +264,16 @@ def _CreateTuningJobConfig_to_vertex(
246264 ),
247265 )
248266
267+ elif discriminator == 'DISTILLATION' :
268+ if getv (from_object , ['validation_dataset' ]) is not None :
269+ setv (
270+ parent_object ,
271+ ['distillationSpec' ],
272+ _TuningValidationDataset_to_vertex (
273+ getv (from_object , ['validation_dataset' ]), to_object , root_object
274+ ),
275+ )
276+
249277 if getv (from_object , ['tuned_model_display_name' ]) is not None :
250278 setv (
251279 parent_object ,
@@ -275,6 +303,14 @@ def _CreateTuningJobConfig_to_vertex(
275303 getv (from_object , ['epoch_count' ]),
276304 )
277305
306+ elif discriminator == 'DISTILLATION' :
307+ if getv (from_object , ['epoch_count' ]) is not None :
308+ setv (
309+ parent_object ,
310+ ['distillationSpec' , 'hyperParameters' , 'epochCount' ],
311+ getv (from_object , ['epoch_count' ]),
312+ )
313+
278314 discriminator = getv (root_object , ['config' , 'method' ])
279315 if discriminator is None :
280316 discriminator = 'SUPERVISED_FINE_TUNING'
@@ -298,6 +334,14 @@ def _CreateTuningJobConfig_to_vertex(
298334 getv (from_object , ['learning_rate_multiplier' ]),
299335 )
300336
337+ elif discriminator == 'DISTILLATION' :
338+ if getv (from_object , ['learning_rate_multiplier' ]) is not None :
339+ setv (
340+ parent_object ,
341+ ['distillationSpec' , 'hyperParameters' , 'learningRateMultiplier' ],
342+ getv (from_object , ['learning_rate_multiplier' ]),
343+ )
344+
301345 discriminator = getv (root_object , ['config' , 'method' ])
302346 if discriminator is None :
303347 discriminator = 'SUPERVISED_FINE_TUNING'
@@ -317,6 +361,14 @@ def _CreateTuningJobConfig_to_vertex(
317361 getv (from_object , ['export_last_checkpoint_only' ]),
318362 )
319363
364+ elif discriminator == 'DISTILLATION' :
365+ if getv (from_object , ['export_last_checkpoint_only' ]) is not None :
366+ setv (
367+ parent_object ,
368+ ['distillationSpec' , 'exportLastCheckpointOnly' ],
369+ getv (from_object , ['export_last_checkpoint_only' ]),
370+ )
371+
320372 discriminator = getv (root_object , ['config' , 'method' ])
321373 if discriminator is None :
322374 discriminator = 'SUPERVISED_FINE_TUNING'
@@ -336,6 +388,14 @@ def _CreateTuningJobConfig_to_vertex(
336388 getv (from_object , ['adapter_size' ]),
337389 )
338390
391+ elif discriminator == 'DISTILLATION' :
392+ if getv (from_object , ['adapter_size' ]) is not None :
393+ setv (
394+ parent_object ,
395+ ['distillationSpec' , 'hyperParameters' , 'adapterSize' ],
396+ getv (from_object , ['adapter_size' ]),
397+ )
398+
339399 if getv (from_object , ['batch_size' ]) is not None :
340400 raise ValueError ('batch_size parameter is not supported in Vertex AI.' )
341401
@@ -365,6 +425,16 @@ def _CreateTuningJobConfig_to_vertex(
365425 ),
366426 )
367427
428+ elif discriminator == 'DISTILLATION' :
429+ if getv (from_object , ['evaluation_config' ]) is not None :
430+ setv (
431+ parent_object ,
432+ ['distillationSpec' , 'evaluationConfig' ],
433+ _EvaluationConfig_to_vertex (
434+ getv (from_object , ['evaluation_config' ]), to_object , root_object
435+ ),
436+ )
437+
368438 if getv (from_object , ['labels' ]) is not None :
369439 setv (parent_object , ['labels' ], getv (from_object , ['labels' ]))
370440
@@ -375,6 +445,30 @@ def _CreateTuningJobConfig_to_vertex(
375445 getv (from_object , ['beta' ]),
376446 )
377447
448+ if getv (from_object , ['base_teacher_model' ]) is not None :
449+ setv (
450+ parent_object ,
451+ ['distillationSpec' , 'baseTeacherModel' ],
452+ getv (from_object , ['base_teacher_model' ]),
453+ )
454+
455+ if getv (from_object , ['tuned_teacher_model_source' ]) is not None :
456+ setv (
457+ parent_object ,
458+ ['distillationSpec' , 'tunedTeacherModelSource' ],
459+ getv (from_object , ['tuned_teacher_model_source' ]),
460+ )
461+
462+ if getv (from_object , ['sft_loss_weight_multiplier' ]) is not None :
463+ setv (
464+ parent_object ,
465+ ['distillationSpec' , 'hyperParameters' , 'sftLossWeightMultiplier' ],
466+ getv (from_object , ['sft_loss_weight_multiplier' ]),
467+ )
468+
469+ if getv (from_object , ['output_uri' ]) is not None :
470+ setv (parent_object , ['outputUri' ], getv (from_object , ['output_uri' ]))
471+
378472 return to_object
379473
380474
@@ -920,6 +1014,14 @@ def _TuningDataset_to_vertex(
9201014 getv (from_object , ['gcs_uri' ]),
9211015 )
9221016
1017+ elif discriminator == 'DISTILLATION' :
1018+ if getv (from_object , ['gcs_uri' ]) is not None :
1019+ setv (
1020+ parent_object ,
1021+ ['distillationSpec' , 'promptDatasetUri' ],
1022+ getv (from_object , ['gcs_uri' ]),
1023+ )
1024+
9231025 discriminator = getv (root_object , ['config' , 'method' ])
9241026 if discriminator is None :
9251027 discriminator = 'SUPERVISED_FINE_TUNING'
@@ -939,6 +1041,14 @@ def _TuningDataset_to_vertex(
9391041 getv (from_object , ['vertex_dataset_resource' ]),
9401042 )
9411043
1044+ elif discriminator == 'DISTILLATION' :
1045+ if getv (from_object , ['vertex_dataset_resource' ]) is not None :
1046+ setv (
1047+ parent_object ,
1048+ ['distillationSpec' , 'promptDatasetUri' ],
1049+ getv (from_object , ['vertex_dataset_resource' ]),
1050+ )
1051+
9421052 if getv (from_object , ['examples' ]) is not None :
9431053 raise ValueError ('examples parameter is not supported in Vertex AI.' )
9441054
@@ -1066,6 +1176,13 @@ def _TuningJob_from_vertex(
10661176 getv (from_object , ['preferenceOptimizationSpec' ]),
10671177 )
10681178
1179+ if getv (from_object , ['distillationSpec' ]) is not None :
1180+ setv (
1181+ to_object ,
1182+ ['distillation_spec' ],
1183+ getv (from_object , ['distillationSpec' ]),
1184+ )
1185+
10691186 if getv (from_object , ['tuningDataStats' ]) is not None :
10701187 setv (
10711188 to_object , ['tuning_data_stats' ], getv (from_object , ['tuningDataStats' ])
0 commit comments