Skip to content

Commit 655f259

Browse files
committed
Merge remote-tracking branch 'upstream' into update-test-time-action
2 parents 30eeab7 + 5d4e9e0 commit 655f259

File tree

25 files changed

+250
-141
lines changed

25 files changed

+250
-141
lines changed

.github/workflows/test.yml

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
if: ${{ needs.changes.outputs.changes == 'true' }}
5555
strategy:
5656
matrix:
57-
python-version: ["3.10", "3.12"]
57+
python-version: ["3.10", "3.13"]
5858
steps:
5959
- uses: actions/checkout@v4
6060
with:
@@ -75,15 +75,14 @@ jobs:
7575
fail-fast: false
7676
matrix:
7777
os: ["ubuntu-latest"]
78-
python-version: ["3.10", "3.12"]
78+
python-version: ["3.10", "3.13"]
7979
numpy-version: ["~=1.26.0", ">=2.0"]
8080
fast-compile: [0, 1]
8181
float32: [0, 1]
8282
install-numba: [0]
8383
install-jax: [0]
8484
install-torch: [0]
8585
part:
86-
- "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
8786
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
8887
- "tests/scan"
8988
- "tests/sparse"
@@ -98,23 +97,24 @@ jobs:
9897
fast-compile: 1
9998
- python-version: "3.10"
10099
float32: 1
101-
- python-version: "3.10"
102-
part: "tests/tensor/test_math.py"
103100
- fast-compile: 1
104101
float32: 1
105-
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
106-
float32: 1
107-
- part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
108-
fast-compile: 1
109102
- numpy-version: "~=1.26.0"
110103
fast-compile: 1
111104
- numpy-version: "~=1.26.0"
112105
float32: 1
113106
- numpy-version: "~=1.26.0"
114-
python-version: "3.12"
115-
- numpy-version: "~=1.26.0"
116-
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
107+
python-version: "3.13"
117108
include:
109+
- os: "ubuntu-latest"
110+
part: "--doctest-modules pytensor --ignore=pytensor/misc/check_duplicate_key.py --ignore=pytensor/link"
111+
python-version: "3.12"
112+
numpy-version: ">=2.0"
113+
fast-compile: 0
114+
float32: 0
115+
install-numba: 0
116+
install-jax: 0
117+
install-torch: 0
118118
- install-numba: 1
119119
os: "ubuntu-latest"
120120
python-version: "3.10"
@@ -124,7 +124,7 @@ jobs:
124124
part: "tests/link/numba"
125125
- install-numba: 1
126126
os: "ubuntu-latest"
127-
python-version: "3.12"
127+
python-version: "3.13"
128128
numpy-version: "~=2.1.0"
129129
fast-compile: 0
130130
float32: 0
@@ -138,7 +138,7 @@ jobs:
138138
part: "tests/link/jax"
139139
- install-jax: 1
140140
os: "ubuntu-latest"
141-
python-version: "3.12"
141+
python-version: "3.13"
142142
numpy-version: ">=2.0"
143143
fast-compile: 0
144144
float32: 0
@@ -151,23 +151,14 @@ jobs:
151151
float32: 0
152152
part: "tests/link/pytorch"
153153
- os: macos-15
154-
python-version: "3.12"
154+
python-version: "3.13"
155155
numpy-version: ">=2.0"
156156
fast-compile: 0
157157
float32: 0
158158
install-numba: 0
159159
install-jax: 0
160160
install-torch: 0
161161
part: "tests/tensor/test_blas.py tests/tensor/test_elemwise.py tests/tensor/test_math_scipy.py"
162-
- os: "ubuntu-latest"
163-
python-version: "3.10"
164-
numpy-version: "~=1.26.0"
165-
fast-compile: 0
166-
float32: 0
167-
install-numba: 0
168-
install-jax: 0
169-
install-torch: 0
170-
part: "tests/tensor/test_math.py"
171162

172163
steps:
173164
- uses: actions/checkout@v4
@@ -198,13 +189,13 @@ jobs:
198189
run: |
199190
200191
if [[ $OS == "macos-15" ]]; then
201-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
192+
micromamba install --yes -q "python~=${PYTHON_VERSION}" "numpy${NUMPY_VERSION}" scipy pip graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock libblas=*=*accelerate;
202193
else
203-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
194+
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock;
204195
fi
205-
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
206-
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
207-
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
196+
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
197+
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
198+
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
208199
pip install pytest-sphinx
209200
210201
pip install -e ./
@@ -269,7 +260,7 @@ jobs:
269260
- name: Install dependencies
270261
shell: micromamba-shell {0}
271262
run: |
272-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
263+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
273264
pip install -e ./
274265
micromamba list && pip freeze
275266
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
@@ -322,7 +313,7 @@ jobs:
322313
- name: Set up Python
323314
uses: actions/setup-python@v5
324315
with:
325-
python-version: "3.12"
316+
python-version: "3.13"
326317

