diff --git a/src/pdl/pdl_ast.py b/src/pdl/pdl_ast.py index a21540504..36c45c660 100644 --- a/src/pdl/pdl_ast.py +++ b/src/pdl/pdl_ast.py @@ -628,49 +628,64 @@ def set_default_model_params( return params -def set_default_granite_model_parameters( - model_id: str, +def set_structured_decoding_parameters( spec: Any, parameters: Optional[dict[str, Any]], ) -> dict[str, Any]: if parameters is None: parameters = {} - if spec is not None: + if ( + spec is not None + and parameters["response_format"] is None + and "guided_decoding_backend" not in parameters + ): schema = pdltype_to_jsonschema(spec, True) parameters["guided_decoding_backend"] = "lm-format-enforcer" parameters["guided_json"] = schema + # parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True } + return parameters + + +def set_default_granite_model_parameters( + model_id: str, + spec: Any, + parameters: Optional[dict[str, Any]], +) -> dict[str, Any]: + if parameters is None: + parameters = {} - # if "decoding_method" not in parameters: - # parameters["decoding_method"] = ( - # DECODING_METHOD # pylint: disable=attribute-defined-outside-init - # ) - # if "max_tokens" in parameters and parameters["max_tokens"] is None: - # parameters["max_tokens"] = ( - # MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init - # ) - # if "min_new_tokens" not in parameters: - # parameters["min_new_tokens"] = ( - # MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init - # ) - # if "repetition_penalty" not in parameters: - # parameters["repetition_penalty"] = ( - # REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init - # ) - # if parameters["decoding_method"] == "sample": - # if "temperature" not in parameters: - # parameters["temperature"] = ( - # TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init - # ) - # if "top_k" not in parameters: - # parameters["top_k"] = ( - # TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init - # ) - # if "top_p" not in parameters: - # parameters["top_p"] = ( - # TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init - # ) - if "granite-3.0" in model_id: + if "watsonx" in model_id: + if "decoding_method" not in parameters: + parameters["decoding_method"] = ( + DECODING_METHOD # pylint: disable=attribute-defined-outside-init + ) + if "max_tokens" in parameters and parameters["max_tokens"] is None: + parameters["max_tokens"] = ( + MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + ) + if "min_new_tokens" not in parameters: + parameters["min_new_tokens"] = ( + MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init + ) + if "repetition_penalty" not in parameters: + parameters["repetition_penalty"] = ( + REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init + ) + if parameters["decoding_method"] == "sample": + if "temperature" not in parameters: + parameters["temperature"] = ( + TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init + ) + if "top_k" not in parameters: + parameters["top_k"] = ( + TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init + ) + if "top_p" not in parameters: + parameters["top_p"] = ( + TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init + ) + if "replicate" in model_id and "granite-3.0" in model_id: if "temperature" not in parameters or parameters["temperature"] is None: parameters["temperature"] = 0 # setting to decoding greedy if "roles" not in parameters: diff --git a/src/pdl/pdl_interpreter.py b/src/pdl/pdl_interpreter.py index b268d9d73..7af1007bc 100644 --- a/src/pdl/pdl_interpreter.py +++ b/src/pdl/pdl_interpreter.py @@ -1128,10 +1128,10 @@ def get_transformed_inputs(kwargs): msg, raw_result = yield from generate_client_response( state, concrete_block, model_input ) - if "input" in litellm_params: - append_log(state, "Model Input", litellm_params["input"]) - else: - append_log(state, "Model Input", messages_to_str(model_input)) + # if "input" in litellm_params: + append_log(state, "Model Input", litellm_params) + # else: + # append_log(state, "Model Input", messages_to_str(model_input)) background: Messages = [msg] result = "" if msg["content"] is None else msg["content"] append_log(state, "Model Output", result) diff --git a/src/pdl/pdl_llms.py b/src/pdl/pdl_llms.py index b5715ea34..3ab5e7bac 100644 --- a/src/pdl/pdl_llms.py +++ b/src/pdl/pdl_llms.py @@ -14,6 +14,7 @@ Message, set_default_granite_model_parameters, set_default_model_params, + set_structured_decoding_parameters, ) # Load environment variables @@ -155,6 +156,7 @@ def generate_text( parameters = set_default_granite_model_parameters( model_id, spec, parameters ) + parameters = set_structured_decoding_parameters(spec, parameters) if parameters.get("mock_response") is not None: litellm.suppress_debug_info = True response = completion( @@ -176,6 +178,7 @@ def generate_text_stream( parameters = set_default_granite_model_parameters( model_id, spec, parameters ) + parameters = set_structured_decoding_parameters(spec, parameters) response = completion( model=model_id, messages=messages,