@@ -223,35 +223,39 @@ def prob(self, value: torch.Tensor) -> torch.Tensor:
223223
224224 value = value .to (dtype = self .logits .dtype , device = self .logits .device )
225225
226- # Determine number of sample dimensions (dimensions before batch_shape).
227- num_sample_dims = len (value .shape ) - len (self .batch_shape )
228-
229- # Prepend singleton dimensions for sample_shape to bin_edges, bin_widths, and probs.
230- # For all of them, the resulting shape will be: (*sample_shape, *batch_shape, num_bins)
231- bin_edges_left = self .bin_edges [..., :- 1 ] # shape: (*batch_shape, num_bins)
232- bin_edges_right = self .bin_edges [..., 1 :] # shape: (*batch_shape, num_bins)
233- bin_edges_left = bin_edges_left .view ((1 ,) * num_sample_dims + bin_edges_left .shape )
234- bin_edges_right = bin_edges_right .view ((1 ,) * num_sample_dims + bin_edges_right .shape )
235- bin_widths = self .bin_widths .view ((1 ,) * num_sample_dims + self .bin_widths .shape )
236- probs = self .bin_probs .view ((1 ,) * num_sample_dims + self .bin_probs .shape )
237-
238- # Add bin dimension to value for broadcasting.
239- value_expanded = value .unsqueeze (- 1 ) # shape: (*sample_shape, *batch_shape, 1)
240-
241- # Check which bin each value falls into. Result shape: (*sample_shape, *batch_shape, num_bins).
242- in_bin = ((value_expanded >= bin_edges_left ) & (value_expanded < bin_edges_right )).to (self .logits .dtype )
243-
244- # Handle right edge case (include bound_up in last bin).
245- at_right_edge = torch .isclose (
246- value_expanded , torch .tensor (self .bound_up , dtype = self .logits .dtype , device = self .logits .device )
247- )
248- in_bin [..., - 1 ] = torch .max (in_bin [..., - 1 ], at_right_edge [..., - 1 ])
249-
250- # PDF = (probability mass / bin width) for the containing bin.
251- pdf_per_bin = probs / bin_widths # shape: (*sample_shape, *batch_shape, num_bins)
252-
253- # Sum over bins is the same as selecting the bin, as there is only one bin active per value.
254- return torch .sum (in_bin * pdf_per_bin , dim = - 1 )
226+ # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
227+ if len (self .batch_shape ) > 0 and value .ndim < len (self .batch_shape ):
228+ value = value .expand (self .batch_shape )
229+
230+ # Use binary search to find which bin each value belongs to. The torch.searchsorted function returns the
231+ # index where value would be inserted to maintain sorted order.
232+ # Since bins are defined as [edge[i], edge[i+1]), we subtract 1 to get the bin index.
233+ bin_indices = torch .searchsorted (self .bin_edges , value ) - 1 # shape: (*sample_shape, *batch_shape)
234+
235+ # Clamp to valid range [0, num_bins - 1] to handle edge cases:
236+ # - values below bound_low would give bin_idx = -1
237+ # - values at bound_up would give bin_idx = num_bins
238+ bin_indices = torch .clamp (bin_indices , 0 , self .num_bins - 1 )
239+
240+ # Gather the bin widths and probabilities for the selected bins.
241+ # For bin_widths of shape (num_bins,) we can index directly.
242+ bin_widths_selected = self .bin_widths [bin_indices ] # shape: (*sample_shape, *batch_shape)
243+
244+ # For bin_probs of shape (*batch_shape, num_bins) we need to use gather along the last dimension.
245+ # Add sample dimensions to bin_probs and expand to match bin_indices shape.
246+ num_sample_dims = len (bin_indices .shape ) - len (self .batch_shape )
247+ bin_probs_for_gather = self .bin_probs .view ((1 ,) * num_sample_dims + self .bin_probs .shape )
248+ bin_probs_for_gather = bin_probs_for_gather .expand (
249+ * bin_indices .shape , - 1
250+ ) # shape: (*sample_shape, *batch_shape, num_bins)
251+
252+ # Gather the selected bin probabilities.
253+ bin_indices_for_gather = bin_indices .unsqueeze (- 1 ) # shape: (*sample_shape, *batch_shape, 1)
254+ bin_probs_selected = torch .gather (bin_probs_for_gather , dim = - 1 , index = bin_indices_for_gather )
255+ bin_probs_selected = bin_probs_selected .squeeze (- 1 )
256+
257+ # Compute PDF = probability mass / bin width.
258+ return bin_probs_selected / bin_widths_selected
255259
256260 def cdf (self , value : torch .Tensor ) -> torch .Tensor :
257261 """Compute cumulative distribution function at given values.
@@ -269,25 +273,34 @@ def cdf(self, value: torch.Tensor) -> torch.Tensor:
269273
270274 value = value .to (dtype = self .logits .dtype , device = self .logits .device )
271275
272- # Determine number of sample dimensions (dimensions before batch_shape).
273- num_sample_dims = len (value .shape ) - len (self .batch_shape )
276+ # Explicitly broadcast value to batch_shape if needed (e.g., scalar inputs with batched distributions).
277+ if len (self .batch_shape ) > 0 and value .ndim < len (self .batch_shape ):
278+ value = value .expand (self .batch_shape )
274279
275- # Prepend singleton dimensions for sample_shape to bin_centers .
276- # bin_centers: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
277- bin_centers_expanded = self .bin_centers . view (( 1 ,) * num_sample_dims + self . bin_centers . shape )
280+ # Use binary search to find how many bin centers are <= value .
281+ # torch.searchsorted with right=True gives us the number of elements <= value.
282+ num_bins_active = torch . searchsorted ( self .bin_centers , value , right = True )
278283
279- # Prepend singleton dimensions for sample_shape to probs.
280- # probs: (*batch_shape, num_bins) -> (*sample_shape, *batch_shape, num_bins)
281- probs_expanded = self .bin_probs .view ((1 ,) * num_sample_dims + self .bin_probs .shape )
284+ # Clamp to valid range [0, num_bins].
285+ num_bins_active = torch .clamp (num_bins_active , 0 , self .num_bins ) # shape: (*sample_shape, *batch_shape)
282286
283- # Add the bin dimension to the input which is used for comparing with the bin centers.
284- value_expanded = value .unsqueeze (- 1 ) # shape: (*sample_shape, *batch_shape, 1)
287+ # Compute cumulative sum of bin probabilities.
288+ # Prepend 0 for the case where no bins are active.
289+ num_sample_dims = len (num_bins_active .shape ) - len (self .batch_shape )
290+ cumsum_probs = torch .cumsum (self .bin_probs , dim = - 1 ) # shape: (*batch_shape, num_bins)
291+ cumsum_probs = torch .cat (
292+ [torch .zeros (* self .batch_shape , 1 , dtype = self .logits .dtype , device = self .logits .device ), cumsum_probs ],
293+ dim = - 1 ,
294+ ) # shape: (*batch_shape, num_bins + 1)
285295
286- # Mask for bins with centers <= value.
287- mask = bin_centers_expanded <= value_expanded # shape: (*sample_shape, *batch_shape, num_bins)
296+ # Expand cumsum_probs to match sample dimensions and gather.
297+ cumsum_probs_for_gather = cumsum_probs .view ((1 ,) * num_sample_dims + cumsum_probs .shape )
298+ cumsum_probs_for_gather = cumsum_probs_for_gather .expand (* num_bins_active .shape , - 1 )
299+ num_bins_active_for_gather = num_bins_active .unsqueeze (- 1 ) # shape: (*sample_shape, *batch_shape, 1)
300+ cdf_values = torch .gather (cumsum_probs_for_gather , dim = - 1 , index = num_bins_active_for_gather )
301+ cdf_values = cdf_values .squeeze (- 1 )
288302
289- # Sum the bins for this value by their weighted "activation"=probability.
290- return torch .sum (mask * probs_expanded , dim = - 1 )
303+ return cdf_values
291304
292305 def icdf (self , value : torch .Tensor ) -> torch .Tensor :
293306 """Compute the inverse CDF, i.e., the quantile function, at the given values.
0 commit comments