@@ -62,14 +62,10 @@ def _(synalinks):
6262 from synalinks import ops
6363
6464 class Thinking (synalinks .DataModel ):
65- thinking : str = synalinks .Field (
66- description = "Your step by step thinking process"
67- )
65+ thinking : str = synalinks .Field (description = "Your step by step thinking process" )
6866
6967 class CritiqueWithReward (synalinks .DataModel ):
70- critique : str = synalinks .Field (
71- description = "The step by step critique"
72- )
68+ critique : str = synalinks .Field (description = "The step by step critique" )
7369 reward : float = synalinks .Field (
7470 description = "The reward corresponding to the critique between [0.0, 1.0]" ,
7571 le = 1.0 ,
@@ -108,25 +104,23 @@ def __init__(
108104 self .stop_threshold = stop_threshold
109105 self .max_iterations = max_iterations
110106 self .critique_program = critique_program
111- self .prompt_template = prompt_template
107+ self .prompt_template = prompt_template
112108 self .examples = examples
113109 self .hints = hints
114110 self .use_inputs_schema = use_inputs_schema
115111 self .use_outputs_schema = use_outputs_schema
116112 if not self .critique_program :
117113 # If no critique program is provided
118114 # We compute the reward in the thinking step
119- thinking_data_model = \
120- Thinking \
121- + synalinks .SymbolicDataModel (
122- schema = self . schema
123- ) + CritiqueWithReward
115+ thinking_data_model = (
116+ Thinking
117+ + synalinks .SymbolicDataModel (schema = self . schema )
118+ + CritiqueWithReward
119+ )
124120 else :
125- thinking_data_model = \
126- Thinking \
127- + synalinks .SymbolicDataModel (
128- schema = self .schema
129- )
121+ thinking_data_model = Thinking + synalinks .SymbolicDataModel (
122+ schema = self .schema
123+ )
130124 # This is for generating the intermediary steps
131125 self .thinking = synalinks .Generator (
132126 data_model = thinking_data_model ,
@@ -136,7 +130,7 @@ def __init__(
136130 hints = self .hints ,
137131 use_inputs_schema = self .use_inputs_schema ,
138132 use_outputs_schema = self .use_outputs_schema ,
139- name = self .name + "_thinking_generator" ,
133+ name = self .name + "_thinking_generator" ,
140134 )
141135 # This is going to be the final generator
142136 self .generator = synalinks .Generator (
@@ -147,7 +141,7 @@ def __init__(
147141 hints = self .hints ,
148142 use_inputs_schema = self .use_inputs_schema ,
149143 use_outputs_schema = self .use_outputs_schema ,
150- name = self .name + "_generator" ,
144+ name = self .name + "_generator" ,
151145 )
152146
153147 async def call (self , inputs , training = False ):
@@ -167,9 +161,7 @@ async def call(self, inputs, training=False):
167161 if reward > self .stop_threshold :
168162 break
169163 inputs = await ops .concat (
170- inputs ,
171- thinking ,
172- name = self .name + f"_thinking_{ i } "
164+ inputs , thinking , name = self .name + f"_thinking_{ i } "
173165 )
174166 return await self .generator (inputs )
175167
@@ -190,7 +182,7 @@ def get_config(self):
190182 "name" : self .name ,
191183 "description" : self .description ,
192184 "trainable" : self .trainable ,
193- }
185+ }
194186 language_model_config = {
195187 "language_model" : synalinks .saving .serialize_synalinks_object (
196188 self .language_model ,
@@ -224,6 +216,7 @@ def from_config(cls, config):
224216 critique_program = critique_program ,
225217 ** config ,
226218 )
219+
227220 return BacktrackingOfThought , CritiqueWithReward , Thinking , ops
228221
229222
0 commit comments