@@ -627,6 +627,19 @@ def set_default_model_params(
627627 )
628628 return params
629629
630+ def set_structured_decoding_parameters (
631+ spec : Any ,
632+ parameters : Optional [dict [str , Any ]],
633+ ) -> dict [str , Any ]:
634+ if parameters is None :
635+ parameters = {}
636+
637+ if spec is not None and parameters ["response_format" ] is None and "guided_decoding_backend" not in parameters :
638+ schema = pdltype_to_jsonschema (spec , True )
639+ #parameters["guided_decoding_backend"] = "lm-format-enforcer"
640+ #parameters["guided_json"] = schema
641+ parameters ["response_format" ] = { "type" : "json_schema" , "json_schema" : schema , "strict" : True }
642+ return parameters
630643
631644def set_default_granite_model_parameters (
632645 model_id : str ,
@@ -636,41 +649,37 @@ def set_default_granite_model_parameters(
636649 if parameters is None :
637650 parameters = {}
638651
639- if spec is not None :
640- schema = pdltype_to_jsonschema (spec , True )
641- parameters ["guided_decoding_backend" ] = "lm-format-enforcer"
642- parameters ["guided_json" ] = schema
643-
644- # if "decoding_method" not in parameters:
645- # parameters["decoding_method"] = (
646- # DECODING_METHOD # pylint: disable=attribute-defined-outside-init
647- # )
648- # if "max_tokens" in parameters and parameters["max_tokens"] is None:
649- # parameters["max_tokens"] = (
650- # MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
651- # )
652- # if "min_new_tokens" not in parameters:
653- # parameters["min_new_tokens"] = (
654- # MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
655- # )
656- # if "repetition_penalty" not in parameters:
657- # parameters["repetition_penalty"] = (
658- # REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
659- # )
660- # if parameters["decoding_method"] == "sample":
661- # if "temperature" not in parameters:
662- # parameters["temperature"] = (
663- # TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
664- # )
665- # if "top_k" not in parameters:
666- # parameters["top_k"] = (
667- # TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
668- # )
669- # if "top_p" not in parameters:
670- # parameters["top_p"] = (
671- # TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
672- # )
673- if "granite-3.0" in model_id :
652+ if "watsonx" in model_id :
653+ if "decoding_method" not in parameters :
654+ parameters ["decoding_method" ] = (
655+ DECODING_METHOD # pylint: disable=attribute-defined-outside-init
656+ )
657+ if "max_tokens" in parameters and parameters ["max_tokens" ] is None :
658+ parameters ["max_tokens" ] = (
659+ MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
660+ )
661+ if "min_new_tokens" not in parameters :
662+ parameters ["min_new_tokens" ] = (
663+ MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
664+ )
665+ if "repetition_penalty" not in parameters :
666+ parameters ["repetition_penalty" ] = (
667+ REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
668+ )
669+ if parameters ["decoding_method" ] == "sample" :
670+ if "temperature" not in parameters :
671+ parameters ["temperature" ] = (
672+ TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
673+ )
674+ if "top_k" not in parameters :
675+ parameters ["top_k" ] = (
676+ TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
677+ )
678+ if "top_p" not in parameters :
679+ parameters ["top_p" ] = (
680+ TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
681+ )
682+ if "replicate" in model_id and "granite-3.0" in model_id :
674683 if "temperature" not in parameters or parameters ["temperature" ] is None :
675684 parameters ["temperature" ] = 0 # setting to decoding greedy
676685 if "roles" not in parameters :
0 commit comments