diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 57059a0a4..a4e6e22f5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,9 +4,9 @@ name: CI on: push: - branches: [ master ] + branches: [master] pull_request: - branches: [ master ] + branches: [master] env: PYTEST_ADDOPTS: "--cov=numpyro --cov-append" @@ -17,37 +17,37 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.9","3.10","3.13"] + python-version: ["3.9", "3.10", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - sudo apt install -y pandoc gsfonts - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install '.[doc,test]' - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -r docs/requirements.txt - pip freeze - - name: Lint with mypy and ruff - if: matrix.python-version != '3.9' - run: | - make lint - - name: Build documentation - if: matrix.python-version != '3.9' - run: | - make docs - - name: Test documentation - if: matrix.python-version != '3.9' - run: | - make doctest - python -m doctest -v README.md + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt install -y pandoc gsfonts + python -m pip install --upgrade pip + pip install jaxlib + pip install jax + pip install '.[doc,test]' + pip install https://github.com/pyro-ppl/funsor/archive/master.zip + pip install -r docs/requirements.txt + pip freeze + - name: Lint with mypy and ruff + if: matrix.python-version != '3.9' + run: | + make lint + - name: Build documentation + if: matrix.python-version != '3.9' + run: | + make docs + - name: Test documentation + if: matrix.python-version != '3.9' + run: | + make doctest + python -m doctest -v README.md test-modeling: @@ -59,46 +59,46 @@ jobs: python-version: ["3.9", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - sudo apt install -y graphviz - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze - - name: Test with pytest - run: | - CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ - - name: Test x64 - run: | - JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum" - - name: Test tracer leak - if: matrix.python-version == '3.13' - env: - JAX_CHECK_TRACER_LEAKS: 1 - run: | - pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit - pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke - pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run - pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths - pytest -vs test/test_distributions.py::test_mean_var -k Gompertz - - - name: Coveralls - if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' - uses: coverallsapp/github-action@v2 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel: true - flag-name: test-modeling + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + sudo apt install -y graphviz + python -m pip install --upgrade pip + # Keep track of pyro-api master branch + pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip + pip install jaxlib + pip install jax + pip install https://github.com/pyro-ppl/funsor/archive/master.zip + pip install -e '.[dev,test]' + pip freeze + - name: Test with pytest + run: | + CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/ + - name: Test x64 + run: | + JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k "powerLaw or Dagum" + - name: Test tracer leak + if: matrix.python-version == '3.13' + env: + JAX_CHECK_TRACER_LEAKS: 1 + run: | + pytest -vs test/infer/test_mcmc.py::test_chain_inside_jit + pytest -vs test/infer/test_mcmc.py::test_chain_jit_args_smoke + pytest -vs test/infer/test_mcmc.py::test_reuse_mcmc_run + pytest -vs test/infer/test_mcmc.py::test_model_with_multiple_exec_paths + pytest -vs test/test_distributions.py::test_mean_var -k Gompertz + + - name: Coveralls + if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' + uses: coverallsapp/github-action@v2 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: test-modeling test-inference: @@ -110,48 +110,48 @@ jobs: python-version: ["3.9", "3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,test]' - pip freeze - - name: Test with pytest - run: | - pytest -vs --durations=20 test/infer/test_mcmc.py - pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py - pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py - - name: Test x64 - run: | - JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 - - name: Test chains - run: | - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py - XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" - - name: Test custom prng - run: | - JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py - - name: Test nested sampling - run: | - JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py - - name: Coveralls - if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' - uses: coverallsapp/github-action@v2 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel: true - flag-name: test-inference + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + # Keep track of pyro-api master branch + pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip + pip install jaxlib + pip install jax + pip install https://github.com/pyro-ppl/funsor/archive/master.zip + pip install -e '.[dev,test]' + pip freeze + - name: Test with pytest + run: | + pytest -vs --durations=20 test/infer/test_mcmc.py + pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py + pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py + - name: Test x64 + run: | + JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64 + - name: Test chains + run: | + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap" + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain" + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py + XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain" + - name: Test custom prng + run: | + JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py + - name: Test nested sampling + run: | + JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py + - name: Coveralls + if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' + uses: coverallsapp/github-action@v2 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: test-inference examples: @@ -163,29 +163,29 @@ jobs: python-version: ["3.13"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install jaxlib - pip install jax - pip install https://github.com/pyro-ppl/funsor/archive/master.zip - pip install -e '.[dev,examples,test]' - pip freeze - - name: Test with pytest - run: | - CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example - - name: Coveralls - if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' - uses: coverallsapp/github-action@v2 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel: true - flag-name: examples + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install jaxlib + pip install jax + pip install https://github.com/pyro-ppl/funsor/archive/master.zip + pip install -e '.[dev,examples,test]' + pip freeze + - name: Test with pytest + run: | + CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example + - name: Coveralls + if: github.repository == 'pyro-ppl/numpyro' && matrix.python-version == '3.13' + uses: coverallsapp/github-action@v2 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel: true + flag-name: examples finish: @@ -193,10 +193,10 @@ jobs: needs: [test-modeling, test-inference, examples] runs-on: ubuntu-latest steps: - - name: Coveralls finished - uses: coverallsapp/github-action@v2 - with: - github-token: ${{ secrets.GITHUB_TOKEN }} - parallel-finished: true - carryforward: "test-modeling,test-inference,examples" + - name: Coveralls finished + uses: coverallsapp/github-action@v2 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + parallel-finished: true + carryforward: "test-modeling,test-inference,examples" diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index ce2f4c343..6f12a0d79 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -11,7 +11,7 @@ name: Upload Python Package on: release: - types: [ published ] + types: [published] jobs: deploy: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c3be4e01..cdaf5e623 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,3 +44,9 @@ repos: stages: [pre-commit, commit-msg] args: [--ignore-words-list, "Teh,aas,ans,dout", --check-filenames, --skip, "*.ipynb"] + # Format yaml files + - repo: https://github.com/google/yamlfmt + rev: v0.20.0 + hooks: + - id: yamlfmt + args: [-formatter, retain_line_breaks=true] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 26b13e1d4..5904cb20e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,11 +6,11 @@ build: python: "3.10" sphinx: - configuration: docs/source/conf.py + configuration: docs/source/conf.py formats: - - pdf + - pdf python: - install: - - requirements: docs/requirements.txt + install: + - requirements: docs/requirements.txt diff --git a/numpyro/examples/datasets.py b/numpyro/examples/datasets.py index 0c6563da6..7bfd5be5c 100644 --- a/numpyro/examples/datasets.py +++ b/numpyro/examples/datasets.py @@ -279,7 +279,14 @@ def _load_jsb_chorales(): file_path = os.path.join(DATA_DIR, "jsb_chorales.pickle") with open(file_path, "rb") as f: - data = pickle.load(f) + # Filter numpy deprecation warning from loading legacy pickle file + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="dtype.*align should be passed as Python or NumPy boolean", + category=np.exceptions.VisibleDeprecationWarning, + ) + data = pickle.load(f) # XXX: we might expose those in `load_dataset` keywords min_note = 21