diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 36a3a09..590cd19 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -111,6 +111,8 @@ class CompiledPyMCModel(CompiledModel): _n_dim: int _shapes: dict[str, tuple[int, ...]] _coords: Optional[dict[str, Any]] + log_likelihood_names: list[str] + log_likelihood_shapes: list[tuple[int, ...]] @property def n_dim(self): @@ -220,6 +222,7 @@ def _compile_pymc_model_numba( model: "pm.Model", pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], var_names: Iterable[str] | None = None, + compute_log_likelihood: bool = False, **kwargs, ) -> CompiledPyMCModel: if find_spec("numba") is None: @@ -238,6 +241,7 @@ def _compile_pymc_model_numba( expand_fn_pt, initial_point_fn, shape_info, + log_likelihood_info, ) = _make_functions( model, mode="NUMBA", @@ -245,6 +249,7 @@ def _compile_pymc_model_numba( join_expanded=True, pymc_initial_point_fn=pymc_initial_point_fn, var_names=var_names, + compute_log_likelihood=compute_log_likelihood, ) expand_fn = expand_fn_pt.vm.jit_fn @@ -254,6 +259,9 @@ def _compile_pymc_model_numba( shared_vars = {} seen = set() for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: + # Skip RNG variables that don't have names + if val.name is None: + continue if val.name in shared_data and val not in seen: raise ValueError(f"Shared variables must have unique names: {val.name}") shared_data[val.name] = np.array(val.get_value(), order="C", copy=True) @@ -265,10 +273,16 @@ def _compile_pymc_model_numba( user_data = make_user_data(shared_vars, shared_data) - logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] + logp_shared_names = [ + var.name for var in logp_fn_pt.get_shared() if var.name is not None + ] logp_numba_raw, c_sig = _make_c_logp_func( - n_dim, logp_fn, user_data, logp_shared_names, shared_data + n_dim, logp_fn, user_data, logp_shared_names, shared_data, logp_fn_pt ) + + # Filter out compute_log_likelihood from kwargs for numba compilation + numba_kwargs = {k: v for k, v in kwargs.items() if k != "compute_log_likelihood"} + with warnings.catch_warnings(): warnings.filterwarnings( "ignore", @@ -276,11 +290,19 @@ def _compile_pymc_model_numba( category=numba.NumbaWarning, # type: ignore ) - logp_numba = numba.cfunc(c_sig, **kwargs)(logp_numba_raw) + logp_numba = numba.cfunc(c_sig, **numba_kwargs)(logp_numba_raw) - expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + expand_shared_names = [ + var.name for var in expand_fn_pt.get_shared() if var.name is not None + ] expand_numba_raw, c_sig_expand = _make_c_expand_func( - n_dim, n_expanded, expand_fn, user_data, expand_shared_names, shared_data + n_dim, + n_expanded, + expand_fn, + user_data, + expand_shared_names, + shared_data, + expand_fn_pt, ) with warnings.catch_warnings(): warnings.filterwarnings( @@ -289,10 +311,12 @@ def _compile_pymc_model_numba( category=numba.NumbaWarning, # type: ignore ) - expand_numba = numba.cfunc(c_sig_expand, **kwargs)(expand_numba_raw) + expand_numba = numba.cfunc(c_sig_expand, **numba_kwargs)(expand_numba_raw) dims, coords = _prepare_dims_and_coords(model, shape_info) + log_likelihood_names, log_likelihood_shapes = log_likelihood_info + return CompiledPyMCModel( _n_dim=n_dim, dims=dims, @@ -307,6 +331,8 @@ def _compile_pymc_model_numba( shape_info=shape_info, logp_func=logp_fn_pt, expand_func=expand_fn_pt, + log_likelihood_names=log_likelihood_names, + log_likelihood_shapes=log_likelihood_shapes, ) @@ -341,6 +367,7 @@ def _compile_pymc_model_jax( gradient_backend=None, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], var_names: Iterable[str] | None = None, + compute_log_likelihood: bool = False, **kwargs, ): if find_spec("jax") is None: @@ -364,6 +391,7 @@ def _compile_pymc_model_jax( expand_fn_pt, initial_point_fn, shape_info, + log_likelihood_info, ) = _make_functions( model, mode="JAX", @@ -371,18 +399,33 @@ def _compile_pymc_model_jax( join_expanded=False, pymc_initial_point_fn=pymc_initial_point_fn, var_names=var_names, + compute_log_likelihood=compute_log_likelihood, ) logp_fn = logp_fn_pt.vm.jit_fn expand_fn = expand_fn_pt.vm.jit_fn - logp_shared_names = [var.name for var in logp_fn_pt.get_shared()] - expand_shared_names = [var.name for var in expand_fn_pt.get_shared()] + logp_shared_names = [ + var.name for var in logp_fn_pt.get_shared() if var.name is not None + ] + expand_shared_names = [ + var.name for var in expand_fn_pt.get_shared() if var.name is not None + ] + + logp_rng_count = len([var for var in logp_fn_pt.get_shared() if var.name is None]) + expand_rng_count = len( + [var for var in expand_fn_pt.get_shared() if var.name is None] + ) if gradient_backend == "jax": orig_logp_fn = logp_fn._fun def logp_fn_jax_grad(x, *shared): + if len(shared) < logp_rng_count: + rng_dummies = [ + {"jax_state": jax.random.key(i)} for i in range(logp_rng_count) + ] + shared = shared + tuple(rng_dummies) return jax.value_and_grad(lambda x: orig_logp_fn(x, *shared)[0])(x) # static_argnums = list(range(1, len(logp_shared_names) + 1)) @@ -399,6 +442,9 @@ def logp_fn_jax_grad(x, *shared): shared_vars = {} seen = set() for val in [*logp_fn_pt.get_shared(), *expand_fn_pt.get_shared()]: + # Skip RNG variables that don't have names + if val.name is None: + continue if val.name in shared_data and val not in seen: raise ValueError(f"Shared variables must have unique names: {val.name}") shared_data[val.name] = jax.numpy.asarray(val.get_value()) @@ -407,7 +453,9 @@ def logp_fn_jax_grad(x, *shared): def make_logp_func(): def logp(_x, **shared): - logp, grad = logp_fn(_x, *[shared[name] for name in logp_shared_names]) + named_shared = [shared[name] for name in logp_shared_names] + rng_dummies = [{"jax_state": jax.random.key(0)}] * logp_rng_count + logp, grad = logp_fn(_x, *(named_shared + rng_dummies)) return float(logp), np.asarray(grad, dtype="float64", order="C") return logp @@ -419,7 +467,12 @@ def logp(_x, **shared): def make_expand_func(seed1, seed2, chain): # TODO handle seeds def expand(_x, **shared): - values = expand_fn(_x, *[shared[name] for name in expand_shared_names]) + named_shared = [shared[name] for name in expand_shared_names] + rng_dummies = [ + {"jax_state": jax.random.key(i)} for i in range(expand_rng_count) + ] + all_outputs = expand_fn(_x, *(named_shared + rng_dummies)) + values = all_outputs[: len(names)] return { name: np.asarray(val, order="C", dtype=dtype).ravel() for name, val, dtype in zip(names, values, dtypes, strict=True) @@ -428,6 +481,7 @@ def expand(_x, **shared): return expand dims, coords = _prepare_dims_and_coords(model, shape_info) + log_likelihood_names, log_likelihood_shapes = log_likelihood_info return from_pyfunc( ndim=n_dim, @@ -441,6 +495,8 @@ def expand(_x, **shared): dims=dims, coords=coords, raw_logp_fn=orig_logp_fn, + log_likelihood_names=log_likelihood_names, + log_likelihood_shapes=log_likelihood_shapes, ) @@ -457,6 +513,7 @@ def compile_pymc_model( ] = "support_point", var_names: Iterable[str] | None = None, freeze_model: bool | None = None, + compute_log_likelihood: bool = False, **kwargs, ) -> CompiledModel: """Compile necessary functions for sampling a pymc model. @@ -485,6 +542,9 @@ def compile_pymc_model( freeze_model : bool | None Freeze all dimensions and shared variables to treat them as compile time constants. + compute_log_likelihood : bool + Whether to compute element-wise log-likelihood values for observed variables. + When True, enables population of the log_likelihood group in ArviZ InferenceData. Returns ------- compiled_model : CompiledPyMCModel @@ -529,10 +589,26 @@ def compile_pymc_model( if backend.lower() == "numba": if gradient_backend == "jax": raise ValueError("Gradient backend cannot be jax when using numba backend") + # WORKAROUND: Disable log likelihood computation for numba backend due to + # PyTensor vectorization creating input_bc_patterns that must be literals. + # This is a known limitation documented in CLAUDE.md + if compute_log_likelihood: + import warnings + + warnings.warn( + "compute_log_likelihood=True is not supported with numba backend due to " + "PyTensor vectorization issues. Please use the JAX backend for log-likelihood " + "computation: backend='jax', gradient_backend='jax'", + UserWarning, + stacklevel=2, + ) + # Disable log likelihood computation and continue with numba + compute_log_likelihood = False return _compile_pymc_model_numba( model=model, pymc_initial_point_fn=initial_point_fn, var_names=var_names, + compute_log_likelihood=compute_log_likelihood, **kwargs, ) elif backend.lower() == "jax": @@ -541,6 +617,7 @@ def compile_pymc_model( gradient_backend=gradient_backend, pymc_initial_point_fn=initial_point_fn, var_names=var_names, + compute_log_likelihood=compute_log_likelihood, **kwargs, ) else: @@ -595,6 +672,7 @@ def _make_functions( join_expanded: bool, pymc_initial_point_fn: Callable[[SeedType], dict[str, np.ndarray]], var_names: Iterable[str] | None = None, + compute_log_likelihood: bool = False, ) -> tuple[ int, int, @@ -602,6 +680,7 @@ def _make_functions( Callable, Callable, tuple[list[str], list[slice], list[tuple[int, ...]]], + tuple[list[str], list[tuple[int, ...]]], ]: """ Compile functions required by nuts-rs from a given PyMC model. @@ -623,6 +702,8 @@ def _make_functions( pymc.initial_point.make_initial_point_fn var_names: Names of variables to store in the trace. Defaults to all variables. + compute_log_likelihood: bool + Whether to compute element-wise log-likelihood values for observed variables. Returns ------- @@ -646,10 +727,14 @@ def _make_functions( names of the variables, the second list contains the slices that correspond to the variables in the flat array, and the third list contains the shapes of the variables. + log_likelihood_data: tuple of lists + Tuple containing log-likelihood variable information. The first list contains + the names of the log-likelihood variables, the second list contains their shapes. """ import pytensor import pytensor.tensor as pt from pymc.pytensorf import compile as compile_pymc + import pymc as pm shapes = _compute_shapes(model) @@ -748,6 +833,39 @@ def _make_functions( num_expanded = count + log_likelihood_names = [] + log_likelihood_shapes = [] + log_likelihood_vars = [] + + if compute_log_likelihood: + for obs in model.observed_RVs: + obs_name = obs.name + log_lik_name = f"log_likelihood_{obs_name}" + + # Get element-wise log-likelihood + log_lik_expr = pm.logp(obs, obs) + log_lik_expr.name = log_lik_name + + # Get the shape of the observed variable + # For observed variables, we need to get the shape from the observed data + obs_shape = ( + tuple(obs.shape.eval()) if hasattr(obs.shape, "eval") else obs.shape + ) + + log_likelihood_vars.append(log_lik_expr) + log_likelihood_names.append(log_lik_name) + log_likelihood_shapes.append(obs_shape) + + all_names.append(log_lik_name) + all_shapes.append(obs_shape) + length = prod(obs_shape) + all_slices.append(slice(count, count + length)) + count += length + + num_expanded = count + + remaining_rvs.extend(log_likelihood_vars) + if join_expanded: allvars = [pt.concatenate([joined, *[var.ravel() for var in remaining_rvs]])] else: @@ -767,19 +885,74 @@ def _make_functions( expand_fn_pt, initial_point_fn, (all_names, all_slices, all_shapes), + (log_likelihood_names, log_likelihood_shapes), ) -def make_extraction_fn(inner, shared_data, shared_vars, record_dtype): +def make_extraction_fn(inner, shared_data, shared_vars, record_dtype, pytensor_fn=None): import numba from numba import literal_unroll from numba.cpython.unsafe.tuple import alloca_once, tuple_setitem if not shared_vars: - - @numba.njit(inline="always") - def extract_shared(x, user_data_): - return inner(x) + # Check if we have a PyTensor function to get the count of RNG shared variables + if pytensor_fn is not None: + all_shared = pytensor_fn.get_shared() + rng_count = len([var for var in all_shared if var.name is None]) + else: + rng_count = 0 + + if rng_count > 0: + # Have RNG shared variables but no named ones + # Create a wrapper that provides dummy values for RNG variables + def make_extract_with_rng_dummy(inner_fn, num_rng): + if num_rng == 1: + + @numba.njit(inline="always") + def extract_shared(x, user_data_): + rng = 0 # Dummy value - RNG state is managed elsewhere + return inner_fn(x, rng) + + return extract_shared + elif num_rng == 2: + + @numba.njit(inline="always") + def extract_shared(x, user_data_): + rng1, rng2 = 0, 0 # Dummy values + return inner_fn(x, rng1, rng2) + + return extract_shared + elif num_rng == 3: + + @numba.njit(inline="always") + def extract_shared(x, user_data_): + rng1, rng2, rng3 = 0, 0, 0 # Dummy values + return inner_fn(x, rng1, rng2, rng3) + + return extract_shared + elif num_rng == 4: + + @numba.njit(inline="always") + def extract_shared(x, user_data_): + rng1, rng2, rng3, rng4 = 0, 0, 0, 0 # Dummy values + return inner_fn(x, rng1, rng2, rng3, rng4) + + return extract_shared + else: + # Fallback for other counts - try 2 RNGs as default + @numba.njit(inline="always") + def extract_shared(x, user_data_): + rng1, rng2 = 0, 0 # Dummy values + return inner_fn(x, rng1, rng2) + + return extract_shared + + extract_shared = make_extract_with_rng_dummy(inner, rng_count) + else: + # No shared variables expected at all + @numba.njit(inline="always") + def extract_shared(x, user_data_): + return inner(x) return extract_shared @@ -867,10 +1040,14 @@ def extract_shared(x, user_data_): return extract_shared -def _make_c_logp_func(n_dim, logp_fn, user_data, shared_logp, shared_data): +def _make_c_logp_func( + n_dim, logp_fn, user_data, shared_logp, shared_data, logp_fn_pt=None +): import numba - extract = make_extraction_fn(logp_fn, shared_data, shared_logp, user_data.dtype) + extract = make_extraction_fn( + logp_fn, shared_data, shared_logp, user_data.dtype, logp_fn_pt + ) c_sig = numba.types.int64( numba.types.uint64, @@ -907,11 +1084,13 @@ def logp_numba(dim, x_, out_, logp_, user_data_): def _make_c_expand_func( - n_dim, n_expanded, expand_fn, user_data, shared_vars, shared_data + n_dim, n_expanded, expand_fn, user_data, shared_vars, shared_data, expand_fn_pt=None ): import numba - extract = make_extraction_fn(expand_fn, shared_data, shared_vars, user_data.dtype) + extract = make_extraction_fn( + expand_fn, shared_data, shared_vars, user_data.dtype, expand_fn_pt + ) c_sig = numba.types.int64( numba.types.uint64, @@ -931,7 +1110,13 @@ def expand_numba(dim, expanded, x_, out_, user_data_): x = numba.carray(x_, (n_dim,)) out = numba.carray(out_, (n_expanded,)) - (values,) = extract(x, user_data_) + result = extract(x, user_data_) + if isinstance(result, tuple): + # Take only the first element (the actual expanded variables) + # The remaining elements are RNG state which we don't need for output + values = result[0] + else: + values = result out[...] = values except Exception: # noqa: BLE001 diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 618feea..f7bd032 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -1,5 +1,5 @@ import dataclasses -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from typing import Any, Callable @@ -22,6 +22,8 @@ class PyFuncModel(CompiledModel): _coords: dict[str, Any] _raw_logp_fn: Callable | None _transform_adapt_args: dict | None = None + log_likelihood_names: list[str] = field(default_factory=list) + log_likelihood_shapes: list[tuple[int, ...]] = field(default_factory=list) @property def shapes(self) -> dict[str, tuple[int, ...]]: @@ -104,6 +106,8 @@ def from_pyfunc( make_initial_point_fn: Callable[[SeedType], np.ndarray] | None = None, make_transform_adapter=None, raw_logp_fn=None, + log_likelihood_names: list[str] | None = None, + log_likelihood_shapes: list[tuple[int, ...]] | None = None, ): variables = [] for name, shape, dtype in zip( @@ -124,6 +128,10 @@ def from_pyfunc( dims = {} if shared_data is None: shared_data = {} + if log_likelihood_names is None: + log_likelihood_names = [] + if log_likelihood_shapes is None: + log_likelihood_shapes = [] return PyFuncModel( _n_dim=ndim, @@ -135,4 +143,6 @@ def from_pyfunc( _variables=variables, _shared_data=shared_data, _raw_logp_fn=raw_logp_fn, + log_likelihood_names=log_likelihood_names, + log_likelihood_shapes=log_likelihood_shapes, ) diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 0655173..7356f91 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -54,14 +54,19 @@ def benchmark_logp(self, point, num_evals, cores): return pd.concat(times) -def _trace_to_arviz(traces, n_tune, shapes, **kwargs): +def _trace_to_arviz(traces, n_tune, shapes, log_likelihood_names=None, **kwargs): n_chains = len(traces) data_dict = {} data_dict_tune = {} + log_likelihood_dict = {} + log_likelihood_dict_tune = {} stats_dict = {} stats_dict_tune = {} + if log_likelihood_names is None: + log_likelihood_names = [] + draw_batches = [] stats_batches = [] for draws, stats in traces: @@ -85,8 +90,13 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): (len(chunk),) + shapes[name] ) - data_dict[name] = data[:, n_tune:] - data_dict_tune[name] = data[:, :n_tune] + # Separate log-likelihood variables from main data + if name in log_likelihood_names: + log_likelihood_dict[name] = data[:, n_tune:] + log_likelihood_dict_tune[name] = data[:, :n_tune] + else: + data_dict[name] = data[:, n_tune:] + data_dict_tune[name] = data[:, :n_tune] for name, col in zip(table_stats.column_names, table_stats.columns): if name in ["chain", "draw", "divergence_message"]: @@ -116,12 +126,19 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): stats_dict[name] = data[:, n_tune:] stats_dict_tune[name] = data[:, :n_tune] + # Pass log_likelihood data to arviz.from_dict if present + kwargs_with_log_likelihood = kwargs.copy() + if log_likelihood_dict: + kwargs_with_log_likelihood["log_likelihood"] = log_likelihood_dict + if log_likelihood_dict_tune: + kwargs_with_log_likelihood["warmup_log_likelihood"] = log_likelihood_dict_tune + return arviz.from_dict( data_dict, sample_stats=stats_dict, warmup_posterior=data_dict_tune, warmup_sample_stats=stats_dict_tune, - **kwargs, + **kwargs_with_log_likelihood, ) @@ -478,10 +495,12 @@ def _extract(self, results): if self._return_raw_trace: return results else: + log_likelihood_names = getattr(self._compiled_model, 'log_likelihood_names', []) return _trace_to_arviz( results, self._settings.num_tune, self._compiled_model.shapes, + log_likelihood_names=log_likelihood_names, dims=dims, coords={ name: pd.Index(vals) diff --git a/tests/test_log_likelihood_integration.py b/tests/test_log_likelihood_integration.py new file mode 100644 index 0000000..ee83812 --- /dev/null +++ b/tests/test_log_likelihood_integration.py @@ -0,0 +1,212 @@ +""" +Integration tests for log-likelihood calculation functionality. +""" + +from importlib.util import find_spec +import pytest +import numpy as np + +if find_spec("pymc") is None: + pytest.skip("Skip pymc tests", allow_module_level=True) + +import pymc as pm +import nutpie + + +@pytest.mark.pymc +def test_log_likelihood_compilation_numba_disabled(): + """Test that log-likelihood is properly disabled for numba backend with warning.""" + np.random.seed(42) + observed_data = np.random.normal(0, 1, 10) + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + pm.Normal("y", mu=mu, sigma=1, observed=observed_data) + + with pytest.warns( + UserWarning, + match="compute_log_likelihood=True is not supported with numba backend", + ): + compiled_model = nutpie.compile_pymc_model( + model, backend="numba", compute_log_likelihood=True + ) + + # Should be disabled despite being requested + assert hasattr(compiled_model, "log_likelihood_names") + assert hasattr(compiled_model, "log_likelihood_shapes") + assert compiled_model.log_likelihood_names == [] + assert compiled_model.log_likelihood_shapes == [] + + +@pytest.mark.pymc +def test_log_likelihood_compilation_disabled(): + """Test that compilation works with compute_log_likelihood=False.""" + np.random.seed(42) + observed_data = np.random.normal(0, 1, 10) + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + pm.Normal("y", mu=mu, sigma=1, observed=observed_data) + + compiled_model = nutpie.compile_pymc_model( + model, backend="numba", compute_log_likelihood=False + ) + assert compiled_model.log_likelihood_names == [] + assert compiled_model.log_likelihood_shapes == [] + + +@pytest.mark.pymc +def test_log_likelihood_basic_sampling_jax(): + """Test basic sampling with log-likelihood calculation using JAX backend.""" + np.random.seed(42) + observed_data = np.random.normal(2.0, 1.0, 20) + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=5) + sigma = pm.HalfNormal("sigma", sigma=2) + pm.Normal("y", mu=mu, sigma=sigma, observed=observed_data) + + compiled_model = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax", compute_log_likelihood=True + ) + trace = nutpie.sample(compiled_model, draws=10, tune=10, chains=1, cores=1) + + assert hasattr(trace, "log_likelihood"), ( + "log_likelihood group missing from InferenceData" + ) + assert "log_likelihood_y" in trace.log_likelihood.data_vars, ( + "log_likelihood_y not in log_likelihood group" + ) + + log_lik = trace.log_likelihood["log_likelihood_y"] + assert log_lik.shape == (1, 10, 20), ( + f"Expected shape (1, 10, 20), got {log_lik.shape}" + ) + assert np.all(np.isfinite(log_lik.values)), ( + "Log-likelihood values contain non-finite values" + ) + assert not np.all(log_lik.values == 0), "Log-likelihood values are all zero" + assert np.all(log_lik.values <= 0), "Log-likelihood values should be non-positive" + + +@pytest.mark.pymc +def test_log_likelihood_multiple_observed_jax(): + """Test log-likelihood calculation with multiple observed variables using JAX backend.""" + np.random.seed(42) + n_obs1, n_obs2 = 15, 10 + observed_data1 = np.random.normal(1.0, 0.5, n_obs1) + observed_data2 = np.random.normal(-1.0, 1.0, n_obs2) + + with pm.Model() as model: + mu1 = pm.Normal("mu1", mu=0, sigma=2) + mu2 = pm.Normal("mu2", mu=0, sigma=2) + pm.Normal("y1", mu=mu1, sigma=0.5, observed=observed_data1) + pm.Normal("y2", mu=mu2, sigma=1.0, observed=observed_data2) + + compiled_model = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax", compute_log_likelihood=True + ) + assert len(compiled_model.log_likelihood_names) == 2 + assert "log_likelihood_y1" in compiled_model.log_likelihood_names + assert "log_likelihood_y2" in compiled_model.log_likelihood_names + + y1_idx = compiled_model.log_likelihood_names.index("log_likelihood_y1") + y2_idx = compiled_model.log_likelihood_names.index("log_likelihood_y2") + + assert compiled_model.log_likelihood_shapes[y1_idx] == (n_obs1,) + assert compiled_model.log_likelihood_shapes[y2_idx] == (n_obs2,) + + trace = nutpie.sample(compiled_model, draws=8, tune=8, chains=1, cores=1) + assert "log_likelihood_y1" in trace.log_likelihood.data_vars + assert "log_likelihood_y2" in trace.log_likelihood.data_vars + + log_lik1 = trace.log_likelihood["log_likelihood_y1"] + log_lik2 = trace.log_likelihood["log_likelihood_y2"] + + assert log_lik1.shape == (1, 8, n_obs1) + assert log_lik2.shape == (1, 8, n_obs2) + + assert np.all(np.isfinite(log_lik1.values)) + assert np.all(np.isfinite(log_lik2.values)) + assert np.all(log_lik1.values <= 0) + assert np.all(log_lik2.values <= 0) + + +@pytest.mark.pymc +def test_log_likelihood_scalar_observed_jax(): + """Test log-likelihood calculation with scalar observed variable using JAX backend.""" + np.random.seed(42) + observed_value = 3.5 + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=2) + pm.Normal("y", mu=mu, sigma=1, observed=observed_value) + + compiled_model = nutpie.compile_pymc_model( + model, backend="jax", gradient_backend="jax", compute_log_likelihood=True + ) + assert "log_likelihood_y" in compiled_model.log_likelihood_names + y_idx = compiled_model.log_likelihood_names.index("log_likelihood_y") + assert compiled_model.log_likelihood_shapes[y_idx] == () + trace = nutpie.sample(compiled_model, draws=6, tune=6, chains=1, cores=1) + + assert "log_likelihood_y" in trace.log_likelihood.data_vars + log_lik = trace.log_likelihood["log_likelihood_y"] + + assert log_lik.shape == (1, 6), f"Expected shape (1, 6), got {log_lik.shape}" + assert np.all(np.isfinite(log_lik.values)) + assert np.all(log_lik.values <= 0) + + +@pytest.mark.pymc +def test_log_likelihood_backward_compatibility(): + """Test that existing code without compute_log_likelihood still works.""" + np.random.seed(42) + observed_data = np.random.normal(0, 1, 5) + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=1) + pm.Normal("y", mu=mu, sigma=1, observed=observed_data) + + compiled_model = nutpie.compile_pymc_model(model, backend="numba") + assert compiled_model.log_likelihood_names == [] + assert compiled_model.log_likelihood_shapes == [] + + trace = nutpie.sample(compiled_model, draws=5, tune=5, chains=1, cores=1) + if hasattr(trace, "log_likelihood"): + assert len(trace.log_likelihood.data_vars) == 0 + + assert "mu" in trace.posterior.data_vars + + +@pytest.mark.pymc +def test_log_likelihood_numba_sampling_without_log_lik(): + """Test that numba backend works correctly when log-likelihood is disabled.""" + np.random.seed(42) + observed_data = np.random.normal(2.0, 1.0, 20) + + with pm.Model() as model: + mu = pm.Normal("mu", mu=0, sigma=5) + sigma = pm.HalfNormal("sigma", sigma=2) + pm.Normal("y", mu=mu, sigma=sigma, observed=observed_data) + + # Test with explicit compute_log_likelihood=True (should be disabled with warning) + with pytest.warns( + UserWarning, + match="compute_log_likelihood=True is not supported with numba backend", + ): + compiled_model = nutpie.compile_pymc_model( + model, backend="numba", compute_log_likelihood=True + ) + + # Verify log-likelihood is disabled + assert compiled_model.log_likelihood_names == [] + assert compiled_model.log_likelihood_shapes == [] + + # Sampling should still work correctly + trace = nutpie.sample(compiled_model, draws=10, tune=10, chains=1, cores=1) + assert ( + not hasattr(trace, "log_likelihood") or len(trace.log_likelihood.data_vars) == 0 + ) + assert "mu" in trace.posterior.data_vars + assert "sigma" in trace.posterior.data_vars