|
15 | 15 | import re
|
16 | 16 | import sys
|
17 | 17 |
|
| 18 | +from datetime import datetime |
18 | 19 | from functools import partial
|
19 | 20 | from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
20 | 21 |
|
21 |
| -from pytensor.tensor.random.type import RandomType |
22 |
| - |
23 |
| -from pymc.initial_point import StartDict |
24 |
| -from pymc.sampling.mcmc import _init_jitter |
25 |
| - |
26 |
| -xla_flags = os.getenv("XLA_FLAGS", "") |
27 |
| -xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags).split() |
28 |
| -os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) |
29 |
| - |
30 |
| -from datetime import datetime |
31 |
| - |
32 | 22 | import arviz as az
|
33 | 23 | import jax
|
34 | 24 | import numpy as np
|
|
43 | 33 | from pytensor.link.jax.dispatch import jax_funcify
|
44 | 34 | from pytensor.raise_op import Assert
|
45 | 35 | from pytensor.tensor import TensorVariable
|
| 36 | +from pytensor.tensor.random.type import RandomType |
46 | 37 | from pytensor.tensor.shape import SpecifyShape
|
47 | 38 |
|
48 | 39 | from pymc import Model, modelcontext
|
49 | 40 | from pymc.backends.arviz import find_constants, find_observations
|
| 41 | +from pymc.initial_point import StartDict |
50 | 42 | from pymc.logprob.utils import CheckParameterValue
|
| 43 | +from pymc.sampling.mcmc import _init_jitter |
51 | 44 | from pymc.util import (
|
52 | 45 | RandomSeed,
|
53 | 46 | RandomState,
|
54 | 47 | _get_seeds_per_chain,
|
55 | 48 | get_default_varnames,
|
56 | 49 | )
|
57 | 50 |
|
| 51 | +xla_flags_env = os.getenv("XLA_FLAGS", "") |
| 52 | +xla_flags = re.sub(r"--xla_force_host_platform_device_count=.+\s", "", xla_flags_env).split() |
| 53 | +os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"] + xla_flags) |
| 54 | + |
58 | 55 | __all__ = (
|
59 | 56 | "get_jaxified_graph",
|
60 | 57 | "get_jaxified_logp",
|
@@ -111,7 +108,7 @@ def get_jaxified_graph(
|
111 | 108 | ) -> List[TensorVariable]:
|
112 | 109 | """Compile an PyTensor graph into an optimized JAX function"""
|
113 | 110 |
|
114 |
| - graph = _replace_shared_variables(outputs) |
| 111 | + graph = _replace_shared_variables(outputs) if outputs is not None else None |
115 | 112 |
|
116 | 113 | fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
|
117 | 114 | # We need to add a Supervisor to the fgraph to be able to run the
|
@@ -254,12 +251,10 @@ def _get_batched_jittered_initial_points(
|
254 | 251 | jitter=jitter,
|
255 | 252 | jitter_max_retries=jitter_max_retries,
|
256 | 253 | )
|
257 |
| - initial_points = [list(initial_point.values()) for initial_point in initial_points] |
| 254 | + initial_points_values = [list(initial_point.values()) for initial_point in initial_points] |
258 | 255 | if chains == 1:
|
259 |
| - initial_points = initial_points[0] |
260 |
| - else: |
261 |
| - initial_points = [np.stack(init_state) for init_state in zip(*initial_points)] |
262 |
| - return initial_points |
| 256 | + return initial_points_values[0] |
| 257 | + return [np.stack(init_state) for init_state in zip(*initial_points_values)] |
263 | 258 |
|
264 | 259 |
|
265 | 260 | def _update_coords_and_dims(
|
|
0 commit comments