diff --git a/pyproject.toml b/pyproject.toml index d9bee4d68..1f80feac8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,7 +70,8 @@ lines-between-types = 1 [tool.ruff.lint.per-file-ignores] 'tests/*.py' = [ 'F841', # Unused variable warning for test files -- common in pymc model declarations - 'D106' # Missing docstring for public method -- unittest test subclasses don't need docstrings + 'D106', # Missing docstring for public method -- unittest test subclasses don't need docstrings + 'E402' # Import at top, not respected when pytest.importorskip is required ] 'tests/statespace/*.py' = [ 'F401', # Unused import warning for test files -- this check removes imports of fixtures diff --git a/tests/test_blackjax_smc.py b/tests/test_blackjax_smc.py index 49db7de7f..4a3894139 100644 --- a/tests/test_blackjax_smc.py +++ b/tests/test_blackjax_smc.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import jax import numpy as np import pymc as pm import pytensor.tensor as pt @@ -21,6 +20,9 @@ from numpy import dtype from xarray.core.utils import Frozen +jax = pytest.importorskip("jax") +pytest.importorskip("blackjax") + from pymc_experimental.inference.smc.sampling import ( arviz_from_particles, blackjax_particles_from_pymc_population,