@@ -628,49 +628,64 @@ def set_default_model_params(
628628 return params
629629
630630
631- def set_default_granite_model_parameters (
632- model_id : str ,
631+ def set_structured_decoding_parameters (
633632 spec : Any ,
634633 parameters : Optional [dict [str , Any ]],
635634) -> dict [str , Any ]:
636635 if parameters is None :
637636 parameters = {}
638637
639- if spec is not None :
638+ if (
639+ spec is not None
640+ and parameters ["response_format" ] is None
641+ and "guided_decoding_backend" not in parameters
642+ ):
640643 schema = pdltype_to_jsonschema (spec , True )
641644 parameters ["guided_decoding_backend" ] = "lm-format-enforcer"
642645 parameters ["guided_json" ] = schema
646+ # parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True }
647+ return parameters
648+
649+
650+ def set_default_granite_model_parameters (
651+ model_id : str ,
652+ spec : Any ,
653+ parameters : Optional [dict [str , Any ]],
654+ ) -> dict [str , Any ]:
655+ if parameters is None :
656+ parameters = {}
643657
644- # if "decoding_method" not in parameters:
645- # parameters["decoding_method"] = (
646- # DECODING_METHOD # pylint: disable=attribute-defined-outside-init
647- # )
648- # if "max_tokens" in parameters and parameters["max_tokens"] is None:
649- # parameters["max_tokens"] = (
650- # MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
651- # )
652- # if "min_new_tokens" not in parameters:
653- # parameters["min_new_tokens"] = (
654- # MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
655- # )
656- # if "repetition_penalty" not in parameters:
657- # parameters["repetition_penalty"] = (
658- # REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
659- # )
660- # if parameters["decoding_method"] == "sample":
661- # if "temperature" not in parameters:
662- # parameters["temperature"] = (
663- # TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
664- # )
665- # if "top_k" not in parameters:
666- # parameters["top_k"] = (
667- # TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
668- # )
669- # if "top_p" not in parameters:
670- # parameters["top_p"] = (
671- # TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
672- # )
673- if "granite-3.0" in model_id :
658+ if "watsonx" in model_id :
659+ if "decoding_method" not in parameters :
660+ parameters ["decoding_method" ] = (
661+ DECODING_METHOD # pylint: disable=attribute-defined-outside-init
662+ )
663+ if "max_tokens" in parameters and parameters ["max_tokens" ] is None :
664+ parameters ["max_tokens" ] = (
665+ MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
666+ )
667+ if "min_new_tokens" not in parameters :
668+ parameters ["min_new_tokens" ] = (
669+ MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
670+ )
671+ if "repetition_penalty" not in parameters :
672+ parameters ["repetition_penalty" ] = (
673+ REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
674+ )
675+ if parameters ["decoding_method" ] == "sample" :
676+ if "temperature" not in parameters :
677+ parameters ["temperature" ] = (
678+ TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
679+ )
680+ if "top_k" not in parameters :
681+ parameters ["top_k" ] = (
682+ TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
683+ )
684+ if "top_p" not in parameters :
685+ parameters ["top_p" ] = (
686+ TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
687+ )
688+ if "replicate" in model_id and "granite-3.0" in model_id :
674689 if "temperature" not in parameters or parameters ["temperature" ] is None :
675690 parameters ["temperature" ] = 0 # setting to decoding greedy
676691 if "roles" not in parameters :
0 commit comments