Skip to content

Commit 6bb96c4

Browse files
committed
changes to default parameters, structured decoding support
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 6c8f509 commit 6bb96c4

File tree

3 files changed

+51
-39
lines changed

3 files changed

+51
-39
lines changed

src/pdl/pdl_ast.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,19 @@ def set_default_model_params(
627627
)
628628
return params
629629

630+
def set_structured_decoding_parameters(
631+
spec: Any,
632+
parameters: Optional[dict[str, Any]],
633+
) -> dict[str, Any]:
634+
if parameters is None:
635+
parameters = {}
636+
637+
if spec is not None and parameters["response_format"] is None and "guided_decoding_backend" not in parameters :
638+
schema = pdltype_to_jsonschema(spec, True)
639+
#parameters["guided_decoding_backend"] = "lm-format-enforcer"
640+
#parameters["guided_json"] = schema
641+
parameters["response_format"] = { "type": "json_schema", "json_schema": schema , "strict": True }
642+
return parameters
630643

631644
def set_default_granite_model_parameters(
632645
model_id: str,
@@ -636,41 +649,37 @@ def set_default_granite_model_parameters(
636649
if parameters is None:
637650
parameters = {}
638651

639-
if spec is not None:
640-
schema = pdltype_to_jsonschema(spec, True)
641-
parameters["guided_decoding_backend"] = "lm-format-enforcer"
642-
parameters["guided_json"] = schema
643-
644-
# if "decoding_method" not in parameters:
645-
# parameters["decoding_method"] = (
646-
# DECODING_METHOD # pylint: disable=attribute-defined-outside-init
647-
# )
648-
# if "max_tokens" in parameters and parameters["max_tokens"] is None:
649-
# parameters["max_tokens"] = (
650-
# MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
651-
# )
652-
# if "min_new_tokens" not in parameters:
653-
# parameters["min_new_tokens"] = (
654-
# MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
655-
# )
656-
# if "repetition_penalty" not in parameters:
657-
# parameters["repetition_penalty"] = (
658-
# REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
659-
# )
660-
# if parameters["decoding_method"] == "sample":
661-
# if "temperature" not in parameters:
662-
# parameters["temperature"] = (
663-
# TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
664-
# )
665-
# if "top_k" not in parameters:
666-
# parameters["top_k"] = (
667-
# TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
668-
# )
669-
# if "top_p" not in parameters:
670-
# parameters["top_p"] = (
671-
# TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
672-
# )
673-
if "granite-3.0" in model_id:
652+
if "watsonx" in model_id:
653+
if "decoding_method" not in parameters:
654+
parameters["decoding_method"] = (
655+
DECODING_METHOD # pylint: disable=attribute-defined-outside-init
656+
)
657+
if "max_tokens" in parameters and parameters["max_tokens"] is None:
658+
parameters["max_tokens"] = (
659+
MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
660+
)
661+
if "min_new_tokens" not in parameters:
662+
parameters["min_new_tokens"] = (
663+
MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
664+
)
665+
if "repetition_penalty" not in parameters:
666+
parameters["repetition_penalty"] = (
667+
REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
668+
)
669+
if parameters["decoding_method"] == "sample":
670+
if "temperature" not in parameters:
671+
parameters["temperature"] = (
672+
TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
673+
)
674+
if "top_k" not in parameters:
675+
parameters["top_k"] = (
676+
TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
677+
)
678+
if "top_p" not in parameters:
679+
parameters["top_p"] = (
680+
TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
681+
)
682+
if "replicate" in model_id and "granite-3.0" in model_id:
674683
if "temperature" not in parameters or parameters["temperature"] is None:
675684
parameters["temperature"] = 0 # setting to decoding greedy
676685
if "roles" not in parameters:

src/pdl/pdl_interpreter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,10 +1128,10 @@ def get_transformed_inputs(kwargs):
11281128
msg, raw_result = yield from generate_client_response(
11291129
state, concrete_block, model_input
11301130
)
1131-
if "input" in litellm_params:
1132-
append_log(state, "Model Input", litellm_params["input"])
1133-
else:
1134-
append_log(state, "Model Input", messages_to_str(model_input))
1131+
#if "input" in litellm_params:
1132+
append_log(state, "Model Input", litellm_params)
1133+
#else:
1134+
# append_log(state, "Model Input", messages_to_str(model_input))
11351135
background: Messages = [msg]
11361136
result = "" if msg["content"] is None else msg["content"]
11371137
append_log(state, "Model Output", result)

src/pdl/pdl_llms.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Message,
1515
set_default_granite_model_parameters,
1616
set_default_model_params,
17+
set_structured_decoding_parameters,
1718
)
1819

1920
# Load environment variables
@@ -155,6 +156,7 @@ def generate_text(
155156
parameters = set_default_granite_model_parameters(
156157
model_id, spec, parameters
157158
)
159+
parameters = set_structured_decoding_parameters(spec, parameters)
158160
if parameters.get("mock_response") is not None:
159161
litellm.suppress_debug_info = True
160162
response = completion(
@@ -176,6 +178,7 @@ def generate_text_stream(
176178
parameters = set_default_granite_model_parameters(
177179
model_id, spec, parameters
178180
)
181+
parameters = set_structured_decoding_parameters(spec, parameters)
179182
response = completion(
180183
model=model_id,
181184
messages=messages,

0 commit comments

Comments
 (0)