Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,7 +874,12 @@ class SubsMeta(FunsorMeta):
def __call__(cls, arg, subs):
subs = tuple(
(k, to_funsor(v, arg.inputs[k])) for k, v in subs if k in arg.inputs
# (k, to_funsor(v, arg.inputs[k]))
# for k, v in subs
# if k in arg.inputs and k is not v
)
# if not subs:
# return arg
return super().__call__(arg, subs)


Expand Down
16 changes: 16 additions & 0 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,19 @@ def test_sequential_sum_product_adjoint(
)
expected_bwd = expected_bwds[operand]
assert (actual_bwd_t - expected_bwd).abs().data.max() < 5e-3 * num_steps


@pytest.mark.parametrize(
"use_subs", [False, xfail_param(True, reason="doubled adjoint value")]
)
def test_subs_adjoint(use_subs):
x = random_tensor(OrderedDict(i=Bint[3]))

with AdjointTape() as tape:
y = 2 * x
if use_subs:
y = y(i="i")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This substitution shouldn't change adjoint value.


# use_subs=True returns Number(4.0)
actual = tape.adjoint(ops.add, ops.mul, y, (x,))[x]
assert actual is funsor.Number(2.0)