Skip to content

Commit 6863617

Browse files
committed
cleanup
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent becac0e commit 6863617

File tree

3 files changed

+101
-60
lines changed

3 files changed

+101
-60
lines changed

src/pdl/pdl_context.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ class SerializeMode(StrEnum):
1212
GRANITEIO = "graniteio"
1313

1414

15-
class PDLContext():
15+
class PDLContext:
1616

1717
def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
1818
return []
@@ -60,7 +60,9 @@ def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
6060
return [x for xs in contexts for x in xs]
6161

6262

63-
def deserialize(context: list[dict[str, Any]]) -> DependentContext: # Only support dependent for now
63+
def deserialize(
64+
context: list[dict[str, Any]],
65+
) -> DependentContext: # Only support dependent for now
6466
ret: DependentContext = DependentContext(PdlList([]))
6567
for message in context:
6668
if isinstance(message, dict):

src/pdl/pdl_interpreter.py

Lines changed: 74 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -260,16 +260,21 @@ def process_block(
260260
trace=ErrorBlock(msg=exc.message, pdl__location=loc, program=block),
261261
) from exc
262262
result = PdlConst(v)
263-
background = DependentContext(PdlList([BaseMessage(
264-
{
265-
"role": state.role,
266-
"content": result,
267-
"defsite": ".".join(
268-
state.id_stack
269-
), # Warning: defsite for a literal value
270-
}
271-
)]
272-
))
263+
background = DependentContext(
264+
PdlList(
265+
[
266+
BaseMessage(
267+
{
268+
"role": state.role,
269+
"content": result,
270+
"defsite": ".".join(
271+
state.id_stack
272+
), # Warning: defsite for a literal value
273+
}
274+
)
275+
]
276+
)
277+
)
273278
trace = DataBlock(
274279
data=expr,
275280
pdl__result=result,
@@ -368,9 +373,13 @@ def process_advanced_block(
368373
state, scope, block, loc
369374
)
370375
result = lazy_apply(id_with_set_first_use_nanos(block.pdl__timing), result)
371-
background = DependentContext(PdlList(lazy_apply(
372-
id_with_set_first_use_nanos(block.pdl__timing), background.context
373-
)))
376+
background = DependentContext(
377+
PdlList(
378+
lazy_apply(
379+
id_with_set_first_use_nanos(block.pdl__timing), background.context
380+
)
381+
)
382+
)
374383
trace = trace.model_copy(update={"pdl__result": result})
375384
if block.parser is not None:
376385
parser = block.parser
@@ -547,7 +556,9 @@ def process_block_body(
547556
block.kind,
548557
append(obj_loc, k),
549558
)
550-
background = DependentContext(PdlList([background, value_background]))
559+
background = DependentContext(
560+
PdlList([background, value_background])
561+
)
551562
if (
552563
block.context is IndependentEnum.INDEPENDENT
553564
): # reset pdl_context
@@ -768,9 +779,9 @@ def process_block_body(
768779
]
769780
)
770781
scope = scope | {
771-
"pdl_context": DependentContext(PdlList([
772-
pdl_context_init, background
773-
]))
782+
"pdl_context": DependentContext(
783+
PdlList([pdl_context_init, background])
784+
)
774785
}
775786
if items is not None:
776787
for k in items.keys():
@@ -786,9 +797,9 @@ def process_block_body(
786797
block.repeat,
787798
repeat_loc,
788799
)
789-
saved_background = DependentContext(PdlList([
790-
saved_background, iteration_background
791-
]))
800+
saved_background = DependentContext(
801+
PdlList([saved_background, iteration_background])
802+
)
792803
if block.context is IndependentEnum.DEPENDENT:
793804
background = saved_background
794805
results.append(iteration_result)
@@ -1017,7 +1028,9 @@ def process_blocks( # pylint: disable=too-many-arguments,too-many-positional-ar
10171028
for i, block in enumerate(blocks):
10181029
iteration_state = iteration_state.with_iter(i)
10191030
scope = scope | {
1020-
"pdl_context": DependentContext(PdlList([pdl_context_init, background]))
1031+
"pdl_context": DependentContext(
1032+
PdlList([pdl_context_init, background])
1033+
)
10211034
}
10221035
new_loc = append(loc, "[" + str(i) + "]")
10231036
if iteration_type == IterationType.LASTOF and state.yield_result:
@@ -1029,9 +1042,9 @@ def process_blocks( # pylint: disable=too-many-arguments,too-many-positional-ar
10291042
t,
10301043
) = process_block(iteration_state, scope, block, new_loc)
10311044
results.append(iteration_result)
1032-
saved_background = DependentContext(PdlList([
1033-
saved_background, iteration_background
1034-
]))
1045+
saved_background = DependentContext(
1046+
PdlList([saved_background, iteration_background])
1047+
)
10351048
if context == IndependentEnum.DEPENDENT:
10361049
background = saved_background
10371050
trace.append(t) # type: ignore
@@ -1186,9 +1199,7 @@ def process_expr( # pylint: disable=too-many-return-statements
11861199
pdl__expr=expr, pdl__result=result, pdl__location=loc
11871200
)
11881201
if "pdl_context" in str(expr): # need to deserialize pdl_context
1189-
scope = scope | {
1190-
"pdl_context": saved_context
1191-
}
1202+
scope = scope | {"pdl_context": saved_context}
11921203
return (result, trace)
11931204

11941205

@@ -1614,12 +1625,17 @@ def process_call_code(
16141625
case "command":
16151626
try:
16161627
result = call_command(code_s, code_a)
1617-
background = DependentContext(PdlList([BaseMessage(
1618-
{
1619-
"role": state.role,
1620-
"content": result,
1621-
"defsite": block.pdl__id,
1622-
})]
1628+
background = DependentContext(
1629+
PdlList(
1630+
[
1631+
BaseMessage(
1632+
{
1633+
"role": state.role,
1634+
"content": result,
1635+
"defsite": block.pdl__id,
1636+
}
1637+
)
1638+
]
16231639
)
16241640
)
16251641
except Exception as exc:
@@ -1631,13 +1647,17 @@ def process_call_code(
16311647
case "jinja":
16321648
try:
16331649
result = call_jinja(code_s, scope)
1634-
background = DependentContext(PdlList([BaseMessage(
1635-
{
1636-
"role": state.role,
1637-
"content": result,
1638-
"defsite": block.pdl__id,
1639-
}
1640-
)]
1650+
background = DependentContext(
1651+
PdlList(
1652+
[
1653+
BaseMessage(
1654+
{
1655+
"role": state.role,
1656+
"content": result,
1657+
"defsite": block.pdl__id,
1658+
}
1659+
)
1660+
]
16411661
)
16421662
)
16431663
except Exception as exc:
@@ -1649,9 +1669,15 @@ def process_call_code(
16491669
case "pdl":
16501670
try:
16511671
result = call_pdl(code_s, scope)
1652-
background = DependentContext(PdlList([BaseMessage(
1653-
{"role": state.role, "content": result, "defsite": block.pdl__id} # type: ignore
1654-
)]))
1672+
background = DependentContext(
1673+
PdlList(
1674+
[
1675+
BaseMessage(
1676+
{"role": state.role, "content": result, "defsite": block.pdl__id} # type: ignore
1677+
)
1678+
]
1679+
)
1680+
)
16551681
except Exception as exc:
16561682
raise PDLRuntimeError(
16571683
f"PDL Code error: {repr(exc)}",
@@ -1830,8 +1856,11 @@ def process_input(
18301856
contents.append(line + "\n")
18311857
s = "".join(contents)
18321858
trace = block.model_copy(update={"pdl__result": s})
1833-
background: LazyMessages = DependentContext(PdlList(
1834-
[BaseMessage({"role": state.role, "content": s, "defsite": block.pdl__id})])) # type: ignore
1859+
background: LazyMessages = DependentContext(
1860+
PdlList(
1861+
[BaseMessage({"role": state.role, "content": s, "defsite": block.pdl__id})]
1862+
)
1863+
) # type: ignore
18351864
return PdlConst(s), background, scope, trace
18361865

18371866

tests/test_context.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,30 @@
2424

2525

2626
def test_p():
27-
assert p.serialize(SerializeMode.LITELLM) == [{"role": "user", "content": "hello"},
28-
{"role": "user", "content": "bye"},
29-
{"role": "user", "content": "hello1"},
30-
{"role": "user", "content": "bye1"}]
27+
assert p.serialize(SerializeMode.LITELLM) == [
28+
{"role": "user", "content": "hello"},
29+
{"role": "user", "content": "bye"},
30+
{"role": "user", "content": "hello1"},
31+
{"role": "user", "content": "bye1"},
32+
]
3133
assert p.serialize(SerializeMode.LITELLM) == p.serialize(SerializeMode.GRANITEIO)
3234

3335

3436
def test_p1():
35-
assert p1.serialize(SerializeMode.LITELLM) == [{"role": "user", "content": "hello"},
36-
{"role": "user", "content": "bye"},
37-
{"role": "user", "content": "hello2"},
38-
{"role": "user", "content": "bye2"}]
39-
40-
assert p1.serialize(SerializeMode.GRANITEIO) == [{"role": "user", "content": "hello"},
41-
{"role": "user", "content": "bye"},
42-
{"independent": [{"role": "user", "content": "hello2"},
43-
{"role": "user", "content": "bye2"}]}]
37+
assert p1.serialize(SerializeMode.LITELLM) == [
38+
{"role": "user", "content": "hello"},
39+
{"role": "user", "content": "bye"},
40+
{"role": "user", "content": "hello2"},
41+
{"role": "user", "content": "bye2"},
42+
]
43+
44+
assert p1.serialize(SerializeMode.GRANITEIO) == [
45+
{"role": "user", "content": "hello"},
46+
{"role": "user", "content": "bye"},
47+
{
48+
"independent": [
49+
{"role": "user", "content": "hello2"},
50+
{"role": "user", "content": "bye2"},
51+
]
52+
},
53+
]

0 commit comments

Comments
 (0)