Skip to content

Commit 70dab4d

Browse files
committed
use decision tree to rule out invalid parameter combinations
1 parent 6842749 commit 70dab4d

8 files changed

+191
-27
lines changed

src/surfaces/_surrogates/README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,38 @@ Each ONNX model has an accompanying `.meta.json` file:
223223
}
224224
```
225225

226+
## Validity Model
227+
228+
Some hyperparameter combinations are invalid (e.g., `n_neighbors > dataset_size` in KNN). The real function returns `NaN` for these cases.
229+
230+
The surrogate system handles this by training a **validity classifier** alongside the regression model:
231+
232+
1. During training, both valid and invalid samples are collected
233+
2. A binary classifier learns to predict validity
234+
3. During inference, validity is checked first
235+
4. Invalid combinations return `NaN`, just like the real function
236+
237+
```python
238+
# Surrogate correctly returns NaN for invalid combinations
239+
func = KNeighborsClassifierFunction(use_surrogate=True)
240+
241+
# Valid: returns score
242+
result = func({'n_neighbors': 5, 'cv': 5, 'dataset': digits_data, ...})
243+
# 0.9560
244+
245+
# Invalid: returns NaN (n_neighbors too large for dataset)
246+
result = func({'n_neighbors': 140, 'cv': 5, 'dataset': iris_data, ...})
247+
# nan
248+
```
249+
250+
Files for a function with validity model:
251+
```
252+
models/
253+
├── k_neighbors_classifier.onnx # Regression model
254+
├── k_neighbors_classifier.validity.onnx # Validity classifier
255+
└── k_neighbors_classifier.onnx.meta.json # Metadata (has_validity_model: true)
256+
```
257+
226258
## Limitations
227259

228260
1. **Interpolation only**: Surrogates work best within the training search space

src/surfaces/_surrogates/_surrogate_loader.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,15 @@ def __init__(
4444
self.metadata_path = metadata_path or self.model_path.with_suffix(
4545
self.model_path.suffix + ".meta.json"
4646
)
47+
self.validity_model_path = self.model_path.with_suffix(".validity.onnx")
4748

4849
self._session = None
50+
self._validity_session = None
4951
self._metadata = None
5052

5153
@property
5254
def session(self):
53-
"""Lazy-load ONNX runtime session."""
55+
"""Lazy-load ONNX runtime session for regression model."""
5456
if self._session is None:
5557
try:
5658
import onnxruntime as ort
@@ -66,6 +68,35 @@ def session(self):
6668
)
6769
return self._session
6870

71+
@property
72+
def validity_session(self):
73+
"""Lazy-load ONNX runtime session for validity model."""
74+
if self._validity_session is None:
75+
if not self.has_validity_model:
76+
return None
77+
78+
try:
79+
import onnxruntime as ort
80+
except ImportError:
81+
raise ImportError(
82+
"onnxruntime is required for surrogate models. "
83+
"Install it with: pip install onnxruntime"
84+
)
85+
86+
self._validity_session = ort.InferenceSession(
87+
str(self.validity_model_path),
88+
providers=["CPUExecutionProvider"],
89+
)
90+
return self._validity_session
91+
92+
@property
93+
def has_validity_model(self) -> bool:
94+
"""Check if a validity model exists."""
95+
return (
96+
self.metadata.get("has_validity_model", False)
97+
and self.validity_model_path.exists()
98+
)
99+
69100
@property
70101
def metadata(self) -> Dict[str, Any]:
71102
"""Load metadata from JSON file."""
@@ -114,6 +145,30 @@ def _encode_params(self, params: Dict[str, Any]) -> np.ndarray:
114145

115146
return np.array([values], dtype=np.float32)
116147

148+
def is_valid(self, params: Dict[str, Any]) -> bool:
149+
"""Check if parameter combination is valid.
150+
151+
Parameters
152+
----------
153+
params : dict
154+
Parameter dictionary.
155+
156+
Returns
157+
-------
158+
bool
159+
True if valid, False if invalid (would return NaN).
160+
"""
161+
if not self.has_validity_model:
162+
return True # No validity model, assume all valid
163+
164+
input_array = self._encode_params(params)
165+
input_name = self.validity_session.get_inputs()[0].name
166+
output = self.validity_session.run(None, {input_name: input_array})
167+
168+
# Output is class label (0=invalid, 1=valid)
169+
predicted_class = int(output[0][0])
170+
return predicted_class == 1
171+
117172
def predict(self, params: Dict[str, Any]) -> float:
118173
"""Run inference on the surrogate model.
119174
@@ -125,8 +180,12 @@ def predict(self, params: Dict[str, Any]) -> float:
125180
Returns
126181
-------
127182
float
128-
Predicted objective value.
183+
Predicted objective value, or NaN if parameters are invalid.
129184
"""
185+
# Check validity first
186+
if not self.is_valid(params):
187+
return float("nan")
188+
130189
input_array = self._encode_params(params)
131190
input_name = self.session.get_inputs()[0].name
132191
output = self.session.run(None, {input_name: input_array})

src/surfaces/_surrogates/_surrogate_trainer.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,12 @@ def __init__(
5353

5454
self.X: Optional[np.ndarray] = None
5555
self.y: Optional[np.ndarray] = None
56+
self.X_all: Optional[np.ndarray] = None # All samples (valid + invalid)
57+
self.y_valid: Optional[np.ndarray] = None # Validity labels (0/1)
5658
self.param_names: List[str] = []
5759
self.param_encodings: Dict[str, Dict[str, int]] = {}
5860
self.model = None
61+
self.validity_model = None
5962

6063
self._training_time: float = 0
6164
self._collection_time: float = 0
@@ -145,8 +148,10 @@ def collect_samples_grid(
145148
grid_points = [grid_points[i] for i in indices]
146149

147150
n_samples = len(grid_points)
148-
X_list = []
151+
X_valid_list = []
149152
y_list = []
153+
X_all_list = []
154+
validity_list = []
150155

151156
if verbose:
152157
print(f"Collecting {n_samples} samples...")
@@ -169,29 +174,42 @@ def collect_samples_grid(
169174
# Evaluate function (use pure_objective_function to get raw value)
170175
try:
171176
score = self.function.pure_objective_function(params)
172-
# Skip NaN values (can happen with invalid hyperparameter combos)
177+
178+
# Track all samples for validity model
179+
X_all_list.append(x_row)
180+
173181
if np.isnan(score):
174-
continue
175-
X_list.append(x_row)
176-
y_list.append(score)
182+
# Invalid combination
183+
validity_list.append(0)
184+
else:
185+
# Valid combination
186+
validity_list.append(1)
187+
X_valid_list.append(x_row)
188+
y_list.append(score)
177189

178190
if verbose and (i + 1) % 100 == 0:
179191
print(f" Collected {len(y_list)}/{n_samples} valid samples")
180192
except Exception as e:
193+
# Treat exceptions as invalid
194+
X_all_list.append(x_row)
195+
validity_list.append(0)
181196
if verbose:
182197
print(f" Error at sample {i}: {e}")
183198

184-
self.X = np.array(X_list, dtype=np.float32)
199+
self.X = np.array(X_valid_list, dtype=np.float32)
185200
self.y = np.array(y_list, dtype=np.float32)
201+
self.X_all = np.array(X_all_list, dtype=np.float32)
202+
self.y_valid = np.array(validity_list, dtype=np.int32)
186203

187204
self._collection_time = time.time() - start_time
188205

206+
n_valid = len(self.y)
207+
n_invalid = len(self.y_valid) - n_valid
208+
189209
if verbose:
190-
n_valid = len(self.y)
191-
n_skipped = n_samples - n_valid
192210
print(f"Collected {n_valid} valid samples in {self._collection_time:.1f}s")
193-
if n_skipped > 0:
194-
print(f" Skipped {n_skipped} samples (NaN or errors)")
211+
if n_invalid > 0:
212+
print(f" Invalid samples: {n_invalid} (will train validity model)")
195213
if n_valid > 0:
196214
print(f" y range: [{self.y.min():.4f}, {self.y.max():.4f}]")
197215

@@ -205,6 +223,8 @@ def train(
205223
):
206224
"""Train an MLP regressor on collected samples.
207225
226+
Also trains a validity classifier if invalid samples were found.
227+
208228
Parameters
209229
----------
210230
hidden_layer_sizes : tuple
@@ -225,11 +245,13 @@ def train(
225245

226246
start_time = time.time()
227247

228-
# Normalize inputs
248+
# Normalize inputs for regression model
229249
self.scaler_X = StandardScaler()
230250
X_scaled = self.scaler_X.fit_transform(self.X)
231251

232-
# Train MLP
252+
# Train regression MLP
253+
if verbose:
254+
print("Training regression model...")
233255
self.model = MLPRegressor(
234256
hidden_layer_sizes=hidden_layer_sizes,
235257
max_iter=max_iter,
@@ -240,17 +262,43 @@ def train(
240262
)
241263
self.model.fit(X_scaled, self.y)
242264

243-
self._training_time = time.time() - start_time
244-
245-
# Evaluate on training data
265+
# Evaluate regression on training data
246266
y_pred = self.model.predict(X_scaled)
247267
mse = np.mean((self.y - y_pred) ** 2)
248268
r2 = 1 - mse / np.var(self.y)
249269

270+
# Train validity classifier if there are invalid samples
271+
n_invalid = np.sum(self.y_valid == 0)
272+
if n_invalid > 0:
273+
if verbose:
274+
print("\nTraining validity classifier (DecisionTree)...")
275+
276+
from sklearn.tree import DecisionTreeClassifier
277+
278+
# Decision tree doesn't need scaling, but we keep scaler for API consistency
279+
self.scaler_X_validity = None
280+
281+
self.validity_model = DecisionTreeClassifier(
282+
max_depth=10,
283+
min_samples_leaf=5,
284+
random_state=42,
285+
)
286+
self.validity_model.fit(self.X_all, self.y_valid)
287+
288+
# Evaluate validity classifier
289+
validity_pred = self.validity_model.predict(self.X_all)
290+
validity_acc = np.mean(validity_pred == self.y_valid)
291+
292+
if verbose:
293+
print(f" Validity classifier accuracy: {validity_acc:.4f}")
294+
print(f" Tree depth: {self.validity_model.get_depth()}")
295+
296+
self._training_time = time.time() - start_time
297+
250298
if verbose:
251299
print(f"\nTraining completed in {self._training_time:.1f}s")
252-
print(f" MSE: {mse:.6f}")
253-
print(f" R2: {r2:.4f}")
300+
print(f" Regression MSE: {mse:.6f}")
301+
print(f" Regression R2: {r2:.4f}")
254302

255303
def export(
256304
self,
@@ -294,7 +342,7 @@ def export(
294342
("mlp", self.model),
295343
])
296344

297-
# Convert to ONNX
345+
# Convert regression model to ONNX
298346
n_features = self.X.shape[1]
299347
initial_type = [("input", FloatTensorType([None, n_features]))]
300348
onnx_model = convert_sklearn(pipeline, initial_types=initial_type)
@@ -303,12 +351,33 @@ def export(
303351
with open(output_path, "wb") as f:
304352
f.write(onnx_model.SerializeToString())
305353

354+
# Export validity model if it exists
355+
has_validity_model = self.validity_model is not None
356+
if has_validity_model:
357+
validity_path = output_path.with_suffix(".validity.onnx")
358+
359+
# DecisionTree doesn't need a scaler pipeline
360+
onnx_validity = convert_sklearn(
361+
self.validity_model,
362+
initial_types=initial_type,
363+
options={id(self.validity_model): {"zipmap": False}},
364+
)
365+
366+
with open(validity_path, "wb") as f:
367+
f.write(onnx_validity.SerializeToString())
368+
369+
if verbose:
370+
print(f"Exported validity model to: {validity_path}")
371+
306372
# Save metadata
373+
n_invalid = int(np.sum(self.y_valid == 0))
307374
metadata = {
308375
"function_name": getattr(self.function, "_name_", self.function.__class__.__name__),
309376
"param_names": self.param_names,
310377
"param_encodings": self.param_encodings,
311378
"n_samples": len(self.y),
379+
"n_invalid_samples": n_invalid,
380+
"has_validity_model": has_validity_model,
312381
"y_range": [float(self.y.min()), float(self.y.max())],
313382
"training_time": self._training_time,
314383
"collection_time": self._collection_time,
0 Bytes
Binary file not shown.

src/surfaces/_surrogates/models/gradient_boosting_regressor.onnx.meta.json

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
}
1313
},
1414
"n_samples": 1000,
15+
"n_invalid_samples": 0,
16+
"has_validity_model": false,
1517
"y_range": [
16-
-0.18430274724960327,
17-
0.4645007848739624
18+
-0.188394695520401,
19+
0.46403050422668457
1820
],
19-
"training_time": 0.9513494968414307,
20-
"collection_time": 571.8018696308136
21+
"training_time": 0.9096114635467529,
22+
"collection_time": 581.1047446727753
2123
}
0 Bytes
Binary file not shown.

src/surfaces/_surrogates/models/k_neighbors_classifier.onnx.meta.json

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919
"iris_data": 2
2020
}
2121
},
22-
"n_samples": 866,
22+
"n_samples": 868,
23+
"n_invalid_samples": 132,
24+
"has_validity_model": true,
2325
"y_range": [
2426
0.39886364340782166,
2527
0.9802631735801697
2628
],
27-
"training_time": 1.0907559394836426,
28-
"collection_time": 40.40663194656372
29+
"training_time": 0.5762979984283447,
30+
"collection_time": 41.40630507469177
2931
}
Binary file not shown.

0 commit comments

Comments
 (0)