Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
427 changes: 378 additions & 49 deletions pdl-live-react/src/pdl_ast.d.ts

Large diffs are not rendered by default.

229 changes: 43 additions & 186 deletions src/pdl/pdl-schema.json

Large diffs are not rendered by default.

16 changes: 11 additions & 5 deletions src/pdl/pdl_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
BeforeValidator,
ConfigDict,
Field,
Json,
RootModel,
TypeAdapter,
)
Expand Down Expand Up @@ -236,15 +237,16 @@ class JsonSchemaTypePdlType(PdlType):


class ObjPdlType(PdlType):
"""Optional type."""
"""Object type."""

obj: Optional[dict[str, "PdlTypeType"]]


PdlTypeType = TypeAliasType(
"PdlTypeType",
Annotated[
"Union[BasePdlType," # pyright: ignore
"Union[None," # pyright: ignore
" BasePdlType,"
" EnumPdlType,"
" StrPdlType,"
" FloatPdlType,"
Expand All @@ -269,7 +271,7 @@ class Parser(BaseModel):
description: Optional[str] = None
"""Documentation associated to the parser.
"""
spec: Optional[PdlTypeType] = None
spec: PdlTypeType = None
"""Expected type of the parsed value.
"""

Expand Down Expand Up @@ -348,7 +350,7 @@ class Block(BaseModel):
description: Optional[str] = None
"""Documentation associated to the block.
"""
spec: Optional[PdlTypeType] = None
spec: PdlTypeType = None
"""Type specification of the result of the block.
"""
defs: dict[str, "BlockType"] = {}
Expand Down Expand Up @@ -416,8 +418,12 @@ class FunctionBlock(LeafBlock):
"""Functions parameters with their types.
"""
returns: "BlockType" = Field(..., alias="return")
"""Body of the function
"""Body of the function.
"""
signature: Optional[Json] = None
"""Function signature computed from the function definition.
"""

# Field for internal use
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(default=None, repr=False)

Expand Down
4 changes: 3 additions & 1 deletion src/pdl/pdl_dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,8 +317,10 @@ def expr_to_dict(expr: ExpressionType, json_compatible: bool):


def type_to_dict(t: PdlTypeType):
d: str | list | dict
d: None | str | list | dict
match t:
case None:
d = None
case "null" | "bool" | "str" | "float" | "int" | "list" | "obj":
d = t
case EnumPdlType():
Expand Down
13 changes: 13 additions & 0 deletions src/pdl/pdl_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from .pdl_parser import PDLParseError, parse_file, parse_str # noqa: E402
from .pdl_python_repl import PythonREPL # noqa: E402
from .pdl_scheduler import yield_background, yield_result # noqa: E402
from .pdl_schema_utils import get_json_schema # noqa: E402
from .pdl_schema_validator import type_check_args, type_check_spec # noqa: E402
from .pdl_utils import ( # noqa: E402
GeneratorWrapper,
Expand Down Expand Up @@ -894,6 +895,16 @@ def process_block_body(
if block.def_ is not None:
scope = scope | {block.def_: closure}
closure.pdl__scope = scope
signature: dict[str, Any] = {"type": "function"}
if block.def_ is not None:
signature["name"] = block.def_
if block.description is not None:
signature["description"] = block.description
if block.function is not None:
signature["parameters"] = get_json_schema(block.function, False) or {}
else:
signature["parameters"] = {}
closure.signature = signature
result = PdlConst(closure)
background = PdlList([])
trace = closure.model_copy(update={})
Expand Down Expand Up @@ -976,6 +987,8 @@ def process_defs(
state = state.with_iter(idx)
state = state.with_yield_result(False)
state = state.with_yield_background(False)
if isinstance(block, FunctionBlock) and block.def_ is None:
block = block.model_copy(update={"def_": x})
result, _, _, block_trace = process_block(state, scope, block, newloc)
scope = scope | PdlDict({x: result})
defs_trace[x] = block_trace
Expand Down
2 changes: 1 addition & 1 deletion src/pdl/pdl_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def generate_text_stream(


def set_structured_decoding_parameters(
spec: Optional[PdlTypeType],
spec: PdlTypeType,
parameters: Optional[dict[str, Any]],
) -> dict[str, Any]:
if parameters is None:
Expand Down
2 changes: 1 addition & 1 deletion src/pdl/pdl_schema_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def convert_to_json_type(a_type):


def pdltype_to_jsonschema(
pdl_type: Optional[PdlTypeType], additional_properties: bool
pdl_type: PdlTypeType, additional_properties: bool
) -> dict[str, Any]:
schema: dict[str, Any]
match pdl_type:
Expand Down
6 changes: 4 additions & 2 deletions src/pdl/pdl_schema_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@


def type_check_args(
args: Optional[dict[str, Any]], params: Optional[dict[str, Any]], loc
args: Optional[dict[str, Any]],
params: Optional[dict[str, PdlTypeType]],
loc,
) -> list[str]:
if (args == {} or args is None) and (params is None or params == {}):
return []
Expand All @@ -35,7 +37,7 @@ def type_check_args(
return type_check(args_copy, schema, loc)


def type_check_spec(result: Any, spec: Optional[PdlTypeType], loc) -> list[str]:
def type_check_spec(result: Any, spec: PdlTypeType, loc) -> list[str]:
schema = pdltype_to_jsonschema(spec, False)
if schema is None:
return ["Error obtaining a valid schema from spec"]
Expand Down
27 changes: 27 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,17 @@ def test_function_call():
assert text == "Hello world!"


def test_hello_signature():
result = exec_dict(hello_def, output="all")
closure = result["scope"]["hello"]
assert closure.signature == {
"name": hello_def["def"],
"description": hello_def["description"],
"type": "function",
"parameters": {},
}


hello_params = {
"description": "Call hello",
"text": [
Expand All @@ -39,6 +50,22 @@ def test_function_params():
assert text == "Hello World!"


def test_hello_params_signature():
result = exec_dict(hello_params, output="all")
closure = result["scope"]["hello"]
assert closure.signature == {
"name": hello_params["text"][0]["def"],
"description": hello_params["text"][0]["description"],
"type": "function",
"parameters": {
"type": "object",
"properties": {"name": {"type": "string"}},
"required": ["name"],
"additionalProperties": False,
},
}


hello_stutter = {
"description": "Repeat the context",
"text": [
Expand Down