327318
- name: Install dependencies
328319
run: |

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"
1010
[project]
1111
name = "pytensor"
1212
dynamic = ['version']
13-
requires-python = ">=3.10,<3.13"
13+
requires-python = ">=3.10,<3.14"
1414
authors = [{ name = "pymc-devs", email = "[email protected]" }]
1515
description = "Optimizing compiler for evaluating mathematical expressions on CPUs and GPUs."
1616
readme = "README.rst"
@@ -33,6 +33,7 @@ classifiers = [
3333
"Programming Language :: Python :: 3.10",
3434
"Programming Language :: Python :: 3.11",
3535
"Programming Language :: Python :: 3.12",
36+
"Programming Language :: Python :: 3.13",
3637
]
3738

3839
keywords = [

pytensor/compile/function/pfunc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def construct_pfunc_ins_and_outs(
569569
if not fgraph:
570570
# Extend the outputs with the updates on input variables so they are
571571
# also cloned
572-
additional_outputs = [i.update for i in inputs if i.update]
572+
additional_outputs = [i.update for i in inputs if i.update is not None]
573573
if outputs is None:
574574
out_list = []
575575
else:
@@ -608,7 +608,7 @@ def construct_pfunc_ins_and_outs(
608608
new_i.variable = iv
609609

610610
# If needed, replace the input's update by its cloned equivalent
611-
if i.update:
611+
if i.update is not None:
612612
new_i.update = clone_d[i.update]
613613

614614
new_inputs.append(new_i)

pytensor/compile/function/types.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def std_fgraph(
198198
update_mapping = {}
199199
out_idx = len(output_specs)
200200
for idx, input_spec in enumerate(input_specs):
201-
if input_spec.update:
201+
if input_spec.update is not None:
202202
updates.append(input_spec.update)
203203
update_mapping[out_idx] = idx
204204
out_idx += 1
@@ -1195,7 +1195,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
11951195
updated_fgraph_inputs = {
11961196
fgraph_i
11971197
for i, fgraph_i in zip(wrapped_inputs, fgraph.inputs, strict=True)
1198-
if getattr(i, "update", False)
1198+
if getattr(i, "update", None) is not None
11991199
}
12001200

12011201
# We can't use fgraph.inputs as this don't include Constant Value.
@@ -1351,7 +1351,11 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13511351
ancestors(
13521352
(
13531353
[o.variable for o in outputs]
1354-
+ [i.update for i in inputs if getattr(i, "update", False)]
1354+
+ [
1355+
i.update
1356+
for i in inputs
1357+
if getattr(i, "update", None) is not None
1358+
]
13551359
),
13561360
blockers=[i.variable for i in inputs],
13571361
)

pytensor/compile/nanguardmode.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _is_numeric_value(arr, var):
3636
return False
3737
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
3838
return False
39-
elif var and isinstance(var.type, RandomType):
39+
elif var is not None and isinstance(var.type, RandomType):
4040
return False
4141
elif isinstance(arr, slice):
4242
return False

pytensor/link/numba/dispatch/random.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
6464
@numba_core_rv_funcify.register(ptr.LaplaceRV)
6565
@numba_core_rv_funcify.register(ptr.BinomialRV)
6666
@numba_core_rv_funcify.register(ptr.NegBinomialRV)
67-
@numba_core_rv_funcify.register(ptr.MultinomialRV)
6867
@numba_core_rv_funcify.register(ptr.PermutationRV)
6968
@numba_core_rv_funcify.register(ptr.IntegersRV)
7069
def numba_core_rv_default(op, node):
@@ -132,6 +131,15 @@ def random(rng, b, scale):
132131
return random
133132

134133

134+
@numba_core_rv_funcify.register(ptr.InvGammaRV)
135+
def numba_core_InvGammaRV(op, node):
136+
@numba_basic.numba_njit
137+
def random(rng, shape, scale):
138+
return 1 / rng.gamma(shape, 1 / scale)
139+
140+
return random
141+
142+
135143
@numba_core_rv_funcify.register(ptr.CategoricalRV)
136144
def core_CategoricalRV(op, node):
137145
@numba_basic.numba_njit
@@ -142,6 +150,29 @@ def random_fn(rng, p):
142150
return random_fn
143151

144152

153+
@numba_core_rv_funcify.register(ptr.MultinomialRV)
154+
def core_MultinomialRV(op, node):
155+
dtype = op.dtype
156+
157+
@numba_basic.numba_njit
158+
def random_fn(rng, n, p):
159+
n_cat = p.shape[0]
160+
draws = np.zeros(n_cat, dtype=dtype)
161+
remaining_p = np.float64(1.0)
162+
remaining_n = n
163+
for i in range(n_cat - 1):
164+
draws[i] = rng.binomial(remaining_n, p[i] / remaining_p)
165+
remaining_n -= draws[i]
166+
if remaining_n <= 0:
167+
break
168+
remaining_p -= p[i]
169+
if remaining_n > 0:
170+
draws[n_cat - 1] = remaining_n
171+
return draws
172+
173+
return random_fn
174+
175+
145176
@numba_core_rv_funcify.register(ptr.MvNormalRV)
146177
def core_MvNormalRV(op, node):
147178
method = op.method

pytensor/scalar/basic.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,37 @@ def get_scalar_type(dtype, cache: dict[str, ScalarType] = {}) -> ScalarType:
823823

824824

825825
class _scalar_py_operators:
826+
# These can't work because Python requires native output types
827+
def __bool__(self):
828+
raise TypeError(
829+
"ScalarVariable cannot be converted to Python boolean. "
830+
"Call `.astype(bool)` for the symbolic equivalent."
831+
)
832+
833+
def __index__(self):
834+
raise TypeError(
835+
"ScalarVariable cannot be converted to Python integer. "
836+
"Call `.astype(int)` for the symbolic equivalent."
837+
)
838+
839+
def __int__(self):
840+
raise TypeError(
841+
"ScalarVariable cannot be converted to Python integer. "
842+
"Call `.astype(int)` for the symbolic equivalent."
843+
)
844+
845+
def __float__(self):
846+
raise TypeError(
847+
"ScalarVariable cannot be converted to Python float. "
848+
"Call `.astype(float)` for the symbolic equivalent."
849+
)
850+
851+
def __complex__(self):
852+
raise TypeError(
853+
"ScalarVariable cannot be converted to Python complex number. "
854+
"Call `.astype(complex)` for the symbolic equivalent."
855+
)
856+
826857
# So that we can simplify checking code when we have a mixture of ScalarType
827858
# variables and Tensor variables
828859
ndim = 0
@@ -843,11 +874,6 @@ def __abs__(self):
843874
def __neg__(self):
844875
return neg(self)
845876

846-
# CASTS
847-
# def __int__(self): return AsInt(self).out
848-
# def __float__(self): return AsDouble(self).out
849-
# def __complex__(self): return AsComplex(self).out
850-
851877
# BITWISE
852878
def __invert__(self):
853879
return invert(self)

pytensor/scalar/loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def __init__(
6060
constant = []
6161
if not len(init) == len(update):
6262
raise ValueError("An update must be given for each init variable")
63-
if until:
63+
if until is not None:
6464
inputs, outputs = clone([*init, *constant], [*update, until])
6565
else:
6666
inputs, outputs = clone([*init, *constant], update)
6767

68-
self.is_while = bool(until)
68+
self.is_while = until is not None
6969
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
7070
self._validate_updates(self.inputs, self.outputs)
7171

pytensor/scalar/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,7 @@ def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x):
856856
dfac = k_minus_one_minus_n * dfac + fac
857857
fac *= k_minus_one_minus_n
858858
delta = dfac / xpow
859-
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), ()
859+
return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), None
860860

861861
init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac]
862862
constant = [x]

pytensor/scan/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -979,7 +979,7 @@ def wrap_into_list(x):
979979
# user-specified within the inner-function (e.g. by returning an update
980980
# `dict`) or the `SharedVariable.default_update`s of a shared variable
981981
# created in the inner-function.
982-
if input.update and (is_local or input.variable in updates):
982+
if input.update is not None and (is_local or input.variable in updates):
983983
# We need to remove the `default_update`s on the shared
984984
# variables created within the context of the loop function
985985
# (e.g. via use of `RandomStream`); otherwise, they'll get

0 commit comments

Comments
 (0)