88import logging
99import sys
1010import time
11+ import traceback
1112from typing import Dict , List , Tuple , Any
1213
1314import numpy as np
@@ -49,6 +50,17 @@ def evaluate(program_path: str) -> Dict[str, float]:
4950 Returns:
5051 Dictionary of metric name to score
5152 """
53+ # First perform basic validation
54+ stage1_result = evaluate_stage1 (program_path )
55+
56+ # If validation fails, return early
57+ if stage1_result ["correctness" ] < 0.8 :
58+ return {
59+ "correctness" : stage1_result ["correctness" ],
60+ "rank_quality" : 0.0 ,
61+ "time_efficiency" : 0.0
62+ }
63+
5264 # Import the program
5365 try :
5466 spec = importlib .util .spec_from_file_location ("program_module" , program_path )
@@ -97,14 +109,42 @@ def evaluate(program_path: str) -> Dict[str, float]:
97109
98110def evaluate_stage1 (program_path : str ) -> Dict [str , float ]:
99111 """
100- First stage of evaluation: test correctness
112+ First stage of evaluation: basic validation and test correctness
101113
102114 Args:
103115 program_path: Path to the program file
104116
105117 Returns:
106118 Dictionary of metric name to score
107119 """
120+ # First, perform static code analysis and basic validation
121+ try :
122+ with open (program_path , 'r' ) as f :
123+ code_content = f .read ()
124+
125+ # Basic syntax check
126+ try :
127+ compile (code_content , program_path , 'exec' )
128+ except SyntaxError as e :
129+ logger .error (f"Syntax error in program: { str (e )} " )
130+ return {"correctness" : 0.0 }
131+
132+ # Check for common issues
133+ if "TensorDecomposition" not in code_content :
134+ logger .error ("Program does not contain 'TensorDecomposition' class" )
135+ return {"correctness" : 0.0 }
136+
137+ # Look for variable reference issues (e.g., 'u_factors' being used before definition)
138+ if "u_factors" in code_content and "_initialize_decomposition" in code_content :
139+ # Very basic check - not exhaustive but catches simple issues
140+ if code_content .find ("u_factors" ) < code_content .find ("_initialize_decomposition" ):
141+ if "def _initialize_decomposition" in code_content :
142+ logger .error ("Possible reference to 'u_factors' before initialization" )
143+ return {"correctness" : 0.0 }
144+ except Exception as e :
145+ logger .error (f"Error during static code validation: { str (e )} " )
146+ return {"correctness" : 0.0 }
147+
108148 # Import the program
109149 try :
110150 spec = importlib .util .spec_from_file_location ("program_module" , program_path )
@@ -113,20 +153,50 @@ def evaluate_stage1(program_path: str) -> Dict[str, float]:
113153
114154 module = importlib .util .module_from_spec (spec )
115155 sys .modules ["program_module" ] = module
116- spec .loader .exec_module (module )
117156
157+ # Use a safety wrapper to catch any import-time errors
158+ try :
159+ spec .loader .exec_module (module )
160+ except Exception as e :
161+ logger .error (f"Error during module execution: { str (e )} " )
162+ traceback .print_exc ()
163+ return {"correctness" : 0.0 }
164+
165+ # Check for the required class
118166 if not hasattr (module , "TensorDecomposition" ):
119- raise AttributeError (f"Program does not contain a 'TensorDecomposition' class" )
167+ logger .error ("Program does not contain a 'TensorDecomposition' class" )
168+ return {"correctness" : 0.0 }
120169
170+ # Check basic class structure
121171 TensorDecomposition = module .TensorDecomposition
172+ required_methods = ["__init__" , "optimize" , "_initialize_decomposition" ]
173+ for method in required_methods :
174+ if not hasattr (TensorDecomposition , method ):
175+ logger .error (f"TensorDecomposition class missing required method: { method } " )
176+ return {"correctness" : 0.0 }
177+
178+ # Try to instantiate the class with minimal parameters
179+ try :
180+ test_instance = TensorDecomposition (target_shape = (2 , 2 , 2 ), rank = 7 )
181+ except Exception as e :
182+ logger .error (f"Failed to instantiate TensorDecomposition: { str (e )} " )
183+ traceback .print_exc ()
184+ return {"correctness" : 0.0 }
185+
122186 except Exception as e :
123187 logger .error (f"Error importing program: { str (e )} " )
188+ traceback .print_exc ()
124189 return {"correctness" : 0.0 }
125190
126- # Test correctness
127- correctness_score = evaluate_correctness (TensorDecomposition )
128-
129- return {"correctness" : correctness_score }
191+ # If we get here, the basic validation passed
192+ # Now perform a simple correctness test
193+ try :
194+ correctness_score = evaluate_correctness (TensorDecomposition )
195+ return {"correctness" : correctness_score }
196+ except Exception as e :
197+ logger .error (f"Error in correctness evaluation: { str (e )} " )
198+ traceback .print_exc ()
199+ return {"correctness" : 0.0 }
130200
131201
132202def evaluate_correctness (TensorDecomposition ) -> float :
0 commit comments