Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,14 @@ jobs:
run: |

if [[ $OS == "macos-15" ]]; then
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx libblas=*=*accelerate;
else
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
micromamba list && pip freeze
Expand Down
3 changes: 1 addition & 2 deletions environment-osx-arm64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ dependencies:
- pytest-xdist
- pytest-benchmark
- pytest-mock
- pip:
- pytest-sphinx
- pytest-sphinx
# For building docs
- sphinx>=5.1.0,<6
- sphinx_rtd_theme
Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ dependencies:
- pytest-xdist
- pytest-benchmark
- pytest-mock
- pip:
- pytest-sphinx
- pytest-sphinx
# For building docs
- sphinx>=5.1.0,<6
- sphinx_rtd_theme
Expand Down
4 changes: 3 additions & 1 deletion pytensor/link/jax/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
class JAXLinker(JITLinker):
"""A `Linker` that JIT-compiles NumPy-based operations using JAX."""

scalar_shape_inputs: tuple[int, ...]

def __init__(self, *args, **kwargs):
self.scalar_shape_inputs: tuple[int] = () # type: ignore[annotation-unchecked]
self.scalar_shape_inputs = ()
super().__init__(*args, **kwargs)

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions pytensor/link/numba/dispatch/vectorize_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,17 +517,17 @@ def make_loop_call(
output_slices = []
for output, output_type, bc in zip(outputs, output_types, output_bc, strict=True):
core_ndim = output_type.ndim - len(bc)
size_type = output.shape.type.element # type: ignore
output_shape = cgutils.unpack_tuple(builder, output.shape) # type: ignore
output_strides = cgutils.unpack_tuple(builder, output.strides) # type: ignore
size_type = output.shape.type.element # pyright: ignore[reportAttributeAccessIssue]
output_shape = cgutils.unpack_tuple(builder, output.shape) # pyright: ignore[reportAttributeAccessIssue]
output_strides = cgutils.unpack_tuple(builder, output.strides) # pyright: ignore[reportAttributeAccessIssue]

idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc, strict=True)] + [
zero
] * core_ndim
ptr = cgutils.get_item_pointer2(
context,
builder,
output.data, # type:ignore
output.data,
output_shape,
output_strides,
output_type.layout,
Expand Down
2 changes: 1 addition & 1 deletion pytensor/npy_2_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@


if using_numpy_2:
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version()
ndarray_c_version = np._core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]
else:
ndarray_c_version = np.core._multiarray_umath._get_ndarray_c_version() # type: ignore[attr-defined]

Expand Down
2 changes: 1 addition & 1 deletion pytensor/scan/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def safe_new(
except TestValueError:
pass

return nw_x
return type_cast(Variable, nw_x)


class until:
Expand Down
12 changes: 8 additions & 4 deletions pytensor/tensor/einsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,14 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
# Numpy einsum_path requires arrays even though only the shapes matter
# It's not trivial to duck-type our way around because of internal call to `asanyarray`
*[np.empty(shape) for shape in shapes],
einsum_call=True, # Not part of public API
# einsum_call is not part of public API
einsum_call=True, # type: ignore[arg-type]
optimize="optimal",
) # type: ignore
np_path = tuple(contraction[0] for contraction in contraction_list)
)
np_path: PATH | tuple[tuple[int, ...]] = tuple(
contraction[0] # type: ignore[misc]
for contraction in contraction_list
)

if len(np_path) == 1 and len(np_path[0]) > 2:
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
Expand All @@ -610,7 +614,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
subscripts, tensor_operands, path
)
else:
path = np_path
path = cast(PATH, np_path)

optimized = True

Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/random/rewriting/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def introduce_explicit_core_shape_rv(fgraph, node):
# ← dirichlet_rv{"(a)->(a)"}.1 [id F]
# └─ ···
"""
op: RandomVariable = node.op # type: ignore[annotation-unchecked]
op: RandomVariable = node.op

next_rng, rv = node.outputs
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature:
core_shape = [
shape_feature.get_shape(rv, -i - 1) for i in reversed(range(op.ndim_supp))
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def local_blockwise_alloc(fgraph, node):
This is critical to remove many unnecessary Blockwise, or to reduce the work done by it
"""

op: Blockwise = node.op # type: ignore
op: Blockwise = node.op

batch_ndim = op.batch_ndim(node)
if not batch_ndim:
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/rewriting/numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def introduce_explicit_core_shape_blockwise(fgraph, node):
# [Blockwise{SVD{full_matrices=True, compute_uv=True}, (m,n)->(m,m),(k),(n,n)}].2 [id A] 6
# └─ ···
"""
op: Blockwise = node.op # type: ignore[annotation-unchecked]
op: Blockwise = node.op
batch_ndim = op.batch_ndim(node)

shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None) # type: ignore[annotation-unchecked]
shape_feature: ShapeFeature | None = getattr(fgraph, "shape_feature", None)
if shape_feature:
core_shapes = [
[shape_feature.get_shape(out, i) for i in range(batch_ndim, out.type.ndim)]
Expand Down
1 change: 0 additions & 1 deletion scripts/mypy-failing.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ pytensor/compile/debugmode.py
pytensor/compile/function/pfunc.py
pytensor/compile/function/types.py
pytensor/compile/mode.py
pytensor/compile/sharedvalue.py
pytensor/graph/rewriting/basic.py
pytensor/ifelse.py
pytensor/link/numba/dispatch/elemwise.py
Expand Down