Skip to content

SimonKoop/common_jax_utils

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This is a package with stuff that is useful when working in jax in general. It's not specific to any project. Part of this package is centered around making the tools from common_dl_utils work optimally for jax/equinox. As such it relies heavily on common_dl_utils which can be installed using pip install git+https://github.com/SimonKoop/common_dl_utils.git --upgrade

The package is organized as follows:

  • types: registers relevant types in the type_registry from common_dl_utils
  • tree_utils: some utilities for working with pytrees
  • run_utils: extends the tools from common_dl_utils.config_realization to create models or run experiments from a config together with a jax prng key.
  • wandb_utils: hooks the tools from run_utils to wandb.
  • debug_utils: tools for debugging
  • metrics: a collection of metrics based on the framework setup in common_dl_utils.metrics
  • decorators: a collection of decorators that are useful when working with jax
  • masks: a collection of masking functions for masking parts of equinox modules for use in optax optimizers (e.g. for weight decay)

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published