Skip to content

Commit fc3415f

Browse files
committed
adding load_state_dict() for surrogate
1 parent 4d11d46 commit fc3415f

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

deeprootgen/calibration/model_versioning.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,11 @@ def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
418418
theta (torch.Tensor):
419419
The batch of parameter vectors.
420420
x (torch.Tensor):
421-
The batch tensor.
421+
The batch tensor data.
422422
423423
Returns:
424424
torch.Tensor:
425-
The graph embedding.
425+
The normalising flow.
426426
"""
427427
x = self.encode(x)
428428
x = self.npe(theta, x)
@@ -479,6 +479,7 @@ def load_data(k: str) -> Any:
479479
self.model = SingleTaskVariationalGPModel(inducing_points).double()
480480
self.likelihood = gpytorch.likelihoods.GaussianLikelihood().double()
481481

482+
self.model.load_state_dict(self.state_dict)
482483
self.model.eval()
483484
self.X_scaler = load_data("X_scaler")
484485
self.Y_scaler = load_data("Y_scaler")
@@ -525,12 +526,13 @@ def predict(
525526
lower, upper = predictions.confidence_region()
526527
lower, upper = lower.detach().cpu().numpy(), upper.detach().cpu().numpy()
527528

528-
mean = self.Y_scaler.inverse_transform(mean.reshape(-1, 1)).flatten()
529-
lower = self.Y_scaler.inverse_transform(lower.reshape(-1, 1)).flatten()
530-
upper = self.Y_scaler.inverse_transform(upper.reshape(-1, 1)).flatten()
529+
if context.model_config["surrogate_type"] == "cost_emulator":
530+
mean = self.Y_scaler.inverse_transform(mean.reshape(-1, 1)).flatten()
531+
lower = self.Y_scaler.inverse_transform(lower.reshape(-1, 1)).flatten()
532+
upper = self.Y_scaler.inverse_transform(upper.reshape(-1, 1)).flatten()
531533

532-
df = pd.DataFrame(
533-
{"discrepancy": mean, "lower_bound": lower, "upper_bound": upper}
534-
)
534+
df = pd.DataFrame(
535+
{"discrepancy": mean, "lower_bound": lower, "upper_bound": upper}
536+
)
535537

536-
return df
538+
return df

0 commit comments

Comments
 (0)