Skip to content

Commit 99bfdb7

Browse files
committed
[ENH] Add default ColorsGenerator and improve condition handling
Introduce `Field(default_factory=ColorsGenerator)` to `StructuralFrame` for streamlined initialization. Convert condition number to numpy array in optimization logic for compatibility. Minor adjustments made for improved model validation and notes added for future enhancements.
1 parent 0fb74e5 commit 99bfdb7

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

gempy/core/color_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ColorsGenerator:
3535
)
3636
_index: int = 0
3737

38-
def __init__(self):
38+
def __post_init__(self):
3939
self.regenerate_color_palette()
4040

4141
@staticmethod

gempy/core/data/structural_frame.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import warnings
55
from dataclasses import dataclass
6-
from pydantic import model_validator, computed_field, ValidationError
6+
from pydantic import model_validator, computed_field, ValidationError, Field
77
from pydantic.functional_validators import ModelWrapValidatorHandler
88
from typing import Generator, Union
99

@@ -32,11 +32,12 @@ class StructuralFrame:
3232
"""
3333

3434
structural_groups: list[StructuralGroup]
35+
color_generator: ColorsGenerator = Field(default_factory=ColorsGenerator)
3536
# ? Should I create some sort of structural options class? For example, the masking descriptor and faults relations pointer
3637
is_dirty: bool = True
3738

3839
# region Constructor
39-
40+
#
4041
def __init__(self, structural_groups: list[StructuralGroup], color_gen: ColorsGenerator):
4142
self.structural_groups = structural_groups # ? This maybe could be optional
4243
self.color_generator = color_gen

gempy/modules/optimize_nuggets/_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ def run_optimization(lr, max_epochs, min_impr, model, nugget, patience, target_c
4343
if _has_converged(cur_cond, prev_cond, target_cond_num, epoch, min_impr, patience):
4444
break
4545
prev_cond = cur_cond
46-
46+
47+
# Condition number to numpy
48+
model.interpolation_options.kernel_options.condition_number = model.interpolation_options.kernel_options.condition_number.detach().numpy()
4749
return nugget
4850

4951

test/test_private/test_terranigma/test_nuggets/test_nugget_effect_optimization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@ def test_optimize_nugget_effect():
6161
engine_config=gp.data.GemPyEngineConfig(
6262
backend=gp.data.AvailableBackends.PYTORCH,
6363
),
64-
validate_serialization=False
64+
validate_serialization=True
6565
)
66+
67+
# TODO: Save model
6668

6769
print(f"Final cond number: {geo_model.interpolation_options.kernel_options.condition_number}")
6870
nugget_effect = geo_model.taped_interpolation_input.surface_points.nugget_effect_scalar.detach().numpy()

0 commit comments

Comments
 (0)