Skip to content

Commit b0bdeb4

Browse files
committed
Allow debug evaling IR logp graphs
1 parent d0face4 commit b0bdeb4

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

pymc/logprob/abstract.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,16 @@ class ValuedRV(Op):
224224
and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points.
225225
"""
226226

227+
view_map = {0: [0]}
228+
227229
def make_node(self, rv, value):
228230
assert isinstance(rv, Variable)
229231
assert isinstance(value, Variable)
230232
return Apply(self, [rv, value], [rv.type(name=rv.name)])
231233

232234
def perform(self, node, inputs, out):
233-
raise NotImplementedError("ValuedVar should not be present in the final graph!")
235+
warnings.warn("ValuedVar should not be present in the final graph!")
236+
out[0][0] = inputs[0]
234237

235238
def infer_shape(self, fgraph, node, input_shapes):
236239
return [input_shapes[0]]

pymc/logprob/transform_value.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
1415

1516
from collections.abc import Sequence
1617

@@ -40,7 +41,8 @@ def make_node(self, tran_value: TensorVariable, value: TensorVariable):
4041
return Apply(self, [tran_value, value], [tran_value.type()])
4142

4243
def perform(self, node, inputs, outputs):
43-
raise NotImplementedError("These `Op`s should be removed from graphs used for computation.")
44+
warnings.warn("TransformedValue should not be present in the final graph!")
45+
outputs[0][0] = inputs[0]
4446

4547
def infer_shape(self, fgraph, node, input_shapes):
4648
return [input_shapes[0]]

tests/logprob/test_basic.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,26 @@ def test_ir_rewrite_does_not_disconnect_valued_rvs():
436436
logp_b.eval({a_value: np.pi, b_value: np.e}),
437437
stats.norm.logpdf(np.e, np.pi * 8, 1),
438438
)
439+
440+
441+
def test_ir_ops_can_be_evaluated_with_warning():
442+
_eval_values = [None, None]
443+
444+
def my_logp(value, lam):
445+
nonlocal _eval_values
446+
_eval_values[0] = value.eval()
447+
_eval_values[1] = lam.eval({"lam_log__": -1.5})
448+
return value * lam
449+
450+
with pm.Model() as m:
451+
lam = pm.Exponential("lam")
452+
pm.CustomDist("y", lam, logp=my_logp, observed=[0, 1, 2])
453+
454+
with pytest.warns(
455+
UserWarning, match="TransformedValue should not be present in the final graph"
456+
):
457+
with pytest.warns(UserWarning, match="ValuedVar should not be present in the final graph"):
458+
m.logp()
459+
460+
assert _eval_values[0].sum() == 3
461+
assert _eval_values[1] == np.exp(-1.5)

0 commit comments

Comments
 (0)