|
77 | 77 | MeasurableElemwise,
|
78 | 78 | MeasurableVariable,
|
79 | 79 | _get_measurable_outputs,
|
| 80 | + _icdf, |
| 81 | + _icdf_helper, |
| 82 | + _logcdf, |
| 83 | + _logcdf_helper, |
80 | 84 | _logprob,
|
81 | 85 | _logprob_helper,
|
82 | 86 | )
|
@@ -390,6 +394,38 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
|
390 | 394 | return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
|
391 | 395 |
|
392 | 396 |
|
| 397 | +@_logcdf.register(MeasurableTransform) |
| 398 | +def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs): |
| 399 | + """Compute the log-CDF graph for a `MeasurabeTransform`.""" |
| 400 | + other_inputs = list(inputs) |
| 401 | + measurable_input = other_inputs.pop(op.measurable_input_idx) |
| 402 | + |
| 403 | + backward_value = op.transform_elemwise.backward(value, *other_inputs) |
| 404 | + |
| 405 | + # Some transformations, like squaring may produce multiple backward values |
| 406 | + if isinstance(backward_value, tuple): |
| 407 | + raise NotImplementedError |
| 408 | + |
| 409 | + input_logcdf = _logcdf_helper(measurable_input, backward_value) |
| 410 | + |
| 411 | + # The jacobian is used to ensure a value in the supported domain was provided |
| 412 | + jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) |
| 413 | + |
| 414 | + return pt.switch(pt.isnan(jacobian), -np.inf, input_logcdf) |
| 415 | + |
| 416 | + |
| 417 | +@_icdf.register(MeasurableTransform) |
| 418 | +def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs): |
| 419 | + """Compute the inverse CDF graph for a `MeasurabeTransform`.""" |
| 420 | + other_inputs = list(inputs) |
| 421 | + measurable_input = other_inputs.pop(op.measurable_input_idx) |
| 422 | + |
| 423 | + input_icdf = _icdf_helper(measurable_input, value) |
| 424 | + icdf = op.transform_elemwise.forward(input_icdf, *other_inputs) |
| 425 | + |
| 426 | + return icdf |
| 427 | + |
| 428 | + |
393 | 429 | @node_rewriter([reciprocal])
|
394 | 430 | def measurable_reciprocal_to_power(fgraph, node):
|
395 | 431 | """Convert reciprocal of `MeasurableVariable`s to power."""
|
|
0 commit comments