Skip to content

Commit fbb09e7

Browse files
authored
Optimize prob() and cdf() methods (#17)
1 parent 2d93e64 commit fbb09e7

File tree

1 file changed

+56
-43
lines changed

1 file changed

+56
-43
lines changed

binned_cdf/binned_logit_cdf.py

Lines changed: 56 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -223,35 +223,39 @@ def prob(self, value: torch.Tensor) -> torch.Tensor:
223223

224224
value = value.to(dtype=self.logits.dtype, device=self.logits.device)
225225

226-
# Determine number of sample dimensions (dimensions before batch_shape).
227-
num_sample_dims = len(value.shape) - len(self.batch_shape)
228-
229-
# Prepend singleton dimensions for sample_shape to bin_edges, bin_widths, and probs.
230-
# For all of them, the resulting shape will be: (*sample_shape, *batch_shape, num_bins)
231-
bin_edges_left = self.bin_edges[..., :-1] # shape: (*batch_shape, num_bins)
232-
bin_edges_right = self.bin_edges[..., 1:] # shape: (*batch_shape, num_bins)
233-
bin_edges_left = bin_edges_left.view((1,) * num_sample_dims + bin_edges_left.shape)
234-
bin_edges_right = bin_edges_right.view((1,) * num_sample_dims + bin_edges_right.shape)
235-
bin_widths = self.bin_widths.view((1,) * num_sample_dims + self.bin_widths.shape)
236-
probs = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
237-
238-
# Add bin dimension to value for broadcasting.
239-
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
240-
241-
# Check which bin each value falls into. Result shape: (*sample_shape, *batch_shape, num_bins).
242-
in_bin = ((value_expanded >= bin_edges_left) & (value_expanded < bin_edges_right)).to(self.logits.dtype)
243-
244-
# Handle right edge case (include bound_up in last bin).
245-
at_right_edge = torch.isclose(
246-
value_expanded, torch.tensor(self.bound_up, dtype=self.logits.dtype, device=self.logits.device)
247-
)
248-
in_bin[..., -1] = torch.max(in_bin[..., -1], at_right_edge[..., -1])
249-
250-
# PDF = (probability mass / bin width) for the containing bin.
251-
pdf_per_bin = probs / bin_widths # shape: (*sample_shape, *batch_shape, num_bins)
252-
253-
# Sum over bins is the same as selecting the bin, as there is only one bin active per value.
254-
return torch.sum(in_bin * pdf_per_bin, dim=-1)
226+
# Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
227+
if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
228+
value = value.expand(self.batch_shape)
229+
230+
# Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
231+
# index where value would be inserted to maintain sorted order.
232+
# Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
233+
bin_indices = torch.searchsorted(self.bin_edges, value) - 1 # shape: (*sample_shape, *batch_shape)
234+
235+
# Clamp to valid range [0, num_bins - 1] to handle edge cases:
236+
# - values below bound_low would give bin_idx = -1
237+
# - values at bound_up would give bin_idx = num_bins
238+
bin_indices = torch.clamp(bin_indices, 0, self.num_bins - 1)
239+
240+
# Gather the bin widths and probabilities for the selected bins.
241+
# For bin_widths of shape (num_bins,) we can index directly.
242+
bin_widths_selected = self.bin_widths[bin_indices] # shape: (*sample_shape, *batch_shape)
243+
244+
# For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
245+
# Add sample dimensions to bin_probs and expand to match bin_indices shape.
246+
num_sample_dims = len(bin_indices.shape) - len(self.batch_shape)
247+
bin_probs_for_gather = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
248+
bin_probs_for_gather = bin_probs_for_gather.expand(
249+
*bin_indices.shape, -1
250+
) # shape: (*sample_shape, *batch_shape, num_bins)
251+
252+
# Gather the selected bin probabilities.
253+
bin_indices_for_gather = bin_indices.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
254+
bin_probs_selected = torch.gather(bin_probs_for_gather, dim=-1, index=bin_indices_for_gather)
255+
bin_probs_selected = bin_probs_selected.squeeze(-1)
256+
257+
# Compute PDF = probability mass / bin width.
258+
return bin_probs_selected / bin_widths_selected
255259

256260
def cdf(self, value: torch.Tensor) -> torch.Tensor:
257261
"""Compute cumulative distribution function at given values.
@@ -269,25 +273,34 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:
269273

270274
value = value.to(dtype=self.logits.dtype, device=self.logits.device)
271275

272-
# Determine number of sample dimensions (dimensions before batch_shape).
273-
num_sample_dims = len(value.shape) - len(self.batch_shape)
276+
# Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
277+
if len(self.batch_shape) > 0 and value.ndim < len(self.batch_shape):
278+
value = value.expand(self.batch_shape)
274279

275-
# Prepend singleton dimensions for sample_shape to bin_centers.
276-
# bin_centers: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
277-
bin_centers_expanded = self.bin_centers.view((1,) * num_sample_dims + self.bin_centers.shape)
280+
# Use binary search to find how many bin centers are <= value.
281+
# torch.searchsorted with right=True gives us the number of elements <= value.
282+
num_bins_active = torch.searchsorted(self.bin_centers, value, right=True)
278283

279-
# Prepend singleton dimensions for sample_shape to probs.
280-
# probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
281-
probs_expanded = self.bin_probs.view((1,) * num_sample_dims + self.bin_probs.shape)
284+
# Clamp to valid range [0, num_bins].
285+
num_bins_active = torch.clamp(num_bins_active, 0, self.num_bins) # shape: (*sample_shape, *batch_shape)
282286

283-
# Add the bin dimension to the input which is used for comparing with the bin centers.
284-
value_expanded = value.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
287+
# Compute cumulative sum of bin probabilities.
288+
# Prepend 0 for the case where no bins are active.
289+
num_sample_dims = len(num_bins_active.shape) - len(self.batch_shape)
290+
cumsum_probs = torch.cumsum(self.bin_probs, dim=-1) # shape: (*batch_shape, num_bins)
291+
cumsum_probs = torch.cat(
292+
[torch.zeros(*self.batch_shape, 1, dtype=self.logits.dtype, device=self.logits.device), cumsum_probs],
293+
dim=-1,
294+
) # shape: (*batch_shape, num_bins + 1)
285295

286-
# Mask for bins with centers <= value.
287-
mask = bin_centers_expanded <= value_expanded # shape: (*sample_shape, *batch_shape, num_bins)
296+
# Expand cumsum_probs to match sample dimensions and gather.
297+
cumsum_probs_for_gather = cumsum_probs.view((1,) * num_sample_dims + cumsum_probs.shape)
298+
cumsum_probs_for_gather = cumsum_probs_for_gather.expand(*num_bins_active.shape, -1)
299+
num_bins_active_for_gather = num_bins_active.unsqueeze(-1) # shape: (*sample_shape, *batch_shape, 1)
300+
cdf_values = torch.gather(cumsum_probs_for_gather, dim=-1, index=num_bins_active_for_gather)
301+
cdf_values = cdf_values.squeeze(-1)
288302

289-
# Sum the bins for this value by their weighted "activation"=probability.
290-
return torch.sum(mask * probs_expanded, dim=-1)
303+
return cdf_values
291304

292305
def icdf(self, value: torch.Tensor) -> torch.Tensor:
293306
"""Compute the inverse CDF, i.e., the quantile function, at the given values.

0 commit comments

Comments
 (0)