1+ from collections .abc import Sequence
12from enum import StrEnum
2- from typing import Any
3+ from typing import Any , Callable
34
4- from .pdl_lazy import (
5- PdlDict ,
6- PdlList ,
7- )
5+ from .pdl_lazy import PdlApply , PdlDict , PdlLazy , PdlList
86
97
108class SerializeMode (StrEnum ):
119 LITELLM = "litellm"
1210 GRANITEIO = "graniteio"
1311
1412
15- class PDLContext :
13+ class PDLContext ( Sequence ) :
1614
1715 def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
1816 return []
1917
18+ def __add__ (self , value : "PDLContext" ):
19+ return IndependentContext ([self , value ])
2020
21- class BaseMessage ( PDLContext ):
22- message : PdlDict [ str , Any ]
21+ def __mul__ ( self , value : " PDLContext" ):
22+ return DependentContext ([ self , value ])
2323
24- def __init__ (self , message : dict [str , Any ]):
25- if "role" not in message :
26- assert False
27- if "content" not in message :
28- assert False
29- self .message = PdlDict (message )
24+ def __len__ (self ):
25+ return 0
26+
27+ def __getitem__ (self , index : int | slice ): # pyright: ignore
28+ return []
29+
30+
31+ class SingletonContext (PDLContext ):
32+ message : PdlLazy [dict [str , Any ]]
33+
34+ def __init__ (self , message : PdlLazy [dict [str , Any ]]):
35+ self .message = message
3036
3137 def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
3238 result = self .message .result ()
3339 return [result ]
3440
41+ def __len__ (self ): # pyright: ignore
42+ return 1
3543
36- class IndependentContext (PDLContext ):
37- context : PdlList [PDLContext ]
44+ def __getitem__ (self , index : int | slice ): # pyright: ignore
45+ if index == 0 :
46+ return self .message .result ()
47+ print (index )
48+ assert False
49+
50+ def __repr__ (self ): # pyright: ignore
51+ return str (self .message .result ())
3852
39- def __init__ (self , context : PdlList [PDLContext ]):
40- self .context = context
53+
54+ class IndependentContext (PDLContext ):
55+ context : PdlLazy [list [PDLContext ]]
56+
57+ def __init__ (self , context : list [PDLContext ]):
58+ ret : list [PDLContext ] = []
59+ for item in context :
60+ if isinstance (item , IndependentContext ):
61+ ret += item .context .data
62+ elif isinstance (item , SingletonContext ):
63+ ret += [item ]
64+ else :
65+ # Not all elements of the list are Independent, so return
66+ self .context = PdlList (context )
67+ return
68+ # All elements of the list are Independent
69+ self .context = PdlList (ret )
4170
4271 def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
4372 result = self .context .result ()
@@ -47,31 +76,74 @@ def serialize(self, mode: SerializeMode) -> list[dict[str, Any]]:
4776 return [{"independent" : flat }]
4877 return flat
4978
79+ def __len__ (self ): # pyright: ignore
80+ return len (self .context .result ())
81+
82+ def __getitem__ (self , index : int | slice ): # pyright: ignore
83+ return self .serialize (SerializeMode .LITELLM )[index ]
84+
85+ def __repr__ (self ): # pyright: ignore
86+ ret = "{"
87+ ret += "," .join ([i .__repr__ () for i in self .context .result ()])
88+ return ret + "}"
5089
51- class DependentContext (PDLContext ):
52- context : PdlList [PDLContext ]
5390
54- def __init__ (self , context : PdlList [PDLContext ]):
55- self .context = context
91+ class DependentContext (PDLContext ):
92+ context : PdlLazy [list [PDLContext ]]
93+
94+ def __init__ (self , context : list [PDLContext ]):
95+ ret : list [PDLContext ] = []
96+ for item in context :
97+ if isinstance (item , DependentContext ):
98+ ret += item .context .data
99+ elif isinstance (item , SingletonContext ):
100+ ret += [item ]
101+ else :
102+ # Not all elements of the list are Dependent, so return
103+ self .context = PdlList (context )
104+ return
105+ # All elements of the list are Dependent
106+ self .context = PdlList (ret )
56107
57108 def serialize (self , mode : SerializeMode ) -> list [dict [str , Any ]]:
58109 result = self .context .result ()
59110 contexts = [m .serialize (mode ) for m in result ]
60- return [x for xs in contexts for x in xs ]
111+ res = [x for xs in contexts for x in xs ]
112+ return res
113+
114+ def __len__ (self ): # pyright: ignore
115+ return len (self .context .result ())
116+
117+ def __getitem__ (self , index : int | slice ): # pyright: ignore
118+ return self .serialize (SerializeMode .LITELLM )[index ]
119+
120+ def __repr__ (self ): # pyright: ignore
121+ ret = "["
122+ ret += "," .join ([i .__repr__ () for i in self .context .result ()])
123+ return ret + "]"
61124
62125
63126def deserialize (
64127 context : list [dict [str , Any ]],
65128) -> DependentContext : # Only support dependent for now
66- ret : DependentContext = DependentContext (PdlList ([]) )
129+ ret : DependentContext = DependentContext ([] )
67130 for message in context :
68131 if isinstance (message , dict ):
69- if "role" not in message :
70- assert False
71- if "content" not in message :
72- assert False
73- ret = DependentContext (PdlList ([ret , BaseMessage (message )]))
132+ ret = ret * SingletonContext (PdlDict (message ))
74133 else :
75- ret = DependentContext (PdlList ([ret , message ]))
76-
134+ ret = ret * message
77135 return ret
136+
137+
138+ def add_done_callback (
139+ f : Callable , p : PDLContext
140+ ): # Assuming that f is the identity function
141+ match p :
142+ case SingletonContext (message = m ):
143+ p .message = PdlApply (f , m )
144+ case DependentContext (context = c ):
145+ p .context = PdlApply (f , c )
146+ case IndependentContext (context = c ):
147+ p .context = PdlApply (f , c )
148+ case _:
149+ assert False
0 commit comments