diff --git a/lightgbmlss/distributions/BernsteinFlow.py b/lightgbmlss/distributions/BernsteinFlow.py new file mode 100644 index 0000000..489b936 --- /dev/null +++ b/lightgbmlss/distributions/BernsteinFlow.py @@ -0,0 +1,299 @@ +import torch +import torch.nn as nn +import numpy as np +from torch.distributions import identity_transform, SigmoidTransform, SoftplusTransform +from pyro.distributions import Normal +from pyro.distributions.transforms import Transform +from .flow_utils import NormalizingFlowClass +from ..utils import identity_fn + + +class BernsteinQuantileTransform(Transform): + """ + Bernstein polynomial quantile transform. + + This transform uses Bernstein polynomials to parameterize a quantile function (inverse CDF). + The Bernstein polynomial of degree n is defined as: + + Q(u) = sum_{k=0}^n beta_k * B_{n,k}(u) + + where B_{n,k}(u) = C(n,k) * u^k * (1-u)^{n-k} are the Bernstein basis functions, + C(n,k) is the binomial coefficient, and beta_k are the learnable parameters that + represent quantile values at specific points. + + The monotonicity constraint is enforced by requiring beta_k <= beta_{k+1}. + """ + + domain = torch.distributions.constraints.unit_interval + codomain = torch.distributions.constraints.real + bijective = True + sign = +1 + + def __init__(self, degree, support_bounds=(-5.0, 5.0)): + super().__init__() + self.degree = degree + self.support_bounds = support_bounds + + # Initialize parameters with sorted values to ensure monotonicity + init_values = torch.linspace(support_bounds[0], support_bounds[1], degree + 1) + self.raw_betas = nn.Parameter(init_values) + + # Precompute binomial coefficients + self.register_buffer('binomial_coeffs', self._compute_binomial_coefficients(degree)) + + def _compute_binomial_coefficients(self, n): + """Compute binomial coefficients C(n,k) for k=0,...,n""" + coeffs = torch.zeros(n + 1) + for k in range(n + 1): + # Use scipy.special.comb for compatibility, fallback to manual calculation + try: + from scipy.special import comb + coeffs[k] = torch.tensor(float(comb(n, k))) + except ImportError: + # Manual binomial coefficient calculation: C(n,k) = n! / (k! * (n-k)!) + if k == 0 or k == n: + coeffs[k] = 1.0 + else: + coeff = 1.0 + for i in range(min(k, n-k)): + coeff = coeff * (n - i) / (i + 1) + coeffs[k] = torch.tensor(float(coeff)) + return coeffs + + @property + def betas(self): + """Ensure monotonicity by using cumulative sum""" + return torch.cumsum(torch.cat([self.raw_betas[:1], torch.nn.functional.softplus(self.raw_betas[1:] - self.raw_betas[:-1])]), dim=0) + + def _bernstein_basis(self, u, k): + """Compute k-th Bernstein basis polynomial of degree n at u""" + n = self.degree + # B_{n,k}(u) = C(n,k) * u^k * (1-u)^{n-k} + if k == 0: + return self.binomial_coeffs[k] * torch.pow(1 - u, n - k) + elif k == n: + return self.binomial_coeffs[k] * torch.pow(u, k) + else: + return self.binomial_coeffs[k] * torch.pow(u, k) * torch.pow(1 - u, n - k) + + def _bernstein_polynomial(self, u): + """Evaluate Bernstein polynomial quantile function at u""" + u = torch.clamp(u, 1e-7, 1 - 1e-7) # Avoid boundary issues + + result = torch.zeros_like(u) + betas = self.betas + + for k in range(self.degree + 1): + basis = self._bernstein_basis(u, k) + result += betas[k] * basis + + return result + + def _bernstein_derivative(self, u): + """Compute derivative of Bernstein polynomial (needed for Jacobian)""" + u = torch.clamp(u, 1e-7, 1 - 1e-7) + + if self.degree == 0: + return torch.zeros_like(u) + + # Derivative using the property: d/du B_{n,k}(u) = n * [B_{n-1,k-1}(u) - B_{n-1,k}(u)] + result = torch.zeros_like(u) + betas = self.betas + n = self.degree + + for k in range(self.degree + 1): + if k > 0: + # B_{n-1,k-1}(u) term + if k-1 <= n-1: + prev_basis = self._bernstein_basis_degree(u, k-1, n-1) + result += n * betas[k] * prev_basis + + if k < self.degree: + # -B_{n-1,k}(u) term + if k <= n-1: + curr_basis = self._bernstein_basis_degree(u, k, n-1) + result -= n * betas[k] * curr_basis + + return torch.clamp(result, 1e-7, float('inf')) # Ensure positive for monotonicity + + def _bernstein_basis_degree(self, u, k, degree): + """Compute k-th Bernstein basis polynomial of given degree at u""" + if degree < k or k < 0: + return torch.zeros_like(u) + + # Compute binomial coefficient with fallback + try: + from scipy.special import comb + binomial_coeff = float(comb(degree, k)) + except ImportError: + if k == 0 or k == degree: + binomial_coeff = 1.0 + else: + coeff = 1.0 + for i in range(min(k, degree-k)): + coeff = coeff * (degree - i) / (i + 1) + binomial_coeff = float(coeff) + + if k == 0: + return binomial_coeff * torch.pow(1 - u, degree - k) + elif k == degree: + return binomial_coeff * torch.pow(u, k) + else: + return binomial_coeff * torch.pow(u, k) * torch.pow(1 - u, degree - k) + + def __call__(self, x): + """Transform from uniform [0,1] to target distribution""" + return self._bernstein_polynomial(x) + + def _inverse(self, y): + """Inverse transform: find u such that Q(u) = y""" + # This requires numerical inversion since no analytical inverse exists + # We use binary search for robustness + return self._numerical_inverse(y) + + def _numerical_inverse(self, y, max_iter=50, tol=1e-6): + """Numerical inverse using binary search""" + # Clamp y to support bounds + y = torch.clamp(y, self.support_bounds[0] + 1e-6, self.support_bounds[1] - 1e-6) + + batch_shape = y.shape + y_flat = y.flatten() + + # Initialize bounds + lower = torch.zeros_like(y_flat) + 1e-7 + upper = torch.ones_like(y_flat) - 1e-7 + + for _ in range(max_iter): + mid = (lower + upper) / 2 + f_mid = self._bernstein_polynomial(mid.reshape(batch_shape)).flatten() + + # Update bounds based on comparison + mask = f_mid < y_flat + lower = torch.where(mask, mid, lower) + upper = torch.where(~mask, mid, upper) + + # Check convergence + if torch.max(upper - lower) < tol: + break + + result = (lower + upper) / 2 + return result.reshape(batch_shape) + + def log_abs_det_jacobian(self, x, y): + """Compute log absolute determinant of Jacobian""" + derivative = self._bernstein_derivative(x) + return torch.log(derivative) + + +class BernsteinFlow(NormalizingFlowClass): + """ + Bernstein Flow class. + + The Bernstein flow is a normalizing flow based on Bernstein polynomial quantile functions. + It uses Bernstein polynomials as basis functions to construct flexible, monotonic transformations + that naturally preserve the ordering required for valid probability distributions. + + Key features: + - Shape-constrained modeling with natural monotonicity preservation + - Interpretable parameters (each coefficient represents a quantile value) + - Computational efficiency with simple polynomial evaluation + - Flexibility to approximate any monotonic quantile function + + Parameters + ---------- + target_support : str + The target support. Options are: + - "real": [-inf, inf] + - "positive": [0, inf] + - "positive_integer": [0, 1, 2, 3, ...] + - "unit_interval": [0, 1] + degree : int + The degree of the Bernstein polynomial. Higher degree provides more flexibility + but requires more parameters. Typical values: 5-15. + bound : float + The support bounds for the distribution. The quantile function will map + [0,1] to approximately [-bound, bound]. + stabilization : str + Stabilization method for the Gradient and Hessian. Options are "None", "MAD" or "L2". + loss_fn : str + Loss function. Options are "nll" (negative log-likelihood) or "crps" + (continuous ranked probability score). Note that if "crps" is used, the Hessian + is set to 1, as the current CRPS version is not twice differentiable. + """ + + def __init__(self, + target_support: str = "real", + degree: int = 8, + bound: float = 5.0, + stabilization: str = "None", + loss_fn: str = "nll" + ): + + # Input validation + if not isinstance(target_support, str): + raise ValueError("target_support must be a string.") + + # Specify Target Transform + transforms = { + "real": (identity_transform, False), + "positive": (SoftplusTransform(), False), + "positive_integer": (SoftplusTransform(), True), + "unit_interval": (SigmoidTransform(), False) + } + + if target_support in transforms: + target_transform, discrete = transforms[target_support] + else: + raise ValueError( + "Invalid target_support. Options are 'real', 'positive', 'positive_integer', or 'unit_interval'.") + + # Check if degree is valid + if not isinstance(degree, int): + raise ValueError("degree must be an integer.") + if degree <= 0: + raise ValueError("degree must be a positive integer > 0.") + if degree > 20: + raise ValueError("degree should be <= 20 for numerical stability.") + + # Check if bound is valid + if not isinstance(bound, float): + bound = float(bound) + if bound <= 0: + raise ValueError("bound must be positive.") + + # Number of parameters (degree + 1 coefficients) + n_params = degree + 1 + + # Check if stabilization method is valid + if not isinstance(stabilization, str): + raise ValueError("stabilization must be a string.") + if stabilization not in ["None", "MAD", "L2"]: + raise ValueError("Invalid stabilization method. Options are 'None', 'MAD' or 'L2'.") + + # Check if loss function is valid + if not isinstance(loss_fn, str): + raise ValueError("loss_fn must be a string.") + if loss_fn not in ["nll", "crps"]: + raise ValueError("Invalid loss_fn. Options are 'nll' or 'crps'.") + + # Specify parameter dictionary + param_dict = {f"beta_{i}": identity_fn for i in range(n_params)} + torch.distributions.Distribution.set_default_validate_args(False) + + # Support bounds for the transform + support_bounds = (-bound, bound) + + # Specify Normalizing Flow Class + super().__init__(base_dist=Normal, + flow_transform=BernsteinQuantileTransform, + degree=degree, + support_bounds=support_bounds, + n_dist_param=n_params, + param_dict=param_dict, + distribution_arg_names=list(param_dict.keys()), + target_transform=target_transform, + discrete=discrete, + univariate=True, + stabilization=stabilization, + loss_fn=loss_fn + ) \ No newline at end of file diff --git a/lightgbmlss/distributions/__init__.py b/lightgbmlss/distributions/__init__.py index d1cffc0..51ef6ff 100644 --- a/lightgbmlss/distributions/__init__.py +++ b/lightgbmlss/distributions/__init__.py @@ -22,4 +22,5 @@ from . import ZABeta from . import ZALN from . import SplineFlow +from . import BernsteinFlow from . import Mixture \ No newline at end of file diff --git a/lightgbmlss/distributions/flow_utils.py b/lightgbmlss/distributions/flow_utils.py index 7b54272..d101f9a 100644 --- a/lightgbmlss/distributions/flow_utils.py +++ b/lightgbmlss/distributions/flow_utils.py @@ -61,6 +61,8 @@ def __init__(self, count_bins: Optional[int] = 8, bound: Optional[float] = 3.0, order: Optional[str] = "quadratic", + degree: Optional[int] = 8, + support_bounds: Optional[Tuple] = None, n_dist_param: int = None, param_dict: Dict[str, Any] = None, distribution_arg_names: List = None, @@ -76,6 +78,8 @@ def __init__(self, self.count_bins = count_bins self.bound = bound self.order = order + self.degree = degree + self.support_bounds = support_bounds self.n_dist_param = n_dist_param self.param_dict = param_dict self.distribution_arg_names = distribution_arg_names @@ -303,7 +307,7 @@ def create_spline_flow(self, Returns ------- - spline_flow: Transform + flow: Transform Normalizing Flow. """ @@ -311,17 +315,24 @@ def create_spline_flow(self, loc, scale = torch.zeros(input_dim), torch.ones(input_dim) flow_dist = self.base_dist(loc, scale) - # Create Spline Transform + # Create Transform based on type torch.manual_seed(123) - spline_transform = self.flow_transform(input_dim, - count_bins=self.count_bins, - bound=self.bound, - order=self.order) + + # Check if this is a Bernstein flow (by checking for degree parameter) + if hasattr(self, 'degree') and self.degree is not None and hasattr(self, 'support_bounds'): + # This is a Bernstein flow + transform = self.flow_transform(degree=self.degree, support_bounds=self.support_bounds) + else: + # This is a Spline flow (original implementation) + transform = self.flow_transform(input_dim, + count_bins=self.count_bins, + bound=self.bound, + order=self.order) # Create Normalizing Flow - spline_flow = TransformedDistribution(flow_dist, [spline_transform, self.target_transform]) + flow = TransformedDistribution(flow_dist, [transform, self.target_transform]) - return spline_flow + return flow def replace_parameters(self, params: torch.Tensor, @@ -345,22 +356,31 @@ def replace_parameters(self, Normalizing Flow with estimated parameters. """ - # Split parameters into list - if self.order == "quadratic": - params_list = torch.split( - params, [self.count_bins, self.count_bins, self.count_bins - 1], - dim=1) - elif self.order == "linear": - params_list = torch.split( - params, [self.count_bins, self.count_bins, self.count_bins - 1, self.count_bins], - dim=1) + # Split parameters based on flow type + if hasattr(self, 'degree') and self.degree is not None: + # Bernstein flow: degree+1 parameters (beta coefficients) + params_list = [params] # All parameters as single tensor for Bernstein + else: + # Spline flow: split according to order + if self.order == "quadratic": + params_list = torch.split( + params, [self.count_bins, self.count_bins, self.count_bins - 1], + dim=1) + elif self.order == "linear": + params_list = torch.split( + params, [self.count_bins, self.count_bins, self.count_bins - 1, self.count_bins], + dim=1) # Replace parameters - for param, new_value in zip(flow_dist.transforms[0].parameters(), params_list): - param.data = new_value - - # Get parameters (including require_grad=True) - params_list = list(flow_dist.transforms[0].parameters()) + if hasattr(self, 'degree') and self.degree is not None: + # For Bernstein flow, replace the raw_betas parameter + flow_dist.transforms[0].raw_betas.data = params.squeeze() + params_list = [flow_dist.transforms[0].raw_betas] + else: + # For Spline flow, replace multiple parameters + for param, new_value in zip(flow_dist.transforms[0].parameters(), params_list): + param.data = new_value + params_list = list(flow_dist.transforms[0].parameters()) return params_list, flow_dist diff --git a/tests/test_distributions/test_bernstein_flow.py b/tests/test_distributions/test_bernstein_flow.py new file mode 100644 index 0000000..219acdf --- /dev/null +++ b/tests/test_distributions/test_bernstein_flow.py @@ -0,0 +1,181 @@ +from ..utils import BaseTestClass +import pytest + + +class TestBernsteinFlowClass(BaseTestClass): + """Test class for BernsteinFlow distribution""" + + def test_init_bernstein(self): + """Test initialization parameters specific to BernsteinFlow""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + + # Test valid initialization + dist = BernsteinFlow() + assert dist.degree == 8 + assert dist.support_bounds == (-5.0, 5.0) + assert dist.n_dist_param == 9 # degree + 1 + + # Test target_support validation + with pytest.raises(ValueError, match="target_support must be a string."): + BernsteinFlow(target_support=1) + with pytest.raises(ValueError, match="Invalid target_support."): + BernsteinFlow(target_support="invalid_target_support") + + # Test degree validation + with pytest.raises(ValueError, match="degree must be an integer."): + BernsteinFlow(degree=1.0) + with pytest.raises(ValueError, match="degree must be a positive integer > 0."): + BernsteinFlow(degree=0) + with pytest.raises(ValueError, match="degree should be <= 20 for numerical stability."): + BernsteinFlow(degree=25) + + # Test bound validation + with pytest.raises(ValueError, match="bound must be positive."): + BernsteinFlow(bound=-1.0) + + # Test stabilization validation + with pytest.raises(ValueError, match="stabilization must be a string."): + BernsteinFlow(stabilization=1) + with pytest.raises(ValueError, match="Invalid stabilization method."): + BernsteinFlow(stabilization="invalid_stabilization") + + # Test loss_fn validation + with pytest.raises(ValueError, match="loss_fn must be a string."): + BernsteinFlow(loss_fn=1) + with pytest.raises(ValueError, match="Invalid loss_fn."): + BernsteinFlow(loss_fn="invalid_loss_fn") + + def test_parameter_dictionary(self): + """Test that parameter dictionary is correctly set up""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + + dist = BernsteinFlow(degree=5) + assert isinstance(dist.param_dict, dict) + assert len(dist.param_dict) == 6 # degree + 1 + assert all(f"beta_{i}" in dist.param_dict for i in range(6)) + assert all(callable(func) for func in dist.param_dict.values()) + + def test_different_degrees(self): + """Test BernsteinFlow with different polynomial degrees""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + + degrees = [3, 5, 10, 15] + for degree in degrees: + dist = BernsteinFlow(degree=degree) + assert dist.degree == degree + assert dist.n_dist_param == degree + 1 + assert len(dist.param_dict) == degree + 1 + + def test_target_supports(self): + """Test different target supports""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + + supports = ["real", "positive", "positive_integer", "unit_interval"] + for support in supports: + dist = BernsteinFlow(target_support=support) + assert dist.target_transform is not None + + def test_transform_properties(self): + """Test that the BernsteinQuantileTransform is properly initialized""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow, BernsteinQuantileTransform + import torch + + # Create transform directly + transform = BernsteinQuantileTransform(degree=5, support_bounds=(-3.0, 3.0)) + assert transform.degree == 5 + assert transform.support_bounds == (-3.0, 3.0) + assert transform.raw_betas.shape[0] == 6 # degree + 1 + + # Test binomial coefficients are computed + assert transform.binomial_coeffs is not None + assert len(transform.binomial_coeffs) == 6 + + # Test monotonicity property + betas = transform.betas + assert torch.all(betas[1:] >= betas[:-1]) # Should be non-decreasing + + def test_bernstein_basis_functions(self): + """Test Bernstein basis function computation""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinQuantileTransform + import torch + + transform = BernsteinQuantileTransform(degree=3) + u = torch.tensor([0.0, 0.25, 0.5, 0.75, 1.0]) + + # Test that all basis functions sum to 1 + total = torch.zeros_like(u) + for k in range(4): # degree + 1 + basis_k = transform._bernstein_basis(u, k) + total += basis_k + + # Should sum to 1 (with some numerical tolerance) + assert torch.allclose(total, torch.ones_like(u), atol=1e-6) + + def test_transform_forward_inverse(self): + """Test forward and inverse transform""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinQuantileTransform + import torch + + transform = BernsteinQuantileTransform(degree=5, support_bounds=(-2.0, 2.0)) + + # Test points in unit interval + u = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]) + + # Forward transform + y = transform(u) + assert y.shape == u.shape + assert torch.all(y >= -2.0) and torch.all(y <= 2.0) # Should be in support bounds + + # Inverse transform (approximate due to numerical method) + u_reconstructed = transform._inverse(y) + assert u_reconstructed.shape == u.shape + assert torch.allclose(u, u_reconstructed, atol=1e-3) # Allow some numerical error + + def test_jacobian_computation(self): + """Test log absolute determinant of Jacobian""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinQuantileTransform + import torch + + transform = BernsteinQuantileTransform(degree=5) + + u = torch.tensor([0.1, 0.3, 0.5, 0.7, 0.9]) + y = transform(u) + + # Should be able to compute Jacobian + log_det_jac = transform.log_abs_det_jacobian(u, y) + assert log_det_jac.shape == u.shape + assert torch.all(torch.isfinite(log_det_jac)) # Should not have NaN or inf + + def test_defaults(self): + """Test default values match expected behavior""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + + dist = BernsteinFlow() + assert isinstance(dist.univariate, bool) + assert dist.univariate is True + assert isinstance(dist.discrete, bool) + assert dist.discrete is False + + def test_distribution_class_integration(self): + """Test integration with LightGBMLSS model class""" + from lightgbmlss.distributions.BernsteinFlow import BernsteinFlow + from lightgbmlss.model import LightGBMLSS + import numpy as np + + # Should be able to create a model with BernsteinFlow + dist = BernsteinFlow(degree=5) + model = LightGBMLSS(dist) + + # Test that model is properly initialized + assert model.dist.n_dist_param == 6 + assert model.dist.param_dict is not None + + # Test with synthetic data + np.random.seed(42) + X = np.random.randn(100, 3) + y = np.random.randn(100) + + # Should be able to compute starting values + start_values = model.dist.calculate_start_values(y, max_iter=10) + assert len(start_values) == 2 # loss and start_values + assert len(start_values[1]) == 6 # Should match n_dist_param \ No newline at end of file diff --git a/tests/utils.py b/tests/utils.py index 9baa071..289ec70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -87,6 +87,7 @@ def get_distribution_classes(univariate: bool = True, # Remove specific distributions distns_remove = [ "SplineFlow", + "BernsteinFlow", "Expectile", "Mixture" ] @@ -170,12 +171,15 @@ def get_distribution_classes(univariate: bool = True, return multivar_distns elif flow: - distribution_name = "SplineFlow" - module = importlib.import_module(f"lightgbmlss.distributions.{distribution_name}") - # Get the class dynamically from the module - distribution_class = [getattr(module, distribution_name)] - - return distribution_class + flow_names = ["SplineFlow", "BernsteinFlow"] + distribution_classes = [] + for distribution_name in flow_names: + module = importlib.import_module(f"lightgbmlss.distributions.{distribution_name}") + # Get the class dynamically from the module + distribution_class = getattr(module, distribution_name) + distribution_classes.append(distribution_class) + + return distribution_classes elif expectile: distribution_name = "Expectile"