Skip to content

Commit ae37f8b

Browse files
authored
support for structured decoding, removed default watsonx parameters (#168)
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent 004a9c1 commit ae37f8b

File tree

6 files changed

+77
-52
lines changed

6 files changed

+77
-52
lines changed

src/pdl/pdl_ast.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
)
1313
from pydantic import BaseModel, ConfigDict, Field, RootModel
1414

15+
from .pdl_schema_utils import pdltype_to_jsonschema
16+
1517
ScopeType: TypeAlias = dict[str, Any]
1618

1719
ExpressionType: TypeAlias = Any
@@ -617,44 +619,49 @@ def set_default_model_params(
617619

618620
def set_default_granite_model_parameters(
619621
model_id: str,
622+
spec: Any,
620623
parameters: Optional[dict[str, Any]],
621624
) -> dict[str, Any]:
622625
if parameters is None:
623626
parameters = {}
624627

625-
if "decoding_method" not in parameters:
626-
parameters["decoding_method"] = (
627-
DECODING_METHOD # pylint: disable=attribute-defined-outside-init
628-
)
629-
if "max_tokens" in parameters and parameters["max_tokens"] is None:
630-
parameters["max_tokens"] = (
631-
MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
632-
)
633-
if "min_new_tokens" not in parameters:
634-
parameters["min_new_tokens"] = (
635-
MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
636-
)
637-
if "repetition_penalty" not in parameters:
638-
parameters["repetition_penalty"] = (
639-
REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
640-
)
641-
if parameters["decoding_method"] == "sample":
642-
if "temperature" not in parameters:
643-
parameters["temperature"] = (
644-
TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
645-
)
646-
if "top_k" not in parameters:
647-
parameters["top_k"] = (
648-
TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
649-
)
650-
if "top_p" not in parameters:
651-
parameters["top_p"] = (
652-
TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
653-
)
628+
if spec is not None:
629+
schema = pdltype_to_jsonschema(spec, True)
630+
parameters["guided_decoding_backend"] = "lm-format-enforcer"
631+
parameters["guided_json"] = schema
632+
633+
# if "decoding_method" not in parameters:
634+
# parameters["decoding_method"] = (
635+
# DECODING_METHOD # pylint: disable=attribute-defined-outside-init
636+
# )
637+
# if "max_tokens" in parameters and parameters["max_tokens"] is None:
638+
# parameters["max_tokens"] = (
639+
# MAX_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
640+
# )
641+
# if "min_new_tokens" not in parameters:
642+
# parameters["min_new_tokens"] = (
643+
# MIN_NEW_TOKENS # pylint: disable=attribute-defined-outside-init
644+
# )
645+
# if "repetition_penalty" not in parameters:
646+
# parameters["repetition_penalty"] = (
647+
# REPETITION_PENATLY # pylint: disable=attribute-defined-outside-init
648+
# )
649+
# if parameters["decoding_method"] == "sample":
650+
# if "temperature" not in parameters:
651+
# parameters["temperature"] = (
652+
# TEMPERATURE_SAMPLING # pylint: disable=attribute-defined-outside-init
653+
# )
654+
# if "top_k" not in parameters:
655+
# parameters["top_k"] = (
656+
# TOP_K_SAMPLING # pylint: disable=attribute-defined-outside-init
657+
# )
658+
# if "top_p" not in parameters:
659+
# parameters["top_p"] = (
660+
# TOP_P_SAMPLING # pylint: disable=attribute-defined-outside-init
661+
# )
654662
if "granite-3.0" in model_id:
655663
if "temperature" not in parameters or parameters["temperature"] is None:
656664
parameters["temperature"] = 0 # setting to decoding greedy
657-
658665
if "roles" not in parameters:
659666
parameters["roles"] = {
660667
"system": {

src/pdl/pdl_interpreter.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,6 +1127,7 @@ def generate_client_response_streaming(
11271127
msg_stream = LitellmModel.generate_text_stream(
11281128
model_id=block.model,
11291129
messages=model_input,
1130+
spec=block.spec,
11301131
parameters=litellm_parameters_to_dict(block.parameters),
11311132
)
11321133
case _:
@@ -1186,6 +1187,7 @@ def generate_client_response_single(
11861187
msg, raw_result = LitellmModel.generate_text(
11871188
model_id=block.model,
11881189
messages=model_input,
1190+
spec=block.spec,
11891191
parameters=litellm_parameters_to_dict(block.parameters),
11901192
)
11911193
if state.yield_result:

src/pdl/pdl_llms.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,13 @@ def get_model() -> None:
148148
def generate_text(
149149
model_id: str,
150150
messages: list[Message],
151+
spec: Any,
151152
parameters: dict[str, Any],
152153
) -> tuple[Message, Any]:
153154
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
154-
parameters = set_default_granite_model_parameters(model_id, parameters)
155+
parameters = set_default_granite_model_parameters(
156+
model_id, spec, parameters
157+
)
155158
if parameters.get("mock_response") is not None:
156159
litellm.suppress_debug_info = True
157160
response = completion(
@@ -169,10 +172,13 @@ def generate_text(
169172
def generate_text_stream(
170173
model_id: str,
171174
messages: list[Message],
175+
spec: Any,
172176
parameters: dict[str, Any],
173177
) -> Generator[Message, Any, Any]:
174178
if "granite" in model_id and "granite-20b-code-instruct-r1.1" not in model_id:
175-
parameters = set_default_granite_model_parameters(model_id, parameters)
179+
parameters = set_default_granite_model_parameters(
180+
model_id, spec, parameters
181+
)
176182
response = completion(
177183
model=model_id,
178184
messages=messages,

src/pdl/pdl_schema_utils.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def convert_to_json_type(a_type):
3030

3131

3232
def pdltype_to_jsonschema( # pylint: disable=too-many-return-statements
33-
pdl_type: str | dict[str, Any] | list
33+
pdl_type: str | dict[str, Any] | list, additional_properties: bool
3434
) -> dict[str, Any]:
3535
match pdl_type:
3636
case {"enum": choices}:
@@ -46,7 +46,10 @@ def pdltype_to_jsonschema( # pylint: disable=too-many-return-statements
4646
case {"int": dict() as details}:
4747
return {"type": "integer", **details}
4848
case {"list": str() as type_name}:
49-
return {"type": "array", "items": pdltype_to_jsonschema(type_name)}
49+
return {
50+
"type": "array",
51+
"items": pdltype_to_jsonschema(type_name, additional_properties),
52+
}
5053
case {"list": dict() as details}:
5154
ikws = ["enum", *_PDLTYPE_TO_JSONSCHEMA_NAME.keys()]
5255
items_details = {k: v for k, v in details.items() if k in ikws}
@@ -55,43 +58,50 @@ def pdltype_to_jsonschema( # pylint: disable=too-many-return-statements
5558
other_details = {k: v for k, v in details.items() if k not in ikws}
5659
return {
5760
"type": "array",
58-
"items": pdltype_to_jsonschema(items_details),
61+
"items": pdltype_to_jsonschema(items_details, additional_properties),
5962
**other_details,
6063
}
6164
case list() as type_list:
6265
if len(type_list) != 1:
6366
raise ValueError(f"invalid PDL type {pdl_type}")
6467
return {
6568
"type": "array",
66-
"items": pdltype_to_jsonschema(type_list[0]),
69+
"items": pdltype_to_jsonschema(type_list[0], additional_properties),
6770
}
6871
case {"obj": dict() as pdl_props}:
69-
return get_json_schema_object(pdl_props)
72+
return get_json_schema_object(pdl_props, additional_properties)
7073
case dict() as pdl_props:
71-
return get_json_schema_object(pdl_props)
74+
return get_json_schema_object(pdl_props, additional_properties)
7275
raise ValueError(f"invalid PDL type {pdl_type}")
7376

7477

75-
def get_json_schema_object(pdl_props: dict) -> dict[str, Any]:
78+
def get_json_schema_object(pdl_props: dict, additional_properties) -> dict[str, Any]:
7679
props = {}
7780
required = []
7881
for name, prop_type in pdl_props.items():
7982
if isinstance(prop_type, dict) and "optional" in prop_type:
80-
props[name] = pdltype_to_jsonschema(prop_type["optional"])
83+
props[name] = pdltype_to_jsonschema(
84+
prop_type["optional"], additional_properties
85+
)
8186
else:
82-
props[name] = pdltype_to_jsonschema(prop_type)
87+
props[name] = pdltype_to_jsonschema(prop_type, additional_properties)
8388
required.append(name)
84-
return {
85-
"type": "object",
86-
"properties": props,
87-
"required": required,
88-
"additionalProperties": False,
89-
}
89+
if additional_properties is False:
90+
return {
91+
"type": "object",
92+
"properties": props,
93+
"required": required,
94+
"additionalProperties": False,
95+
}
96+
97+
return {"type": "object", "properties": props, "required": required}
9098

9199

92-
def get_json_schema(params: dict[str, Any]) -> Optional[dict[str, Any]]:
100+
def get_json_schema(
101+
params: dict[str, Any], additional_properties
102+
) -> Optional[dict[str, Any]]:
93103
try:
94-
result = pdltype_to_jsonschema({"obj": params})
104+
result = pdltype_to_jsonschema({"obj": params}, additional_properties)
95105
return result
96106
except ValueError as e:
97107
warnings.warn(e.args[0])

src/pdl/pdl_schema_validator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@ def type_check_args(args: dict[str, Any], params: dict[str, Any], loc) -> list[s
2222
if "pdl_context" in args_copy:
2323
# params_copy["pdl_context"] = [{"role": "str?", "content": "str"}]
2424
params_copy["pdl_context"] = ["obj"]
25-
schema = get_json_schema(params_copy)
25+
schema = get_json_schema(params_copy, False)
2626
if schema is None:
2727
return ["Error obtaining a valid schema from function parameters definition"]
2828
return type_check(args_copy, schema, loc)
2929

3030

3131
def type_check_spec(result: Any, spec: str | dict[str, Any] | list, loc) -> list[str]:
32-
schema = pdltype_to_jsonschema(spec)
32+
schema = pdltype_to_jsonschema(spec, False)
3333
if schema is None:
3434
return ["Error obtaining a valid schema from spec"]
3535
return type_check(result, schema, loc)

tests/test_type_checking.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145
def test_pdltype_to_jsonschema():
146146
for t in _PDLTYPE_TO_JSONSCHEMA_TESTS:
147147
pdl_type = yaml.safe_load(t["pdl_type"])
148-
json_schema = pdltype_to_jsonschema(pdl_type)
148+
json_schema = pdltype_to_jsonschema(pdl_type, False)
149149
assert json_schema == t["json_schema"]
150150

151151

0 commit comments

Comments
 (0)