@@ -628,49 +628,64 @@ def set_default_model_params(
628
628
return params
629
629
630
630
631
- def set_default_granite_model_parameters (
632
- model_id : str ,
631
+ def set_structured_decoding_parameters (
633
632
spec : Any ,
634
633
parameters : Optional [dict [str , Any ]],
635
634
) -> dict [str , Any ]:
636
635
if parameters is None :
637
636
parameters = {}
638
637
639
- if spec is not None :
638
+ if (
639
+ spec is not None
640
+ and parameters ["response_format" ] is None
641
+ and "guided_decoding_backend" not in parameters
642
+ ):
640
643
schema = pdltype_to_jsonschema (spec , True )
641
644
parameters ["guided_decoding_backend" ] = "lm-format-enforcer"
642
645
parameters ["guided_json" ] = schema
646
+ # parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True }
647
+ return parameters
648
+
649
+
650
+ def set_default_granite_model_parameters (
651
+ model_id : str ,
652
+ spec : Any ,
653
+ parameters : Optional [dict [str , Any ]],
654
+ ) -> dict [str , Any ]:
655
+ if parameters is None :
656
+ parameters = {}
643
657
644
- # if "decoding_method" not in parameters:
645
- # parameters["decoding_method"] = (
646
- # DECODING_METHOD # pylint: disable=attribute-defined-outside-init
647
- # )
648
- # if "max_tokens" in parameters and parameters["max_tokens"] is None:
649
- # parameters["max_tokens"] = (
650
- # MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
651
- # )
652
- # if "min_new_tokens" not in parameters:
653
- # parameters["min_new_tokens"] = (
654
- # MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
655
- # )
656
- # if "repetition_penalty" not in parameters:
657
- # parameters["repetition_penalty"] = (
658
- # REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
659
- # )
660
- # if parameters["decoding_method"] == "sample":
661
- # if "temperature" not in parameters:
662
- # parameters["temperature"] = (
663
- # TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
664
- # )
665
- # if "top_k" not in parameters:
666
- # parameters["top_k"] = (
667
- # TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
668
- # )
669
- # if "top_p" not in parameters:
670
- # parameters["top_p"] = (
671
- # TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
672
- # )
673
- if "granite-3.0" in model_id :
658
+ if "watsonx" in model_id :
659
+ if "decoding_method" not in parameters :
660
+ parameters ["decoding_method" ] = (
661
+ DECODING_METHOD # pylint: disable=attribute-defined-outside-init
662
+ )
663
+ if "max_tokens" in parameters and parameters ["max_tokens" ] is None :
664
+ parameters ["max_tokens" ] = (
665
+ MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
666
+ )
667
+ if "min_new_tokens" not in parameters :
668
+ parameters ["min_new_tokens" ] = (
669
+ MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
670
+ )
671
+ if "repetition_penalty" not in parameters :
672
+ parameters ["repetition_penalty" ] = (
673
+ REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
674
+ )
675
+ if parameters ["decoding_method" ] == "sample" :
676
+ if "temperature" not in parameters :
677
+ parameters ["temperature" ] = (
678
+ TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
679
+ )
680
+ if "top_k" not in parameters :
681
+ parameters ["top_k" ] = (
682
+ TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
683
+ )
684
+ if "top_p" not in parameters :
685
+ parameters ["top_p" ] = (
686
+ TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
687
+ )
688
+ if "replicate" in model_id and "granite-3.0" in model_id :
674
689
if "temperature" not in parameters or parameters ["temperature" ] is None :
675
690
parameters ["temperature" ] = 0 # setting to decoding greedy
676
691
if "roles" not in parameters :
0 commit comments