Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 48 additions & 33 deletions src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,49 +628,64 @@ def set_default_model_params(
return params


def set_default_granite_model_parameters(
model_id: str,
def set_structured_decoding_parameters(
spec: Any,
parameters: Optional[dict[str, Any]],
) -> dict[str, Any]:
if parameters is None:
parameters = {}

if spec is not None:
if (
spec is not None
and parameters["response_format"] is None
and "guided_decoding_backend" not in parameters
):
schema = pdltype_to_jsonschema(spec, True)
parameters["guided_decoding_backend"] = "lm-format-enforcer"
parameters["guided_json"] = schema
# parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True }
return parameters


def set_default_granite_model_parameters(
model_id: str,
spec: Any,
parameters: Optional[dict[str, Any]],
) -> dict[str, Any]:
if parameters is None:
parameters = {}

# if "decoding_method" not in parameters:
# parameters["decoding_method"] = (
# DECODING_METHOD # pylint: disable=attribute-defined-outside-init
# )
# if "max_tokens" in parameters and parameters["max_tokens"] is None:
# parameters["max_tokens"] = (
# MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
# )
# if "min_new_tokens" not in parameters:
# parameters["min_new_tokens"] = (
# MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
# )
# if "repetition_penalty" not in parameters:
# parameters["repetition_penalty"] = (
# REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
# )
# if parameters["decoding_method"] == "sample":
# if "temperature" not in parameters:
# parameters["temperature"] = (
# TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
# )
# if "top_k" not in parameters:
# parameters["top_k"] = (
# TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
# )
# if "top_p" not in parameters:
# parameters["top_p"] = (
# TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
# )
if "granite-3.0" in model_id:
if "watsonx" in model_id:
if "decoding_method" not in parameters:
parameters["decoding_method"] = (
DECODING_METHOD # pylint: disable=attribute-defined-outside-init
)
if "max_tokens" in parameters and parameters["max_tokens"] is None:
parameters["max_tokens"] = (
MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
)
if "min_new_tokens" not in parameters:
parameters["min_new_tokens"] = (
MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
)
if "repetition_penalty" not in parameters:
parameters["repetition_penalty"] = (
REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
)
if parameters["decoding_method"] == "sample":
if "temperature" not in parameters:
parameters["temperature"] = (
TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
)
if "top_k" not in parameters:
parameters["top_k"] = (
TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
)
if "top_p" not in parameters:
parameters["top_p"] = (
TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
)
if "replicate" in model_id and "granite-3.0" in model_id:
if "temperature" not in parameters or parameters["temperature"] is None:
parameters["temperature"] = 0 # setting to decoding greedy
if "roles" not in parameters:
Expand Down
8 changes: 4 additions & 4 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,10 +1128,10 @@ def get_transformed_inputs(kwargs):
msg, raw_result = yield from generate_client_response(
state, concrete_block, model_input
)
if "input" in litellm_params:
append_log(state, "Model Input", litellm_params["input"])
else:
append_log(state, "Model Input", messages_to_str(model_input))
# if "input" in litellm_params:
append_log(state, "Model Input", litellm_params)
# else:
# append_log(state, "Model Input", messages_to_str(model_input))
background: Messages = [msg]
result = "" if msg["content"] is None else msg["content"]
append_log(state, "Model Output", result)
Expand Down
3 changes: 3 additions & 0 deletions src/pdl/pdl_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Message,
set_default_granite_model_parameters,
set_default_model_params,
set_structured_decoding_parameters,
)

# Load environment variables
Expand Down Expand Up @@ -155,6 +156,7 @@ def generate_text(
parameters = set_default_granite_model_parameters(
model_id, spec, parameters
)
parameters = set_structured_decoding_parameters(spec, parameters)
if parameters.get("mock_response") is not None:
litellm.suppress_debug_info = True
response = completion(
Expand All @@ -176,6 +178,7 @@ def generate_text_stream(
parameters = set_default_granite_model_parameters(
model_id, spec, parameters
)
parameters = set_structured_decoding_parameters(spec, parameters)
response = completion(
model=model_id,
messages=messages,
Expand Down