Skip to content

Commit d4419c3

Browse files
authored
[FIX] Use timestamp for cache key instead of input parameters (#23)
### TL;DR Improved caching mechanism by using a timestamp-based approach for scalar field interpolation. ### What changed? - Modified the cache key generation in `interpolate_scalar_field` to use a timestamp instead of multiple parameters - Added a `start_computation_ts` field to the `TempInterpolationValues` class to store the computation start time - Updated the `compute_model` function to set the timestamp at the beginning of computation - Simplified error handling by raising a ValueError when cache is corrupted instead of attempting to recalculate weights ### How to test? 1. Run a model computation and verify that caching works correctly 2. Run multiple computations in sequence to ensure the timestamp-based caching properly identifies different computation runs 3. Test cache invalidation by modifying input data between runs ### Why make this change? The previous caching mechanism relied on multiple parameters which could lead to unnecessary cache misses or complex hash generation. Using a timestamp simplifies the caching logic while still providing a unique identifier for each computation run. This approach is more efficient and reduces the risk of cache-related bugs.
2 parents f5cb64c + 67aa07f commit d4419c3

File tree

9 files changed

+102
-57
lines changed

9 files changed

+102
-57
lines changed

gempy_engine/API/interp_single/_interp_scalar_field.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import warnings
2+
13
from typing import Tuple, Optional
24

35
import numpy as np
@@ -29,15 +31,17 @@ def interpolate_scalar_field(solver_input: SolverInput, options: InterpolationOp
2931
key=weights_key,
3032
look_in_disk= not options.cache_mode == InterpolationOptions.CacheMode.IN_MEMORY_CACHE
3133
)
32-
weights_hash = generate_cache_key(
33-
name="",
34-
parameters={
35-
"surface_points": solver_input.sp_internal,
36-
"orientations" : solver_input.ori_internal,
37-
"fault_internal": solver_input._fault_internal.fault_values_on_sp,
38-
"kernel_options": options.kernel_options
39-
}
40-
)
34+
ts = options.temp_interpolation_values.start_computation_ts
35+
if ts == -1:
36+
warnings.warn("No start computation timestamp found. No caching.")
37+
weights_cached = None
38+
else:
39+
weights_hash = generate_cache_key(
40+
name="",
41+
parameters={
42+
"ts": ts
43+
}
44+
)
4145
case InterpolationOptions.CacheMode.CLEAR_CACHE:
4246
WeightCache.initialize_cache_dir()
4347
weights_cached = None
@@ -54,9 +58,7 @@ def interpolate_scalar_field(solver_input: SolverInput, options: InterpolationOp
5458
weights_key=weights_key,
5559
weights_hash=weights_hash
5660
)
57-
5861
case _ if weights_cached["hash"] != weights_hash:
59-
solver_input.weights_x0 = weights_cached["weights"]
6062
weights = _solve_and_store_weights(
6163
solver_input=solver_input,
6264
kernel_options=options.kernel_options,

gempy_engine/API/model/model_api.py

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import time
12
import copy
23
from typing import List, Optional
34

@@ -15,63 +16,70 @@
1516
from ...core.data.input_data_descriptor import InputDataDescriptor
1617
from ...core.data.interpolation_input import InterpolationInput
1718
from ...core.utils import gempy_profiler_decorator
18-
19-
19+
from ...modules.weights_cache.weights_cache_interface import WeightCache
2020

2121

2222
@gempy_profiler_decorator
2323
def compute_model(interpolation_input: InterpolationInput, options: InterpolationOptions,
2424
data_descriptor: InputDataDescriptor, *, geophysics_input: Optional[GeophysicsInput] = None) -> Solutions:
2525

26-
# ! If we inline this it seems the deepcopy does not work
27-
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and NOT_MAKE_INPUT_DEEP_COPY is False:
28-
interpolation_input = copy.deepcopy(interpolation_input)
26+
try:
27+
WeightCache.initialize_cache_dir()
28+
options.temp_interpolation_values.start_computation_ts = int(time.time())
2929

30-
# Check input is valid
31-
_check_input_validity(interpolation_input, options, data_descriptor)
32-
33-
output: list[OctreeLevel] = interpolate_n_octree_levels(
34-
interpolation_input=interpolation_input,
35-
options=options,
36-
data_descriptor=data_descriptor
37-
)
38-
# region Geophysics
39-
# ---------------------
40-
# TODO: [x] Gravity
41-
# TODO: [ ] Magnetics
42-
43-
if geophysics_input is not None:
44-
first_level_last_field: InterpOutput = output[0].outputs_centers[-1]
45-
gravity = compute_gravity(
46-
geophysics_input=geophysics_input,
47-
root_ouput=first_level_last_field
48-
)
49-
else:
50-
gravity = None
51-
52-
# endregion
30+
# ! If we inline this it seems the deepcopy does not work
31+
if BackendTensor.engine_backend is not AvailableBackends.PYTORCH and NOT_MAKE_INPUT_DEEP_COPY is False:
32+
interpolation_input = copy.deepcopy(interpolation_input)
33+
34+
# Check input is valid
35+
_check_input_validity(interpolation_input, options, data_descriptor)
5336

54-
meshes: Optional[list[DualContouringMesh]] = None
55-
if options.mesh_extraction:
56-
if interpolation_input.grid.octree_grid is None:
57-
raise ValueError("Octree grid must be defined to extract the mesh")
58-
59-
meshes: list[DualContouringMesh] = dual_contouring_multi_scalar(
60-
data_descriptor=data_descriptor,
37+
output: list[OctreeLevel] = interpolate_n_octree_levels(
6138
interpolation_input=interpolation_input,
6239
options=options,
63-
octree_list=output[:options.number_octree_levels_surface]
40+
data_descriptor=data_descriptor
6441
)
42+
# region Geophysics
43+
# ---------------------
44+
# TODO: [x] Gravity
45+
# TODO: [ ] Magnetics
46+
47+
if geophysics_input is not None:
48+
first_level_last_field: InterpOutput = output[0].outputs_centers[-1]
49+
gravity = compute_gravity(
50+
geophysics_input=geophysics_input,
51+
root_ouput=first_level_last_field
52+
)
53+
else:
54+
gravity = None
55+
56+
# endregion
6557

66-
solutions = Solutions(
67-
octrees_output=output,
68-
dc_meshes=meshes,
69-
fw_gravity=gravity,
70-
block_solution_type=options.block_solutions_type
71-
)
58+
meshes: Optional[list[DualContouringMesh]] = None
59+
if options.mesh_extraction:
60+
if interpolation_input.grid.octree_grid is None:
61+
raise ValueError("Octree grid must be defined to extract the mesh")
62+
63+
meshes: list[DualContouringMesh] = dual_contouring_multi_scalar(
64+
data_descriptor=data_descriptor,
65+
interpolation_input=interpolation_input,
66+
options=options,
67+
octree_list=output[:options.number_octree_levels_surface]
68+
)
69+
70+
solutions = Solutions(
71+
octrees_output=output,
72+
dc_meshes=meshes,
73+
fw_gravity=gravity,
74+
block_solution_type=options.block_solutions_type
75+
)
7276

73-
if options.debug:
74-
solutions.debug_input_data["stack_interpolation_input"] = interpolation_input
77+
if options.debug:
78+
solutions.debug_input_data["stack_interpolation_input"] = interpolation_input
79+
except Exception as e:
80+
raise e
81+
finally:
82+
options.temp_interpolation_values.start_computation_ts = -1
7583

7684
return solutions
7785

gempy_engine/core/data/options/temp_interpolation_values.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
@dataclass
55
class TempInterpolationValues:
66
current_octree_level: int = 0 # * Make this a read only property
7+
start_computation_ts: int = -1

gempy_engine/modules/evaluator/generic_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ def _eval_on(
7272
eval_kernel = yield_evaluation_kernel(
7373
solver_input, options.kernel_options, slice_array=slice_array
7474
)
75-
scalar_field = (eval_kernel.T @ weights).reshape(-1)
75+
try:
76+
scalar_field = (eval_kernel.T @ weights).reshape(-1)
77+
except ValueError:
78+
pass
7679

7780
gx_field: Optional[np.ndarray] = None
7881
gy_field: Optional[np.ndarray] = None

tests/fixtures/complex_geometries.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def one_fault_model():
8383
number_dimensions=3,
8484
kernel_function=AvailableKernelFunctions.exponential)
8585

86+
options.cache_mode = InterpolationOptions.CacheMode.NO_CACHE
8687
# endregion
8788

8889
return interpolation_input, input_data_descriptor, options

tests/fixtures/simple_geometries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def unconformity() -> Tuple[InterpolationInput, InterpolationOptions, InputDataD
4949
options = InterpolationOptions.from_args(range_, c_o, uni_degree=1, i_res=i_r, gi_res=gi_r,
5050
number_dimensions=3,
5151
kernel_function=AvailableKernelFunctions.cubic)
52-
52+
options.cache_mode = InterpolationOptions.CacheMode.NO_CACHE
5353
resolution = [2, 2, 2]
5454
extent = [0, 1000, 0, 1000, 0, 1000]
5555

tests/fixtures/simple_models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from gempy_engine.modules.data_preprocess._input_preparation import surface_points_preprocess, \
3434
orientations_preprocess
3535

36+
3637
np.set_printoptions(precision=3, linewidth=200)
3738

3839
dir_name = os.path.dirname(__file__)
@@ -136,6 +137,9 @@ def _gen():
136137

137138

138139
def simple_model_interpolation_input_factory():
140+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
141+
WeightCache.initialize_cache_dir()
142+
139143
resolution = [2, 2, 3]
140144
extent = [0.25, .75, 0.25, .75, 0.25, .75]
141145

@@ -165,6 +169,7 @@ def simple_model_interpolation_input_factory():
165169
ori_i = Orientations(dip_positions, dip_gradients, nugget_effect_grad)
166170
interpolation_options = InterpolationOptions.from_args(range_, co, 0, number_dimensions=3,
167171
kernel_function=AvailableKernelFunctions.cubic)
172+
interpolation_options.cache_mode = InterpolationOptions.CacheMode.NO_CACHE
168173
ids = np.array([1, 2])
169174
interpolation_input = InterpolationInput(spi, ori_i, grid_0_centers, ids)
170175
tensor_struct = TensorsStructure(np.array([7]))
@@ -189,6 +194,10 @@ def simple_model_3_layers(simple_grid_3d_octree) -> Tuple[InterpolationInput, In
189194

190195

191196
def _gen_simple_model_3_layers(simple_grid_3d_octree):
197+
198+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
199+
WeightCache.initialize_cache_dir()
200+
192201
grid_0_centers = dataclasses.replace(simple_grid_3d_octree)
193202
np.set_printoptions(precision=3, linewidth=200)
194203
dip_positions = np.array([
@@ -232,6 +241,8 @@ def _gen_simple_model_3_layers(simple_grid_3d_octree):
232241

233242
@pytest.fixture(scope="session")
234243
def simple_model_3_layers_high_res(simple_grid_3d_more_points_grid) -> Tuple[InterpolationInput, InterpolationOptions, InputDataDescriptor]:
244+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
245+
WeightCache.initialize_cache_dir()
235246
grid_0_centers = dataclasses.replace(simple_grid_3d_more_points_grid)
236247

237248
np.set_printoptions(precision=3, linewidth=200)
@@ -269,6 +280,7 @@ def simple_model_3_layers_high_res(simple_grid_3d_more_points_grid) -> Tuple[Int
269280

270281
interpolation_options = InterpolationOptions.from_args(range_, co, 0,
271282
number_dimensions=3, kernel_function=AvailableKernelFunctions.cubic)
283+
interpolation_options.cache_mode = InterpolationOptions.CacheMode.NO_CACHE
272284

273285
ids = np.array([1, 2, 3, 4])
274286

tests/test_common/test_api/test_faults/test_one_fault.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def test_one_fault_model_thickness(one_fault_model, n_oct_levels=2):
107107

108108

109109
def test_one_fault_model_finite_fault(one_fault_model, n_oct_levels=4):
110+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
111+
WeightCache.initialize_cache_dir()
112+
110113
interpolation_input: InterpolationInput
111114
structure: InputDataDescriptor
112115
options: InterpolationOptions
@@ -171,6 +174,9 @@ def test_one_fault_model_finite_fault(one_fault_model, n_oct_levels=4):
171174

172175

173176
def test_implicit_ellipsoid_projection_on_fault(one_fault_model):
177+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
178+
WeightCache.initialize_cache_dir()
179+
174180
interpolation_input: InterpolationInput
175181
structure: InputDataDescriptor
176182
options: InterpolationOptions

tests/test_common/test_integrations/test_interpolate_model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313

1414
def test_interpolate_model(simple_model_interpolation_input, n_oct_levels=3):
15+
16+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
17+
WeightCache.initialize_cache_dir()
1518
"""Kernel function Cubic"""
1619
interpolation_input, options, structure = simple_model_interpolation_input
1720
print(interpolation_input)
@@ -29,6 +32,9 @@ def test_interpolate_model(simple_model_interpolation_input, n_oct_levels=3):
2932

3033
@pytest.mark.skipif(TEST_SPEED.value <= 1, reason="Global test speed below this test value.")
3134
def test_interpolate_model_no_octtree(simple_model_3_layers_high_res, n_oct_levels=2):
35+
36+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
37+
WeightCache.initialize_cache_dir()
3238
interpolation_input, options, structure = simple_model_3_layers_high_res
3339
print(interpolation_input)
3440

@@ -44,6 +50,9 @@ def test_interpolate_model_no_octtree(simple_model_3_layers_high_res, n_oct_leve
4450

4551

4652
def test_interpolate_model_several_surfaces(simple_model_3_layers, n_oct_levels=3):
53+
54+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
55+
WeightCache.initialize_cache_dir()
4756
interpolation_input, options, structure = simple_model_3_layers
4857
print(interpolation_input)
4958

@@ -63,6 +72,9 @@ def test_interpolate_model_several_surfaces(simple_model_3_layers, n_oct_levels=
6372

6473

6574
def test_interpolate_model_unconformity(unconformity, n_oct_levels=4):
75+
from gempy_engine.modules.weights_cache.weights_cache_interface import WeightCache
76+
WeightCache.initialize_cache_dir()
77+
6678
interpolation_input, options, structure = unconformity
6779
print(interpolation_input)
6880

0 commit comments

Comments
 (0)