Skip to content

Commit b252f14

Browse files
committed
style: Reformat some code
1 parent db793f0 commit b252f14

File tree

5 files changed

+23
-13
lines changed

5 files changed

+23
-13
lines changed

python/nutpie/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from nutpie.sample import sample
55

66
__version__: str = _lib.__version__
7-
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
7+
__all__ = ["__version__", "compile_pymc_model", "compile_stan_model", "sample"]

python/nutpie/compile_pymc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import dataclasses
22
import itertools
33
import warnings
4+
from collections.abc import Iterable
45
from dataclasses import dataclass
56
from functools import wraps
67
from importlib.util import find_spec
78
from math import prod
8-
from typing import TYPE_CHECKING, Any, Callable, Iterable, Literal, Optional, Union
9+
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
910

1011
import numpy as np
1112
import pandas as pd
@@ -500,7 +501,10 @@ def compile_pymc_model(
500501
if gradient_backend == "jax":
501502
raise ValueError("Gradient backend cannot be jax when using numba backend")
502503
return _compile_pymc_model_numba(
503-
model=model, pymc_initial_point_fn=initial_point_fn, var_names=var_names, **kwargs
504+
model=model,
505+
pymc_initial_point_fn=initial_point_fn,
506+
var_names=var_names,
507+
**kwargs,
504508
)
505509
elif backend.lower() == "jax":
506510
return _compile_pymc_model_jax(

src/pyfunc.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use std::sync::Arc;
33
use anyhow::{anyhow, bail, Context, Result};
44
use arrow::{
55
array::{
6-
Array, ArrayBuilder, BooleanBuilder, FixedSizeListBuilder, Float32Builder, Float64Builder,
7-
Int64Builder, LargeListBuilder, ListBuilder, PrimitiveBuilder, StructBuilder,
6+
Array, ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int64Builder,
7+
LargeListBuilder, PrimitiveBuilder, StructBuilder,
88
},
99
datatypes::{DataType, Field, Float32Type, Float64Type, Int64Type},
1010
};
@@ -16,7 +16,7 @@ use pyo3::{
1616
Bound, Py, PyAny, PyErr, Python,
1717
};
1818
use rand::Rng;
19-
use rand_distr::{Distribution, StandardNormal, Uniform};
19+
use rand_distr::{Distribution, Uniform};
2020
use smallvec::SmallVec;
2121
use thiserror::Error;
2222

src/pymc.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ use std::{ffi::c_void, fmt::Display, sync::Arc};
22

33
use anyhow::{bail, Context, Result};
44
use arrow::{
5-
array::{
6-
Array, FixedSizeListArray, Float64Array, LargeListArray, LargeListBuilder, StructArray,
7-
},
5+
array::{Array, Float64Array, LargeListArray, StructArray},
86
buffer::OffsetBuffer,
97
datatypes::{DataType, Field, Fields},
108
};
@@ -16,7 +14,6 @@ use pyo3::{
1614
types::{PyAnyMethods, PyList},
1715
Bound, Py, PyAny, PyObject, PyResult, Python,
1816
};
19-
use rand::{distributions::Uniform, prelude::Distribution};
2017

2118
use thiserror::Error;
2219

tests/test_pymc.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,10 @@ def test_pymc_var_names(backend, gradient_backend):
204204
pm.Deterministic("c", mu * b)
205205

206206
compiled = nutpie.compile_pymc_model(
207-
model, backend=backend, gradient_backend=gradient_backend, var_names=None,
207+
model,
208+
backend=backend,
209+
gradient_backend=gradient_backend,
210+
var_names=None,
208211
)
209212
trace = nutpie.sample(compiled, chains=1, seed=1)
210213

@@ -213,7 +216,10 @@ def test_pymc_var_names(backend, gradient_backend):
213216
assert hasattr(trace.posterior, "c")
214217

215218
compiled = nutpie.compile_pymc_model(
216-
model, backend=backend, gradient_backend=gradient_backend, var_names=[],
219+
model,
220+
backend=backend,
221+
gradient_backend=gradient_backend,
222+
var_names=[],
217223
)
218224
trace = nutpie.sample(compiled, chains=1, seed=1)
219225

@@ -222,7 +228,10 @@ def test_pymc_var_names(backend, gradient_backend):
222228
assert not hasattr(trace.posterior, "c")
223229

224230
compiled = nutpie.compile_pymc_model(
225-
model, backend=backend, gradient_backend=gradient_backend, var_names=["b"],
231+
model,
232+
backend=backend,
233+
gradient_backend=gradient_backend,
234+
var_names=["b"],
226235
)
227236
trace = nutpie.sample(compiled, chains=1, seed=1)
228237

0 commit comments

Comments
 (0)