@@ -53,9 +53,12 @@ def __init__(
5353
5454 self .X : Optional [np .ndarray ] = None
5555 self .y : Optional [np .ndarray ] = None
56+ self .X_all : Optional [np .ndarray ] = None # All samples (valid + invalid)
57+ self .y_valid : Optional [np .ndarray ] = None # Validity labels (0/1)
5658 self .param_names : List [str ] = []
5759 self .param_encodings : Dict [str , Dict [str , int ]] = {}
5860 self .model = None
61+ self .validity_model = None
5962
6063 self ._training_time : float = 0
6164 self ._collection_time : float = 0
@@ -145,8 +148,10 @@ def collect_samples_grid(
145148 grid_points = [grid_points [i ] for i in indices ]
146149
147150 n_samples = len (grid_points )
148- X_list = []
151+ X_valid_list = []
149152 y_list = []
153+ X_all_list = []
154+ validity_list = []
150155
151156 if verbose :
152157 print (f"Collecting { n_samples } samples..." )
@@ -169,29 +174,42 @@ def collect_samples_grid(
169174 # Evaluate function (use pure_objective_function to get raw value)
170175 try :
171176 score = self .function .pure_objective_function (params )
172- # Skip NaN values (can happen with invalid hyperparameter combos)
177+
178+ # Track all samples for validity model
179+ X_all_list .append (x_row )
180+
173181 if np .isnan (score ):
174- continue
175- X_list .append (x_row )
176- y_list .append (score )
182+ # Invalid combination
183+ validity_list .append (0 )
184+ else :
185+ # Valid combination
186+ validity_list .append (1 )
187+ X_valid_list .append (x_row )
188+ y_list .append (score )
177189
178190 if verbose and (i + 1 ) % 100 == 0 :
179191 print (f" Collected { len (y_list )} /{ n_samples } valid samples" )
180192 except Exception as e :
193+ # Treat exceptions as invalid
194+ X_all_list .append (x_row )
195+ validity_list .append (0 )
181196 if verbose :
182197 print (f" Error at sample { i } : { e } " )
183198
184- self .X = np .array (X_list , dtype = np .float32 )
199+ self .X = np .array (X_valid_list , dtype = np .float32 )
185200 self .y = np .array (y_list , dtype = np .float32 )
201+ self .X_all = np .array (X_all_list , dtype = np .float32 )
202+ self .y_valid = np .array (validity_list , dtype = np .int32 )
186203
187204 self ._collection_time = time .time () - start_time
188205
206+ n_valid = len (self .y )
207+ n_invalid = len (self .y_valid ) - n_valid
208+
189209 if verbose :
190- n_valid = len (self .y )
191- n_skipped = n_samples - n_valid
192210 print (f"Collected { n_valid } valid samples in { self ._collection_time :.1f} s" )
193- if n_skipped > 0 :
194- print (f" Skipped { n_skipped } samples (NaN or errors )" )
211+ if n_invalid > 0 :
212+ print (f" Invalid samples: { n_invalid } (will train validity model )" )
195213 if n_valid > 0 :
196214 print (f" y range: [{ self .y .min ():.4f} , { self .y .max ():.4f} ]" )
197215
@@ -205,6 +223,8 @@ def train(
205223 ):
206224 """Train an MLP regressor on collected samples.
207225
226+ Also trains a validity classifier if invalid samples were found.
227+
208228 Parameters
209229 ----------
210230 hidden_layer_sizes : tuple
@@ -225,11 +245,13 @@ def train(
225245
226246 start_time = time .time ()
227247
228- # Normalize inputs
248+ # Normalize inputs for regression model
229249 self .scaler_X = StandardScaler ()
230250 X_scaled = self .scaler_X .fit_transform (self .X )
231251
232- # Train MLP
252+ # Train regression MLP
253+ if verbose :
254+ print ("Training regression model..." )
233255 self .model = MLPRegressor (
234256 hidden_layer_sizes = hidden_layer_sizes ,
235257 max_iter = max_iter ,
@@ -240,17 +262,43 @@ def train(
240262 )
241263 self .model .fit (X_scaled , self .y )
242264
243- self ._training_time = time .time () - start_time
244-
245- # Evaluate on training data
265+ # Evaluate regression on training data
246266 y_pred = self .model .predict (X_scaled )
247267 mse = np .mean ((self .y - y_pred ) ** 2 )
248268 r2 = 1 - mse / np .var (self .y )
249269
270+ # Train validity classifier if there are invalid samples
271+ n_invalid = np .sum (self .y_valid == 0 )
272+ if n_invalid > 0 :
273+ if verbose :
274+ print ("\n Training validity classifier (DecisionTree)..." )
275+
276+ from sklearn .tree import DecisionTreeClassifier
277+
278+ # Decision tree doesn't need scaling, but we keep scaler for API consistency
279+ self .scaler_X_validity = None
280+
281+ self .validity_model = DecisionTreeClassifier (
282+ max_depth = 10 ,
283+ min_samples_leaf = 5 ,
284+ random_state = 42 ,
285+ )
286+ self .validity_model .fit (self .X_all , self .y_valid )
287+
288+ # Evaluate validity classifier
289+ validity_pred = self .validity_model .predict (self .X_all )
290+ validity_acc = np .mean (validity_pred == self .y_valid )
291+
292+ if verbose :
293+ print (f" Validity classifier accuracy: { validity_acc :.4f} " )
294+ print (f" Tree depth: { self .validity_model .get_depth ()} " )
295+
296+ self ._training_time = time .time () - start_time
297+
250298 if verbose :
251299 print (f"\n Training completed in { self ._training_time :.1f} s" )
252- print (f" MSE: { mse :.6f} " )
253- print (f" R2: { r2 :.4f} " )
300+ print (f" Regression MSE: { mse :.6f} " )
301+ print (f" Regression R2: { r2 :.4f} " )
254302
255303 def export (
256304 self ,
@@ -294,7 +342,7 @@ def export(
294342 ("mlp" , self .model ),
295343 ])
296344
297- # Convert to ONNX
345+ # Convert regression model to ONNX
298346 n_features = self .X .shape [1 ]
299347 initial_type = [("input" , FloatTensorType ([None , n_features ]))]
300348 onnx_model = convert_sklearn (pipeline , initial_types = initial_type )
@@ -303,12 +351,33 @@ def export(
303351 with open (output_path , "wb" ) as f :
304352 f .write (onnx_model .SerializeToString ())
305353
354+ # Export validity model if it exists
355+ has_validity_model = self .validity_model is not None
356+ if has_validity_model :
357+ validity_path = output_path .with_suffix (".validity.onnx" )
358+
359+ # DecisionTree doesn't need a scaler pipeline
360+ onnx_validity = convert_sklearn (
361+ self .validity_model ,
362+ initial_types = initial_type ,
363+ options = {id (self .validity_model ): {"zipmap" : False }},
364+ )
365+
366+ with open (validity_path , "wb" ) as f :
367+ f .write (onnx_validity .SerializeToString ())
368+
369+ if verbose :
370+ print (f"Exported validity model to: { validity_path } " )
371+
306372 # Save metadata
373+ n_invalid = int (np .sum (self .y_valid == 0 ))
307374 metadata = {
308375 "function_name" : getattr (self .function , "_name_" , self .function .__class__ .__name__ ),
309376 "param_names" : self .param_names ,
310377 "param_encodings" : self .param_encodings ,
311378 "n_samples" : len (self .y ),
379+ "n_invalid_samples" : n_invalid ,
380+ "has_validity_model" : has_validity_model ,
312381 "y_range" : [float (self .y .min ()), float (self .y .max ())],
313382 "training_time" : self ._training_time ,
314383 "collection_time" : self ._collection_time ,
0 commit comments