@@ -78,25 +78,25 @@ And inside ``configuration_callback()`` implement the response to the configurat
7878 # See the License for the specific language governing permissions and
7979 # limitations under the License.
8080 """ SustainML ML Model Provider Node Implementation."""
81-
81+
8282 from sustainml_py.nodes.MLModelNode import MLModelNode
83-
83+
8484 # Manage signaling
8585 import os
8686 import signal
8787 import threading
8888 import time
8989 import json
90-
90+
9191 from rdftool.ModelONNXCodebase import model
9292 from rdftool.rdfCode import load_graph, get_models_for_problem, get_models_for_problem_and_tag
93-
94- from rag.rag_backend import answer_question
95-
93+
94+ from rag.rag_backend import answer_question, get_allowed_models_for_problem
95+
9696 # Whether to go on spinning or interrupt
9797 running = False
98-
99-
98+
99+
100100 # Load the list of unsupported
101101 def load_unsupported_models (file_path ):
102102 try :
@@ -105,19 +105,19 @@ And inside ``configuration_callback()`` implement the response to the configurat
105105 except Exception as e:
106106 print (f " [WARN] Could not load unsupported list: { e} " )
107107 return []
108-
109-
108+
109+
110110 unsupported_models = load_unsupported_models(os.path.dirname(__file__ ) + " /unsupported_models.txt" )
111-
112-
111+
112+
113113 # Signal handler
114114 def signal_handler (sig , frame ):
115115 print (" \n Exiting" )
116116 MLModelNode.terminate()
117117 global running
118118 running = False
119-
120-
119+
120+
121121 # User Callback implementation
122122 # Inputs: ml_model_metadata, app_requirements, hw_constraints, ml_model_baseline, hw_baseline, carbonfootprint_baseline
123123 # Outputs: node_status, ml_model
@@ -129,11 +129,11 @@ And inside ``configuration_callback()`` implement the response to the configurat
129129 carbonfootprint_baseline ,
130130 node_status ,
131131 ml_model ):
132-
132+
133133 # Callback implementation here
134-
134+
135135 print (f " Received Task: { ml_model_metadata.task_id().problem_id()} , { ml_model_metadata.task_id().iteration_id()} " )
136-
136+
137137 try :
138138 chosen_model = None
139139 # Model restriction after various outputs
@@ -147,31 +147,32 @@ And inside ``configuration_callback()`` implement the response to the configurat
147147 except json.JSONDecodeError:
148148 print (" [WARN] In model_provider node extra_data JSON is not valid." )
149149 extra_data_dict = {}
150-
150+
151151 if " type" in extra_data_dict:
152152 type = extra_data_dict[" type" ]
153-
153+
154154 if " model_restrains" in extra_data_dict:
155155 restrained_models = extra_data_dict[" model_restrains" ]
156-
156+
157157 if " model_selected" in extra_data_dict:
158158 chosen_model = extra_data_dict[" model_selected" ]
159159 print (" Model already selected: " , chosen_model)
160-
160+
161161 problem_short_description = extra_data_dict[" problem_short_description" ]
162-
162+
163163 metadata = ml_model_metadata.ml_model_metadata()[0 ]
164-
164+
165165 if chosen_model is None :
166166 print (f " Problem short description: { problem_short_description} " )
167-
168- # Choose model with the RAG based on the goal selected and the knowledge of the graph.
167+
168+ # Build the whitelist and force the RAG to pick ONLY from it
169+ allowed = get_allowed_models_for_problem(metadata) # Metadata is the goal name
169170 chosen_model = answer_question(
170- f " Task { metadata} with problem description: { problem_short_description} ? "
171- )
172-
171+ f " Task { metadata} with problem description: { problem_short_description} ? " ,
172+ allowed_models = allowed
173+ )
173174 print (f " ML Model chosen: { chosen_model} " )
174-
175+
175176 # Generate model code and keywords
176177 onnx_path = model(chosen_model) # TODO - Further development needed
177178 ml_model.model(chosen_model)
@@ -180,7 +181,7 @@ And inside ``configuration_callback()`` implement the response to the configurat
180181 extra_data = {" unsupported_models" : unsupported_models}
181182 encoded_data = json.dumps(extra_data).encode(" utf-8" )
182183 ml_model.extra_data(encoded_data)
183-
184+
184185 except Exception as e:
185186 print (f " Failed to determine ML model for task { ml_model_metadata.task_id()} : { e} . " )
186187 ml_model.model(" Error" )
@@ -189,18 +190,18 @@ And inside ``configuration_callback()`` implement the response to the configurat
189190 error_info = {" error" : error_message}
190191 encoded_error = json.dumps(error_info).encode(" utf-8" )
191192 ml_model.extra_data(encoded_error)
192-
193-
193+
194+
194195 # User Configuration Callback implementation
195196 # Inputs: req
196197 # Outputs: res
197198 def configuration_callback (req , res ):
198-
199+
199200 # Callback for configuration implementation here
200201 if ' model_from_goal' in req.configuration():
201202 res.node_id(req.node_id())
202203 res.transaction_id(req.transaction_id())
203-
204+
204205 try :
205206 text = req.configuration()[len (" model_from_goal, " ):]
206207 parts = text.split(' ,' )
@@ -211,24 +212,24 @@ And inside ``configuration_callback()`` implement the response to the configurat
211212 else :
212213 goal = text.strip()
213214 models = get_models_for_problem(goal)
214-
215+
215216 sorted_models = ' , ' .join(sorted ([str (m[0 ]) for m in models]))
216-
217+
217218 if not sorted_models:
218219 res.success(False )
219220 res.err_code(1 ) # 0: No error || 1: Error
220221 else :
221222 res.success(True )
222223 res.err_code(0 ) # 0: No error || 1: Error
223-
224+
224225 print (f " Models for { goal} : { sorted_models} " ) # debug
225226 res.configuration(json.dumps(dict (models = sorted_models)))
226-
227+
227228 except Exception as e:
228229 print (f " Error getting model from goal from request: { e} " )
229230 res.success(False )
230231 res.err_code(1 )
231-
232+
232233 else :
233234 res.node_id(req.node_id())
234235 res.transaction_id(req.transaction_id())
@@ -237,8 +238,8 @@ And inside ``configuration_callback()`` implement the response to the configurat
237238 res.success(False )
238239 res.err_code(1 ) # 0: No error || 1: Error
239240 print (error_msg)
240-
241-
241+
242+
242243 # Main workflow routine
243244 def run ():
244245 start_time = time.time()
@@ -255,19 +256,19 @@ And inside ``configuration_callback()`` implement the response to the configurat
255256 global running
256257 running = True
257258 node.spin()
258-
259-
259+
260+
260261 # Call main in program execution
261262 if __name__ == ' __main__' :
262263 signal.signal(signal.SIGINT , signal_handler)
263-
264+
264265 """ Python does not process signals async if
265266 the main thread is blocked (spin()) so, tun
266267 user work flow in another thread """
267268 runner = threading.Thread(target = run)
268269 runner.start()
269-
270+
270271 while running:
271272 time.sleep(1 )
272-
273+
273274 runner.join()
0 commit comments