Skip to content

Commit b5e4534

Browse files
authored
changes to default parameters, structured decoding support (#205)
* changes to default parameters, structured decoding support Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 6c8f509 commit b5e4534

File tree

3 files changed

+55
-37
lines changed

3 files changed

+55
-37
lines changed

src/pdl/pdl_ast.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -628,49 +628,64 @@ def set_default_model_params(
628628
return params
629629

630630

631-
def set_default_granite_model_parameters(
632-
model_id: str,
631+
def set_structured_decoding_parameters(
633632
spec: Any,
634633
parameters: Optional[dict[str, Any]],
635634
) -> dict[str, Any]:
636635
if parameters is None:
637636
parameters = {}
638637

639-
if spec is not None:
638+
if (
639+
spec is not None
640+
and parameters["response_format"] is None
641+
and "guided_decoding_backend" not in parameters
642+
):
640643
schema = pdltype_to_jsonschema(spec, True)
641644
parameters["guided_decoding_backend"] = "lm-format-enforcer"
642645
parameters["guided_json"] = schema
646+
# parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True }
647+
return parameters
648+
649+
650+
def set_default_granite_model_parameters(
651+
model_id: str,
652+
spec: Any,
653+
parameters: Optional[dict[str, Any]],
654+
) -> dict[str, Any]:
655+
if parameters is None:
656+
parameters = {}
643657

644-
# if "decoding_method" not in parameters:
645-
# parameters["decoding_method"] = (
646-
# DECODING_METHOD # pylint: disable=attribute-defined-outside-init
647-
# )
648-
# if "max_tokens" in parameters and parameters["max_tokens"] is None:
649-
# parameters["max_tokens"] = (
650-
# MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
651-
# )
652-
# if "min_new_tokens" not in parameters:
653-
# parameters["min_new_tokens"] = (
654-
# MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
655-
# )
656-
# if "repetition_penalty" not in parameters:
657-
# parameters["repetition_penalty"] = (
658-
# REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
659-
# )
660-
# if parameters["decoding_method"] == "sample":
661-
# if "temperature" not in parameters:
662-
# parameters["temperature"] = (
663-
# TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
664-
# )
665-
# if "top_k" not in parameters:
666-
# parameters["top_k"] = (
667-
# TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
668-
# )
669-
# if "top_p" not in parameters:
670-
# parameters["top_p"] = (
671-
# TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
672-
# )
673-
if "granite-3.0" in model_id:
658+
if "watsonx" in model_id:
659+
if "decoding_method" not in parameters:
660+
parameters["decoding_method"] = (
661+
DECODING_METHOD # pylint: disable=attribute-defined-outside-init
662+
)
663+
if "max_tokens" in parameters and parameters["max_tokens"] is None:
664+
parameters["max_tokens"] = (
665+
MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
666+
)
667+
if "min_new_tokens" not in parameters:
668+
parameters["min_new_tokens"] = (
669+
MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
670+
)
671+
if "repetition_penalty" not in parameters:
672+
parameters["repetition_penalty"] = (
673+
REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
674+
)
675+
if parameters["decoding_method"] == "sample":
676+
if "temperature" not in parameters:
677+
parameters["temperature"] = (
678+
TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
679+
)
680+
if "top_k" not in parameters:
681+
parameters["top_k"] = (
682+
TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
683+
)
684+
if "top_p" not in parameters:
685+
parameters["top_p"] = (
686+
TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
687+
)
688+
if "replicate" in model_id and "granite-3.0" in model_id:
674689
if "temperature" not in parameters or parameters["temperature"] is None:
675690
parameters["temperature"] = 0 # setting to decoding greedy
676691
if "roles" not in parameters:

src/pdl/pdl_interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,10 +1128,10 @@ def get_transformed_inputs(kwargs):
11281128
msg, raw_result = yield from generate_client_response(
11291129
state, concrete_block, model_input
11301130
)
1131-
if "input" in litellm_params:
1132-
append_log(state, "Model Input", litellm_params["input"])
1133-
else:
1134-
append_log(state, "Model Input", messages_to_str(model_input))
1131+
# if "input" in litellm_params:
1132+
append_log(state, "Model Input", litellm_params)
1133+
# else:
1134+
# append_log(state, "Model Input", messages_to_str(model_input))
11351135
background: Messages = [msg]
11361136
result = "" if msg["content"] is None else msg["content"]
11371137
append_log(state, "Model Output", result)

src/pdl/pdl_llms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Message,
1515
set_default_granite_model_parameters,
1616
set_default_model_params,
17+
set_structured_decoding_parameters,
1718
)
1819

1920
# Load environment variables
@@ -155,6 +156,7 @@ def generate_text(
155156
parameters = set_default_granite_model_parameters(
156157
model_id, spec, parameters
157158
)
159+
parameters = set_structured_decoding_parameters(spec, parameters)
158160
if parameters.get("mock_response") is not None:
159161
litellm.suppress_debug_info = True
160162
response = completion(
@@ -176,6 +178,7 @@ def generate_text_stream(
176178
parameters = set_default_granite_model_parameters(
177179
model_id, spec, parameters
178180
)
181+
parameters = set_structured_decoding_parameters(spec, parameters)
179182
response = completion(
180183
model=model_id,
181184
messages=messages,

0 commit comments

Comments
 (0)