Skip to content

Conversation

colehaus
Copy link
Contributor

@colehaus colehaus commented Sep 5, 2022

Jax's betainc doesn't have gradients defined for all parameters while tfp's does.

See the related PR here: #1471 and the initial discussion here: #1452.

I'm not sure exactly how you want to handle the dependency declarations since tensorflow and tensorflow-probability are sort of heavy dependencies to bring in (i.e. should they be promoted to install_requires?).

Also, the type casting stuff is a bit ugly but tfp checks that array types match and self.df sometimes had a float64 dtype in tests while beta_value has a float32 dtype in each test.

Jax's `betainc` doesn't have gradients defined for all parameters while tfp's does
@colehaus
Copy link
Contributor Author

colehaus commented Sep 6, 2022

Ah, sorry. Was slightly non-trivial to run the lint checks with a .venv, but I cherry-picked my config changes from the other PR and ran make lint locally so it should pass this time.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 6, 2022

This is a great addition! Thanks, @colehaus.

@fehiepsi fehiepsi merged commit 9d5d235 into pyro-ppl:master Sep 6, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants