diff --git a/gempy_probability/__init__.py b/gempy_probability/__init__.py index a695542..f0baa37 100644 --- a/gempy_probability/__init__.py +++ b/gempy_probability/__init__.py @@ -2,4 +2,24 @@ from .modules import likelihoods from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference -from ._version import __version__ +""" +Module initialisation for GemPy Probability +""" +import sys + +# * Assert at least python 3.10 +assert sys.version_info[0] >= 3 and sys.version_info[1] >= 10, "GemPy Probability requires Python 3.10 or higher" + +# Import version, with fallback if not generated yet +try: + from ._version import __version__ +except ImportError: + __version__ = "unknown" + +# =================== CORE =================== +# Import your core modules here + +# =================== API =================== + +if __name__ == '__main__': + pass diff --git a/gempy_probability/api/model_runner/_pyro_runner.py b/gempy_probability/api/model_runner/_pyro_runner.py index 1133957..9841856 100644 --- a/gempy_probability/api/model_runner/_pyro_runner.py +++ b/gempy_probability/api/model_runner/_pyro_runner.py @@ -11,8 +11,8 @@ from ...core.samplers_data import NUTSConfig -def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, - y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData: +def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, + y_obs_list: torch.Tensor, n_samples: int, plot_trace: bool = False) -> az.InferenceData: predictive = Predictive( model=prob_model, num_samples=n_samples @@ -24,13 +24,13 @@ def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, if plot_trace: az.plot_trace(data.prior) plt.show() - + return data -def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, - y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False, - run_posterior_predictive:bool=False) -> az.InferenceData: - + +def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, + y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace: bool = False, + run_posterior_predictive: bool = False) -> az.InferenceData: nuts_kernel = NUTS( prob_model, step_size=config.step_size, diff --git a/gempy_probability/modules/model_definition/prob_model_factory.py b/gempy_probability/modules/model_definition/prob_model_factory.py index eb99915..e1b414a 100644 --- a/gempy_probability/modules/model_definition/prob_model_factory.py +++ b/gempy_probability/modules/model_definition/prob_model_factory.py @@ -1,7 +1,7 @@ import pyro import torch from pyro.distributions import Distribution -from typing import Callable, Dict +from typing import Callable, Dict, Optional import gempy as gp from gempy_engine.core.backend_tensor import BackendTensor @@ -17,7 +17,7 @@ def make_gempy_pyro_model( [Dict[str, Distribution], gp.data.GeoModel], gp.data.InterpolationInput ], - likelihood_fn: Callable[[gp.data.Solutions], Distribution], + likelihood_fn: Optional[Callable[[gp.data.Solutions], Distribution]], obs_name: str = "obs" ) -> GemPyPyroModel: """ @@ -115,6 +115,9 @@ def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor): ) # 4) Wrap in likelihood & observe + if likelihood_fn is None: + return + lik_dist = likelihood_fn(simulated) pyro.sample(obs_name, lik_dist, obs=obs_data) diff --git a/gempy_probability/modules/plot/plot_gempy.py b/gempy_probability/modules/plot/plot_gempy.py new file mode 100644 index 0000000..19027c1 --- /dev/null +++ b/gempy_probability/modules/plot/plot_gempy.py @@ -0,0 +1,94 @@ +from typing import Callable, Optional + +import gempy as gp +import gempy_viewer as gpv +import numpy as np +from gempy_viewer.modules.plot_2d.visualization_2d import Plot2D + + +def plot_gempy( + geo_model, # gp.data.GeoModel - avoiding import + n_samples: int, + samples: np.ndarray, + update_model_fn: Callable, + gempy_plot: Plot2D, + plot_kwargs: Optional[dict] = None +): + """ + General function to plot GemPy models with uncertainty from prior/posterior samples. + + Parameters + ---------- + geo_model : gp.data.GeoModel + The geological model to update and plot + n_samples : int + Number of samples to plot + samples : np.ndarray + Array of sample values to iterate through + update_model_fn : Callable + Function that takes (geo_model, sample_value, sample_idx) and updates the model. + Should return None and modify geo_model in place. + gempy_plot : Plot2D + GemPy Plot2D object containing the figure and section data to plot on + plot_kwargs : dict, optional + Additional plotting kwargs for boundaries, surface points, etc. + + Examples + -------- + >>> def update_model_fn(geo_model, sample_value, sample_idx): + ... # Transform sample value to world coordinates + ... xyz = np.zeros((1, 3)) + ... xyz[0, 2] = sample_value + ... world_coord = geo_model.input_transform.apply_inverse(xyz) + ... # Modify surface point + ... gp.modify_surface_points(geo_model, slice=0, Z=world_coord[0, 2]) + >>> + >>> p2d = gpv.plot_2d(geo_model, show_lith=False, show_data=False, show=False) + >>> samples = prior_inference_data.prior['$\\mu_{top}$'].values[0, :] + >>> plot_gempy(geo_model, n_samples=50, samples=samples, + ... update_model_fn=update_model_fn, gempy_plot=p2d) + """ + # Import here to avoid circular dependencies and to make gempy optional + import gempy as gp + from gempy_viewer.API._plot_2d_sections_api import plot_sections + from gempy_viewer.core.data_to_show import DataToShow + + plot_kwargs = plot_kwargs or {} + + # Iterate through samples + for i in np.linspace(0, n_samples - 1, n_samples).astype(int): + # Update model using the provided function + update_model_fn(geo_model, samples[i], i) + + # Compute the model + gp.compute_model(gempy_model=geo_model) + + # Plot the updated model + default_plot_kwargs = { + 'kwargs_boundaries' : { + "linewidth": 0.5, + "alpha" : 0.1, + }, + 'kwargs_surface_points': { + 'alpha': 0.1 + }, + 'kwargs_orientations' : { + 'alpha': 0.02, + } + } + # Merge with user-provided kwargs (user kwargs override defaults) + final_plot_kwargs = {**default_plot_kwargs, **plot_kwargs} + + plot_sections( + gempy_model=geo_model, + sections_data=gempy_plot.section_data_list, + data_to_show=DataToShow( + n_axis=1, + show_data=True, + show_surfaces=True, + show_lith=False + ), + **final_plot_kwargs + ) + + gempy_plot.fig.show()