feat: support for censored likelihoods#91
feat: support for censored likelihoods#91John-Curcio wants to merge 5 commits intoStatMixedML:masterfrom
Conversation
|
@John-Curcio Thanks for opening the PR and your effort! Would need some time to review it. |
There was a problem hiding this comment.
Pull Request Overview
This PR adds support for fitting censored data by introducing a CensoredMixin class that extends univariate distributions to handle interval-censored observations. The mixin overrides objective_fn and metric_fn to compute likelihood functions using cumulative distribution functions (CDFs) for censored intervals.
- Adds
CensoredMixinclass with censored likelihood computation - Implements
CensoredLogNormalandCensoredWeibulldistribution classes - Adds comprehensive test coverage for censored data functionality
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| xgboostlss/distributions/censored_utils.py | Core CensoredMixin class implementing censored likelihood functions |
| xgboostlss/distributions/Weibull.py | Adds CensoredWeibull class inheriting from CensoredMixin and Weibull |
| xgboostlss/distributions/LogNormal.py | Adds CensoredLogNormal class inheriting from CensoredMixin and LogNormal |
| tests/utils.py | Extends test data generation to support censored data scenarios |
| tests/test_distribution_utils/test_censored_utils.py | Test suite validating censored distribution functionality |
| mass = cdf_hi - cdf_low | ||
| log_density = dist.log_prob(low) | ||
| censored_inds = low != hi | ||
| loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) |
There was a problem hiding this comment.
The log density is computed using only the lower bound, but this should only be used for exact observations (non-censored data). For censored intervals where low != hi, this log_density value is incorrect and shouldn't contribute to the loss.
| loss = -torch.sum(torch.log(mass[censored_inds])) - torch.sum(log_density[~censored_inds]) | |
| exact_inds = (low == hi) | |
| log_density = dist.log_prob(low[exact_inds]) | |
| loss = -torch.sum(torch.log(mass[~exact_inds])) - torch.sum(log_density) |
| return super().objective_fn(predt, data) | ||
| if data.get_weight().size == 0: | ||
| # initialize weights as ones with correct shape | ||
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() |
There was a problem hiding this comment.
Creating a tensor just to get its dtype and then converting back to numpy is inefficient. Consider using weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) directly.
| weights = torch.ones((lower.shape[0], 1), dtype=torch.as_tensor(lower).dtype).numpy() | |
| weights = np.ones((lower.shape[0], 1), dtype=lower.dtype) |
| predt, labels, *rest = gen_test_data(model, weights=False, censored=False) | ||
| dmat = rest[-1] | ||
| name_c, loss_c = model.dist.metric_fn(predt, dmat) | ||
| underlying_cls = model.dist.__class__.__mro__[2] |
There was a problem hiding this comment.
Using hardcoded index [2] in the MRO (Method Resolution Order) is fragile and could break if the inheritance hierarchy changes. Consider using a more explicit approach like checking class names or using hasattr to find the base distribution class.
| underlying_cls = model.dist.__class__.__mro__[2] | |
| # Find the first base class in the MRO that is not a censored distribution and not 'object' | |
| underlying_cls = next( | |
| cls for cls in model.dist.__class__.__mro__ | |
| if cls is not model.dist.__class__ and not cls.__name__.startswith("Censored") and cls is not object | |
| ) |
| dist = self.distribution(**dict(zip(self.distribution_arg_names, params_transformed))) | ||
| # compute cdf bounds: convert lower & upper once to tensor with correct dtype | ||
| low = torch.as_tensor(lower, dtype=params_transformed[0].dtype).reshape(-1, 1) | ||
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) |
There was a problem hiding this comment.
Inconsistent spacing: 'hi =' has two spaces before the equals sign while 'low =' on the previous line has one space. This should be consistent.
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) | |
| hi = torch.as_tensor(upper, dtype=params_transformed[0].dtype).reshape(-1, 1) |
|
@claude Review this PR |
|
Claude encountered an error —— View job I'll analyze this and get back to you. |
|
@John-Curcio CAn you please look into these |
CensoredMixinclass to extend fitting univariate distributions to censored dataCensoredLogNormalandCensoredWeibullxgb.DMatrixalready haslabel_lower_bound, label_upper_boundfor the user to specify right-, left-, or interval-censored data. This PR adds aCensoredMixinclass which simply overridesobjective_fn, metric_fnto accommodate censored data. So to fit a LogNormal distribution to such a dataset, just useCensoredLogNormalinstead ofLogNormal.I've added
CensoredLogNormalandCensoredWeibull.I'm happy to further update docs/add examples