Skip to content
Discussion options

You must be logged in to vote

Thanks for the clear repro. It looks like it's some sort of issue with reverse-mode autodiff of lax.conv. Look at the difference between forward-mode and reverse-mode:

params = {'distribution': random.normal(random.PRNGKey(0), (1, 3)),
          'function': random.normal(random.PRNGKey(1), (2,))}

print(jacfwd(convolve)(params))
# {'distribution':
#    DeviceArray([[[[ 2.2125056 , -0.11617047,  0.        ]],
#                  [[ 0.        ,  2.2125056 , -0.11617047]],
#                  [[ 0.        ,  0.        ,  2.2125056 ]]]], dtype=float32),
#  'function':
#    DeviceArray([[[-0.4826233 ,  1.8160859 ],
#                  [ 0.33988902, -0.4826233 ],
#                  [ 0.        ,  …

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@marcosrdac
Comment options

Answer selected by marcosrdac
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants