Skip to content

Commit 7b6bd15

Browse files
authored
Merge branch 'main' into ciguaran_fix_smc_after_bump
2 parents b9f6b3c + e96d07f commit 7b6bd15

20 files changed

+2028
-1198
lines changed

.github/workflows/pypi.yml

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -45,35 +45,10 @@ jobs:
4545
with:
4646
name: artifact
4747
path: dist/*
48-
test:
49-
name: upload to test PyPI
50-
needs: [build]
51-
runs-on: ubuntu-latest
52-
if: github.event_name == 'release' && github.event.action == 'published'
53-
steps:
54-
- uses: actions/download-artifact@v3
55-
with:
56-
name: artifact
57-
path: dist
58-
- uses: pypa/gh-action-pypi-publish@release/v1
59-
with:
60-
skip_existing: true
61-
user: __token__
62-
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
63-
repository_url: https://test.pypi.org/legacy/
64-
- uses: actions/setup-python@v5
65-
with:
66-
python-version: "3.10"
67-
- name: Test pip install from test.pypi
68-
run: |
69-
python -m venv venv-test-pypi
70-
venv-test-pypi/bin/python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple pymc-experimental
71-
echo "Checking import and version number"
72-
venv-test-pypi/bin/python -c "import pymc_experimental; assert pymc_experimental.__version__ == '${{ github.ref_name }}'[1:]"
7348

7449
publish:
7550
name: upload release to PyPI
76-
needs: [build, test]
51+
needs: [build]
7752
runs-on: ubuntu-latest
7853
if: github.event_name == 'release' && github.event.action == 'published'
7954
steps:

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ jobs:
7373
- uses: mamba-org/setup-micromamba@v1
7474
with:
7575
environment-file: conda-envs/windows-environment-test.yml
76+
micromamba-version: "1.5.10-0" # Until https://github.com/mamba-org/mamba/issues/3467 is not fixed
7677
create-args: >-
7778
python=${{matrix.python-version}}
7879
environment-name: pymc-experimental-test

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
ci:
2+
autofix_prs: false
3+
14
repos:
25
- repo: https://github.com/pre-commit/pre-commit-hooks
36
rev: v4.6.0

conda-envs/environment-test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.16.1 # CI was failing to resolve
14-
- blackjax>=1.2.3
13+
- pymc>=5.17.0 # CI was failing to resolve
14+
- blackjax
1515
- scikit-learn

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,6 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.16.1 # CI was failing to resolve
13+
- pymc>=5.17.0 # CI was failing to resolve
1414
- blackjax
1515
- scikit-learn

pymc_experimental/__init__.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414
import logging
1515

16-
from pymc_experimental import distributions, gp, statespace, utils
16+
from pymc_experimental import gp, statespace, utils
17+
from pymc_experimental.distributions import *
1718
from pymc_experimental.inference.fit import fit
18-
from pymc_experimental.model.marginal_model import MarginalModel
19+
from pymc_experimental.model.marginal.marginal_model import MarginalModel
1920
from pymc_experimental.model.model_api import as_model
2021
from pymc_experimental.version import __version__
2122

@@ -26,15 +27,3 @@
2627
if len(_log.handlers) == 0:
2728
handler = logging.StreamHandler()
2829
_log.addHandler(handler)
29-
30-
31-
__all__ = [
32-
"distributions",
33-
"gp",
34-
"statespace",
35-
"utils",
36-
"fit",
37-
"MarginalModel",
38-
"as_model",
39-
"__version__",
40-
]

pymc_experimental/distributions/timeseries.py

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,12 @@
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
2222
from pymc.pytensorf import constant_fold, intX
23-
from pymc.util import check_dist_not_registered
23+
from pymc.step_methods import STEP_METHODS
24+
from pymc.step_methods.arraystep import ArrayStep
25+
from pymc.step_methods.compound import Competence
26+
from pymc.step_methods.metropolis import CategoricalGibbsMetropolis
27+
from pymc.util import check_dist_not_registered, get_value_vars_from_user_vars
28+
from pytensor import Mode
2429
from pytensor.graph.basic import Node
2530
from pytensor.tensor import TensorVariable
2631
from pytensor.tensor.random.op import RandomVariable
@@ -101,10 +106,15 @@ class DiscreteMarkovChain(Distribution):
101106
Create a Markov Chain of length 100 with 3 states. The number of states is given by the shape of P,
102107
3 in this case.
103108
104-
>>> with pm.Model() as markov_chain:
105-
>>> P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
106-
>>> init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
107-
>>> markov_chain = pm.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
109+
.. code-block:: python
110+
111+
import pymc as pm
112+
import pymc_experimental as pmx
113+
114+
with pm.Model() as markov_chain:
115+
P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116+
init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117+
markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
108118
109119
"""
110120

@@ -266,3 +276,69 @@ def discrete_mc_logp(op, values, P, steps, init_dist, state_rng, **kwargs):
266276
"P must sum to 1 along the last axis, "
267277
"First dimension of init_dist must be n_lags",
268278
)
279+
280+
281+
class DiscreteMarkovChainGibbsMetropolis(CategoricalGibbsMetropolis):
282+
name = "discrete_markov_chain_gibbs_metropolis"
283+
284+
def __init__(self, vars, proposal="uniform", order="random", model=None):
285+
model = pm.modelcontext(model)
286+
vars = get_value_vars_from_user_vars(vars, model)
287+
initial_point = model.initial_point()
288+
289+
dimcats = []
290+
# The above variable is a list of pairs (aggregate dimension, number
291+
# of categories). For example, if vars = [x, y] with x being a 2-D
292+
# variable with M categories and y being a 3-D variable with N
293+
# categories, we will have dimcats = [(0, M), (1, M), (2, N), (3, N), (4, N)].
294+
for v in vars:
295+
v_init_val = initial_point[v.name]
296+
rv_var = model.values_to_rvs[v]
297+
rv_op = rv_var.owner.op
298+
299+
if not isinstance(rv_op, DiscreteMarkovChainRV):
300+
raise TypeError("All variables must be DiscreteMarkovChainRV")
301+
302+
k_graph = rv_var.owner.inputs[0].shape[-1]
303+
(k_graph,) = model.replace_rvs_by_values((k_graph,))
304+
k = model.compile_fn(
305+
k_graph,
306+
inputs=model.value_vars,
307+
on_unused_input="ignore",
308+
mode=Mode(linker="py", optimizer=None),
309+
)(initial_point)
310+
start = len(dimcats)
311+
dimcats += [(dim, k) for dim in range(start, start + v_init_val.size)]
312+
313+
if order == "random":
314+
self.shuffle_dims = True
315+
self.dimcats = dimcats
316+
else:
317+
if sorted(order) != list(range(len(dimcats))):
318+
raise ValueError("Argument 'order' has to be a permutation")
319+
self.shuffle_dims = False
320+
self.dimcats = [dimcats[j] for j in order]
321+
322+
if proposal == "uniform":
323+
self.astep = self.astep_unif
324+
elif proposal == "proportional":
325+
# Use the optimized "Metropolized Gibbs Sampler" described in Liu96.
326+
self.astep = self.astep_prop
327+
else:
328+
raise ValueError("Argument 'proposal' should either be 'uniform' or 'proportional'")
329+
330+
# Doesn't actually tune, but it's required to emit a sampler stat
331+
# that indicates whether a draw was done in a tuning phase.
332+
self.tune = True
333+
334+
# We bypass CategoryGibbsMetropolis's __init__ to avoid it's specialiazed initialization logic
335+
ArrayStep.__init__(self, vars, [model.compile_logp()])
336+
337+
@staticmethod
338+
def competence(var):
339+
if isinstance(var.owner.op, DiscreteMarkovChainRV):
340+
return Competence.IDEAL
341+
return Competence.INCOMPATIBLE
342+
343+
344+
STEP_METHODS.append(DiscreteMarkovChainGibbsMetropolis)

pymc_experimental/model/marginal/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)