Skip to content

Commit 5061f77

Browse files
committed
[BUG] Make sure torch is optional
1 parent 6b27207 commit 5061f77

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

gempy/modules/optimize_nuggets/_ops.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
import torch
2-
31
import gempy_engine
2+
from gempy.optional_dependencies import require_torch
43
from gempy_engine.core.data.continue_epoch import ContinueEpoch
54

65

76
def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_cond_num):
7+
torch = require_torch()
88
opt = torch.optim.Adam(
99
params=[
1010
nugget,
@@ -32,7 +32,7 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
3232
_apply_outlier_gradients(tensor=nugget, mask=mask_sp)
3333

3434
# Step & clamp safely
35-
opt.step()
35+
opt.step()
3636
with torch.no_grad():
3737
nugget.clamp_(min=1e-7)
3838

@@ -47,14 +47,14 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
4747
return nugget
4848

4949

50-
def _mask_iqr(grads, multiplier: float = 1.5) -> torch.BoolTensor:
50+
def _mask_iqr(grads, multiplier: float = 1.5) -> "torch.BoolTensor":
5151
q1, q3 = grads.quantile(0.25), grads.quantile(0.75)
5252
thresh = q3 + multiplier * (q3 - q1)
5353
return grads > thresh
5454

5555
def _apply_outlier_gradients(
56-
tensor: torch.Tensor,
57-
mask: torch.BoolTensor,
56+
tensor: "torch.Tensor",
57+
mask: "torch.BoolTensor",
5858
amplification: float = 1.0,
5959
):
6060
# wrap in no_grad if you prefer, but .grad modifications are fine
@@ -65,6 +65,7 @@ def _apply_outlier_gradients(
6565

6666
def _gradient_masking(nugget, focus=0.01):
6767
"""Old way of avoiding exploding gradients."""
68+
torch = require_torch()
6869
grads = nugget.grad.abs()
6970
k = int(grads.numel() * focus)
7071
top_vals, top_idx = torch.topk(grads, k, largest=True)

gempy/optional_dependencies.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,11 @@ def require_zlib():
6565
import zlib
6666
except ImportError:
6767
raise ImportError("The zlib package is required to run this function.")
68-
return zlib
68+
return zlib
69+
70+
def require_torch():
71+
try:
72+
import torch
73+
except ImportError:
74+
raise ImportError("The torch package is required to run this function.")
75+
return torch

0 commit comments

Comments
 (0)