Skip to content

Commit 4c555f6

Browse files
committed
[ENH] Implement .env configuration setup for engine
The commit introduces a .env configuration setup system for GemPy engine and cleans up compute_API.py. Default settings for the application are now sourced from a .env file, enhancing customizable functionality while maintaining the application's default behavior. The .env.example file is added for sample settings. In addition, the requirement, python-dotenv, has been added to requirements.txt. Minor refactoring and formatting changes have been applied to compute_API.py and gempy_engine_config.py for clean and efficient code.
1 parent 4efbdac commit 4c555f6

File tree

2 files changed

+9
-12
lines changed

2 files changed

+9
-12
lines changed

gempy/API/compute_API.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConf
2727
Returns:
2828
Solutions: The computed geological model.
2929
"""
30-
engine_config = engine_config or GemPyEngineConfig(
31-
backend=AvailableBackends.numpy,
32-
use_gpu=False,
33-
)
30+
engine_config = engine_config or GemPyEngineConfig(use_gpu=False)
3431

3532
match engine_config.backend:
3633
case AvailableBackends.numpy | AvailableBackends.PYTORCH:
@@ -43,7 +40,7 @@ def compute_model(gempy_model: GeoModel, engine_config: Optional[GemPyEngineConf
4340

4441
# TODO: To decide what to do with this.
4542
interpolation_input = gempy_model.interpolation_input_copy
46-
gempy_model.taped_interpolation_input = interpolation_input # * This is used for gradient tape
43+
gempy_model.taped_interpolation_input = interpolation_input # * This is used for gradient tape
4744

4845
gempy_model.solutions = gempy_engine.compute_model(
4946
interpolation_input=interpolation_input,
@@ -84,8 +81,7 @@ def compute_model_at(gempy_model: GeoModel, at: np.ndarray,
8481
return sol.raw_arrays.custom
8582

8683

87-
88-
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
84+
def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig, max_epochs: int = 10,
8985
convergence_criteria: float = 1e5):
9086
if engine_config.backend != AvailableBackends.PYTORCH:
9187
raise ValueError(f'Only PyTorch backend is supported for optimization. Received {engine_config.backend}')
@@ -112,15 +108,15 @@ def optimize_and_compute(geo_model: GeoModel, engine_config: GemPyEngineConfig,
112108

113109
# Optimization loop
114110
geo_model.interpolation_options.kernel_options.optimizing_condition_number = True
115-
111+
116112
def _check_convergence_criterion(conditional_number: float, condition_number_old: float, conditional_number_target: float = 1e5):
117113
reached_conditional_target = conditional_number < conditional_number_target
118114
if reached_conditional_target == False and epoch > 10:
119115
condition_number_change = torch.abs(conditional_number - condition_number_old) / condition_number_old
120116
if condition_number_change < 0.01:
121117
reached_conditional_target = True
122118
return reached_conditional_target
123-
119+
124120
previous_condition_number = 0
125121
for epoch in range(max_epochs):
126122
optimizer.zero_grad()
@@ -149,11 +145,11 @@ def _check_convergence_criterion(conditional_number: float, condition_number_old
149145
mask = torch.ones_like(nugget_effect_scalar.grad)
150146
mask[indices] = 0
151147
nugget_effect_scalar.grad *= mask
152-
148+
153149
# Update the vector
154150
optimizer.step()
155151
nugget_effect_scalar.data = nugget_effect_scalar.data.clamp_(min=1e-7) # Replace negative values with 0
156-
152+
157153
# optimizer.zero_grad()
158154
# Monitor progress
159155
if epoch % 1 == 0:
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from dataclasses import dataclass
22
from typing import Optional
33

4+
from gempy_engine import config
45
from gempy_engine.config import AvailableBackends
56

67

78
@dataclass
89
class GemPyEngineConfig:
9-
backend: AvailableBackends = AvailableBackends.numpy # ? This can be grabbed from gempy.config file?
10+
backend: AvailableBackends = config.DEFAULT_BACKEND # ? This can be grabbed from gempy.config file?
1011
use_gpu: bool = False
1112
dtype: Optional[str] = None #: The data type used in the engine. If None, the default data type of the backend is used.
1213

0 commit comments

Comments
 (0)