@@ -89,9 +89,26 @@ And inside ``configuration_callback()`` implement the response to the configurat
8989 import json
9090
9191 from rdftool.ModelONNXCodebase import model
92+ from neo4j import GraphDatabase
9293 from rdftool.rdfCode import load_graph, get_models_for_problem, get_models_for_problem_and_tag
9394
94- from rag.rag_backend import answer_question, get_allowed_models_for_problem
95+ from rag.rag_backend import answer_question
96+
97+ # Neo4j config/driver for local checks (used by _model_has_goal)
98+ NEO4J_URI = " bolt://localhost:7687"
99+ NEO4J_USER = " neo4j"
100+ NEO4J_PASSWORD = " 12345678"
101+ neo4j_driver = GraphDatabase.driver(NEO4J_URI , auth = (NEO4J_USER , NEO4J_PASSWORD ))
102+
103+ def _model_has_goal (neo4j_driver , model_name : str , goal : str ) -> bool :
104+ cypher = """
105+ MATCH (m:Model {name: $model} )-[:HAS_PROBLEM]->(p:Problem)
106+ WHERE toLower(p.name) = toLower($goal)
107+ RETURN COUNT(*) AS cnt
108+ """
109+ with neo4j_driver.session() as s:
110+ r = s.run(cypher, model = model_name, goal = goal).single()
111+ return bool (r and r[" cnt" ] > 0 )
95112
96113 # Whether to go on spinning or interrupt
97114 running = False
@@ -160,17 +177,55 @@ And inside ``configuration_callback()`` implement the response to the configurat
160177
161178 problem_short_description = extra_data_dict[" problem_short_description" ]
162179
163- metadata = ml_model_metadata.ml_model_metadata()[0 ]
180+ goal = ml_model_metadata.ml_model_metadata()[0 ] # goal selected by metadata node
181+ print (f " Problem short description: { problem_short_description} " )
182+ print (f " Selected goal (metadata): { goal} " )
183+
184+ # Build strictly goal-scoped allowed list (names only)
185+ goal_models = get_models_for_problem(goal) # [(model_name, downloads), ...]
186+ allowed_names = [name for (name, _) in goal_models]
187+ print (f " [INFO] { len (allowed_names)} candidates for goal ' { goal} ' " )
188+ if not allowed_names:
189+ raise Exception (" No candidates in graph for the selected goal" )
164190
165- if chosen_model is None :
166- print (f " Problem short description: { problem_short_description} " )
191+ # Track models to avoid repeats across outputs
192+ restrained_models = []
193+ if extra_data_bytes:
194+ try :
195+ if " model_restrains" in extra_data_dict:
196+ restrained_models = list (set (extra_data_dict[" model_restrains" ]))
197+ except Exception :
198+ pass
167199
168- # Build the whitelist and force the RAG to pick ONLY from it
169- allowed = get_allowed_models_for_problem(metadata) # Metadata is the goal name
170- chosen_model = answer_question(
171- f " Task { metadata} with problem description: { problem_short_description} ? " ,
172- allowed_models = allowed
200+ # Try up to 10 candidates, skipping misfits transparently
201+ chosen_model = None
202+ for _ in range (10 ):
203+ remaining = [n for n in allowed_names if n not in restrained_models]
204+ if not remaining:
205+ break
206+
207+ candidate = answer_question(
208+ f " Task { goal} with problem description: { problem_short_description} ? " ,
209+ allowed_models = remaining
173210 )
211+
212+ if not candidate or candidate.strip().lower() == " none" :
213+ # mark and try again
214+ if candidate:
215+ restrained_models.append(candidate)
216+ continue
217+
218+ # Final safety: ensure candidate really belongs to goal
219+ if not _model_has_goal(neo4j_driver, candidate, goal):
220+ print (f " [GUARD] Dropping { candidate} : not linked to goal { goal} " )
221+ restrained_models.append(candidate)
222+ continue
223+
224+ chosen_model = candidate
225+ break
226+
227+ if not chosen_model:
228+ raise Exception (" No suitable model after screening candidates" )
174229 print (f " ML Model chosen: { chosen_model} " )
175230
176231 # Generate model code and keywords
@@ -183,11 +238,11 @@ And inside ``configuration_callback()`` implement the response to the configurat
183238 ml_model.extra_data(encoded_data)
184239
185240 except Exception as e:
186- print (f " Failed to determine ML model for task { ml_model_metadata.task_id()} : { e} . " )
187- ml_model.model(" Error " )
188- ml_model.model_path(" Error " )
189- error_message = " Failed to obtain ML model for task: " + str (e)
190- error_info = {" error" : error_message}
241+ print (f " [WARN] No suitable model found for task { ml_model_metadata.task_id()} : { e} " )
242+ ml_model.model(" NO_MODEL " )
243+ ml_model.model_path(" N/A " )
244+ error_message = " No suitable model found for the given problem. "
245+ error_info = {" error_code " : " NO_MODEL " , " error" : error_message}
191246 encoded_error = json.dumps(error_info).encode(" utf-8" )
192247 ml_model.extra_data(encoded_error)
193248
0 commit comments