If we want to leverage @NeilGirdhar's tjax (e.g., a more flexible custom_vjp, pytree typing, pytree dataclasses), we'll need to use python 3.8 or greater since this package and some of its dependencies use features only added in 3.8.
@pierrelux Thoughts?