Conversation
gpleiss
left a comment
There was a problem hiding this comment.
Looks mostly good to me; see one comment.
|
Comment addressed |
|
This should be ready to merge! |
|
I could use this; @gpleiss @bendavidsteel Do you need any help getting this merged? |
|
I believe we're just waiting on approval from @gpleiss |
|
@gpleiss @bendavidsteel how about merging this? GPOR would be enormously useful for our research |
7cfabd1 to
fca1feb
Compare
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| Returns: | ||
| Probabilities between jitter and 1-jitter | ||
| """ | ||
| return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter |
There was a problem hiding this comment.
If x is a GPU tensor, then this line will trigger a device error as torch.tensor(2.0) is always a CPU tensor.
| return 0.5 * (1.0 + torch.erf(x / torch.sqrt(torch.tensor(2.0)))) * (1 - 2 * jitter) + jitter | |
| return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) * (1 - 2 * jitter) + jitter |
There was a problem hiding this comment.
I changed it to `torch.tensor(2.0, device=x.device)' to keep torch speed + fix device error
| scaled_edges_left = scaled_edges_left.reshape(1, -1) | ||
| scaled_edges_right = scaled_edges_right.reshape(1, -1) |
There was a problem hiding this comment.
Will these two lines work in batch settings where the batch shape is non-empty?
There was a problem hiding this comment.
Fixed + added test to confirm
| super().__init__() | ||
|
|
||
| self.num_bins = len(bin_edges) + 1 | ||
| self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False)) |
There was a problem hiding this comment.
| self.register_parameter("bin_edges", torch.nn.Parameter(bin_edges, requires_grad=False)) | |
| self.register_buffer("bin_edges", bin_edges) |
nit: I think it makes more sense to register this as a buffer instead since we won't update the bin edges?
On the flip side, does it make sense to set requires_grad=True so that we learn the bin edges during model fitting? (Some packages choose to do so; see here.) IIUC, we only learn sigma here but the bin edges are fixed. I am wondering if this could limit the expressiveness of the likelihood.
There was a problem hiding this comment.
I changed the code to allow for learnable edges but default to fixed
| from .likelihood import _OneDimensionalLikelihood | ||
|
|
||
|
|
||
| def inv_probit(x, jitter=1e-3): |
There was a problem hiding this comment.
| def inv_probit(x, jitter=1e-3): | |
| def inv_probit(x: Tensor, jitter: float = 1e-3): |
Let's annotate these variables.
| def _set_sigma(self, value: Tensor) -> None: | ||
| if not torch.is_tensor(value): | ||
| value = torch.as_tensor(value).to(self.raw_sigma) | ||
| self.initialize(raw_sigma=self.raw_sigma_constraint.inverse_transform(value)) |
There was a problem hiding this comment.
nit: We've already annotated value as tensor. So we could drop the if-statement here? Also, maybe we could merge this method with the sigma setter method above?
|
I've merged the latest main into this PR. Multiple users have expressed interests in the ordinal likelihood implementation. So it would be great to get this merged. It's been a while since this PR opened. @bendavidsteel apologies for the delay on our end! I left some additional comments. I am wondering if you still have the capacity to work on this PR. |
- Fix GPU device error in inv_probit by passing device to torch.tensor - Add type annotations to inv_probit signature - Add learn_edges parameter to control whether bin edges are learnable - Merge _set_sigma into sigma setter, remove redundant method - Fix forward to support non-empty batch shapes (unsqueeze instead of reshape) - Add tests for batched likelihood and GPU device Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Addressed comments @kayween |
I needed an ordinal likelihood for some of my own work, and saw Issue #2534, so I thought I'd make this contribution!
Tested it a bit, seems to work pretty well. It's pretty much just using the same idea as in GPflow so nothing new.
Tests and docs provided.