Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion gempy_probability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,24 @@
from .modules import likelihoods
from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference

from ._version import __version__
"""
Module initialisation for GemPy Probability
"""
import sys

# * Assert at least python 3.10
assert sys.version_info[0] >= 3 and sys.version_info[1] >= 10, "GemPy Probability requires Python 3.10 or higher"

# Import version, with fallback if not generated yet
try:
from ._version import __version__
except ImportError:
__version__ = "unknown"

# =================== CORE ===================
# Import your core modules here

# =================== API ===================

if __name__ == '__main__':
pass
14 changes: 7 additions & 7 deletions gempy_probability/api/model_runner/_pyro_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from ...core.samplers_data import NUTSConfig


def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData:
def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, n_samples: int, plot_trace: bool = False) -> az.InferenceData:
predictive = Predictive(
model=prob_model,
num_samples=n_samples
Expand All @@ -24,13 +24,13 @@ def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
if plot_trace:
az.plot_trace(data.prior)
plt.show()

return data

def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False,
run_posterior_predictive:bool=False) -> az.InferenceData:


def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel,
y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace: bool = False,
run_posterior_predictive: bool = False) -> az.InferenceData:
nuts_kernel = NUTS(
prob_model,
step_size=config.step_size,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pyro
import torch
from pyro.distributions import Distribution
from typing import Callable, Dict
from typing import Callable, Dict, Optional

import gempy as gp
from gempy_engine.core.backend_tensor import BackendTensor
Expand All @@ -17,7 +17,7 @@ def make_gempy_pyro_model(
[Dict[str, Distribution], gp.data.GeoModel],
gp.data.InterpolationInput
],
likelihood_fn: Callable[[gp.data.Solutions], Distribution],
likelihood_fn: Optional[Callable[[gp.data.Solutions], Distribution]],
obs_name: str = "obs"
) -> GemPyPyroModel:
"""
Expand Down Expand Up @@ -115,6 +115,9 @@ def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor):
)

# 4) Wrap in likelihood & observe
if likelihood_fn is None:
return

lik_dist = likelihood_fn(simulated)
pyro.sample(obs_name, lik_dist, obs=obs_data)

Expand Down
94 changes: 94 additions & 0 deletions gempy_probability/modules/plot/plot_gempy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Callable, Optional

import gempy as gp
import gempy_viewer as gpv
import numpy as np
from gempy_viewer.modules.plot_2d.visualization_2d import Plot2D


def plot_gempy(
geo_model, # gp.data.GeoModel - avoiding import
n_samples: int,
samples: np.ndarray,
update_model_fn: Callable,
gempy_plot: Plot2D,
plot_kwargs: Optional[dict] = None
):
"""
General function to plot GemPy models with uncertainty from prior/posterior samples.

Parameters
----------
geo_model : gp.data.GeoModel
The geological model to update and plot
n_samples : int
Number of samples to plot
samples : np.ndarray
Array of sample values to iterate through
update_model_fn : Callable
Function that takes (geo_model, sample_value, sample_idx) and updates the model.
Should return None and modify geo_model in place.
gempy_plot : Plot2D
GemPy Plot2D object containing the figure and section data to plot on
plot_kwargs : dict, optional
Additional plotting kwargs for boundaries, surface points, etc.

Examples
--------
>>> def update_model_fn(geo_model, sample_value, sample_idx):
... # Transform sample value to world coordinates
... xyz = np.zeros((1, 3))
... xyz[0, 2] = sample_value
... world_coord = geo_model.input_transform.apply_inverse(xyz)
... # Modify surface point
... gp.modify_surface_points(geo_model, slice=0, Z=world_coord[0, 2])
>>>
>>> p2d = gpv.plot_2d(geo_model, show_lith=False, show_data=False, show=False)
>>> samples = prior_inference_data.prior['$\\mu_{top}$'].values[0, :]
>>> plot_gempy(geo_model, n_samples=50, samples=samples,
... update_model_fn=update_model_fn, gempy_plot=p2d)
"""
# Import here to avoid circular dependencies and to make gempy optional
import gempy as gp
from gempy_viewer.API._plot_2d_sections_api import plot_sections
from gempy_viewer.core.data_to_show import DataToShow

plot_kwargs = plot_kwargs or {}

# Iterate through samples
for i in np.linspace(0, n_samples - 1, n_samples).astype(int):
# Update model using the provided function
update_model_fn(geo_model, samples[i], i)

# Compute the model
gp.compute_model(gempy_model=geo_model)

# Plot the updated model
default_plot_kwargs = {
'kwargs_boundaries' : {
"linewidth": 0.5,
"alpha" : 0.1,
},
'kwargs_surface_points': {
'alpha': 0.1
},
'kwargs_orientations' : {
'alpha': 0.02,
}
}
# Merge with user-provided kwargs (user kwargs override defaults)
final_plot_kwargs = {**default_plot_kwargs, **plot_kwargs}

plot_sections(
gempy_model=geo_model,
sections_data=gempy_plot.section_data_list,
data_to_show=DataToShow(
n_axis=1,
show_data=True,
show_surfaces=True,
show_lith=False
),
**final_plot_kwargs
)

gempy_plot.fig.show()