Skip to content

Commit 1fbfde6

Browse files
authored
feat: allow unnamed arguments when calling a PDL function in Jinja (#1076)
Signed-off-by: Louis Mandel <[email protected]>
1 parent 642ce44 commit 1fbfde6

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

src/pdl/pdl_interpreter.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,24 @@ class ClosureBlock(FunctionBlock):
193193
pdl__scope: SkipJsonSchema[Optional[ScopeType]] = Field(repr=False)
194194
pdl__state: SkipJsonSchema[InterpreterState] = Field(repr=False)
195195

196-
def __call__(self, **kwds):
196+
def __call__(self, *args, **kwargs):
197197
state = self.pdl__state.with_yield_result(False).with_yield_background(False)
198198
current_context = state.current_pdl_context.ref
199+
if len(args) > 0:
200+
keys = self.function.keys() if self.function is not None else {}
201+
if len(keys) < len(args):
202+
if self.signature is not None and self.signature.get("name", "") != "":
203+
err = f"Too many arguments to the call of {self.signature['name']}"
204+
else:
205+
err = "Too many arguments to the call"
206+
raise PDLRuntimeExpressionError(
207+
err,
208+
loc=self.pdl__location,
209+
trace=self.model_copy(),
210+
)
211+
kwargs = dict(zip(keys, args)) | kwargs
199212
result, _, _ = execute_call(
200-
state, current_context, self, kwds, empty_block_location
213+
state, current_context, self, kwargs, self.pdl__location
201214
)
202215
return result
203216

tests/test_function.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,19 @@ def test_call_from_code_01():
136136
x:
137137
y:
138138
return:
139-
${x + 1}
139+
${x + y}
140140
array:
141141
- call: ${f}
142142
args:
143143
x: 1
144-
y: 1
144+
y: 2
145145
- ${ f(x=1, y=2) }
146146
- lang: python
147147
code:
148148
result = f(x=1, y=2)
149149
"""
150150
result = exec_str(prog)
151-
assert result == [2, 2, 2]
151+
assert result == [3, 3, 3]
152152

153153

154154
def test_call_from_code_02():
@@ -234,3 +234,26 @@ def test_call_from_code_04():
234234
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.1.text.0'}]",
235235
"Bye[{'role': 'user', 'content': 'Hello', 'pdl__defsite': 'lastOf.0'},{'role': 'user', 'content': 'Bye', 'pdl__defsite': 'lastOf.1.array.2.text.0.code'}]",
236236
]
237+
238+
239+
def test_call_from_code_05():
240+
prog = """
241+
defs:
242+
f:
243+
function:
244+
x:
245+
y:
246+
return:
247+
${x - y}
248+
array:
249+
- call: ${f}
250+
args:
251+
x: 2
252+
y: 1
253+
- ${ f(2, 1) }
254+
- lang: python
255+
code:
256+
result = f(2, 1)
257+
"""
258+
result = exec_str(prog)
259+
assert result == [1, 1, 1]

tests/test_type_checking.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -605,23 +605,3 @@ def test_deprecated(capsys: pytest.CaptureFixture[str]):
605605
do_test_stderr(
606606
capsys, prog, ["Deprecated type syntax: use integer instead of int.", ""]
607607
)
608-
609-
610-
def test_function_call_jinja_19():
611-
prog = """
612-
defs:
613-
f:
614-
function:
615-
x:
616-
y:
617-
return:
618-
${x + 1}
619-
array:
620-
- call: ${f}
621-
args:
622-
x: 1
623-
y: 1
624-
- ${ f(1, 2) }
625-
"""
626-
with pytest.raises(PDLRuntimeError):
627-
exec_str(prog)

0 commit comments

Comments
 (0)