Skip to content

parallelizing multiple inferences with dask fails #7819

@apatlpo

Description

@apatlpo

Describe the issue:

I am trying to run multiple inference on an HPC cluster with dask.
Some of the spawned calculation fail with a cryptic error message.
This may be a pytensor issue, I am not sure.

Reproduceable code example:

import numpy as np
import pandas as pd

import pymc as pm
from pymc import HalfCauchy, Model, Normal, sample

from dask.distributed import Client, LocalCluster
from dask_jobqueue import PBSCluster

# setting up the dask cluster
cluster = PBSCluster()
w = cluster.scale(jobs=5)
client = Client(cluster)

# inference to be parallelized
def inference(RANDOM_SEED):

    rng = np.random.default_rng(RANDOM_SEED)

    size = 200
    true_intercept = 1
    true_slope = 2
    
    x = np.linspace(0, 1, size)
    # y = a + b*x
    true_regression_line = true_intercept + true_slope * x
    # add noise
    y = true_regression_line + rng.normal(scale=0.5, size=size)
    
    data = pd.DataFrame({"x": x, "y": y})

    with Model() as model:  # model specifications in PyMC are wrapped in a with-statement
        # Define priors
        sigma = HalfCauchy("sigma", beta=10)
        intercept = Normal("Intercept", 0, sigma=20)
        slope = Normal("slope", 0, sigma=20)
    
        # Define likelihood
        likelihood = Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)
    
        # Inference!
        # draw 3000 posterior samples using NUTS sampling
        idata = sample(3000, cores=1) # cores=1 suppresses daemonic child spawn multiprocessing error

    return idata

# spawn calculation
futures = client.map(
    inference,
    range(10),
)

# gather result
client.gather(futures)

Error message:

<details>
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[197], line 1
----> 1 client.gather(futures)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/distributed/client.py:2556, in Client.gather(self, futures, errors, direct, asynchronous)
   2553     local_worker = None
   2555 with shorten_traceback():
-> 2556     return self.sync(
   2557         self._gather,
   2558         futures,
   2559         errors=errors,
   2560         direct=direct,
   2561         local_worker=local_worker,
   2562         asynchronous=asynchronous,
   2563     )

Cell In[186], line 29, in inference()
     25     likelihood = Normal("y", mu=intercept + slope * x, sigma=sigma, observed=y)
     27     # Inference!
     28     # draw 3000 posterior samples using NUTS sampling
---> 29     idata = sample(3000, cores=1)
     31 return idata

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/sampling/mcmc.py:718, in sample()
    715         auto_nuts_init = False
    717 initial_points = None
--> 718 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    720 if nuts_sampler != "pymc":
    721     if not isinstance(step, NUTS):

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/sampling/mcmc.py:237, in assign_step_methods()
    229         selected = max(
    230             methods_list,
    231             key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(  # type: ignore
    232                 var, has_gradient
    233             ),
    234         )
    235         selected_steps.setdefault(selected, []).append(var)
--> 237 return instantiate_steppers(model, steps, selected_steps, step_kwargs)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/sampling/mcmc.py:138, in instantiate_steppers()
    136         args = step_kwargs.get(name, {})
    137         used_keys.add(name)
--> 138         step = step_class(vars=vars, model=model, **args)
    139         steps.append(step)
    141 unused_args = set(step_kwargs).difference(used_keys)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/step_methods/hmc/nuts.py:180, in __init__()
    122 def __init__(self, vars=None, max_treedepth=10, early_max_treedepth=8, **kwargs):
    123     r"""Set up the No-U-Turn sampler.
    124 
    125     Parameters
   (...)
    178     `pm.sample` to the desired number of tuning steps.
    179     """
--> 180     super().__init__(vars, **kwargs)
    182     self.max_treedepth = max_treedepth
    183     self.early_max_treedepth = early_max_treedepth

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/step_methods/hmc/base_hmc.py:109, in __init__()
    107 else:
    108     vars = get_value_vars_from_user_vars(vars, self._model)
--> 109 super().__init__(vars, blocked=blocked, model=self._model, dtype=dtype, **pytensor_kwargs)
    111 self.adapt_step_size = adapt_step_size
    112 self.Emax = Emax

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/step_methods/arraystep.py:163, in __init__()
    160 model = modelcontext(model)
    162 if logp_dlogp_func is None:
--> 163     func = model.logp_dlogp_function(vars, dtype=dtype, **pytensor_kwargs)
    164 else:
    165     func = logp_dlogp_func

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/model/core.py:624, in logp_dlogp_function()
    621     costs = [self.logp()]
    623 input_vars = {i for i in graph_inputs(costs) if not isinstance(i, Constant)}
--> 624 ip = self.initial_point(0)
    625 extra_vars_and_values = {
    626     var: ip[var.name]
    627     for var in self.value_vars
    628     if var in input_vars and var not in grad_vars
    629 }
    630 return ValueGradFunction(costs, grad_vars, extra_vars_and_values, **kwargs)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/model/core.py:1098, in initial_point()
   1085 def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]:
   1086     """Computes the initial point of the model.
   1087 
   1088     Parameters
   (...)
   1096         Maps names of transformed variables to numeric initial values in the transformed space.
   1097     """
-> 1098     fn = make_initial_point_fn(model=self, return_transformed=True)
   1099     return Point(fn(random_seed), model=self)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/initial_point.py:152, in make_initial_point_fn()
    149 # Replace original rng shared variables so that we don't mess with them
    150 # when calling the final seeded function
    151 initial_values = replace_rng_nodes(initial_values)
--> 152 func = compile_pymc(inputs=[], outputs=initial_values, mode=pytensor.compile.mode.FAST_COMPILE)
    154 varnames = []
    155 for var in model.free_RVs:

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pymc/pytensorf.py:1056, in compile_pymc()
   1054 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
   1055 mode = Mode(linker=mode.linker, optimizer=opt_qry)
-> 1056 pytensor_function = pytensor.function(
   1057     inputs,
   1058     outputs,
   1059     updates={**rng_updates, **kwargs.pop("updates", {})},
   1060     mode=mode,
   1061     **kwargs,
   1062 )
   1063 return pytensor_function

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/compile/function/__init__.py:318, in function()
    312     fn = orig_function(
    313         inputs, outputs, mode=mode, accept_inplace=accept_inplace, name=name
    314     )
    315 else:
    316     # note: pfunc will also call orig_function -- orig_function is
    317     #      a choke point that all compilation must pass through
--> 318     fn = pfunc(
    319         params=inputs,
    320         outputs=outputs,
    321         mode=mode,
    322         updates=updates,
    323         givens=givens,
    324         no_default_updates=no_default_updates,
    325         accept_inplace=accept_inplace,
    326         name=name,
    327         rebuild_strict=rebuild_strict,
    328         allow_input_downcast=allow_input_downcast,
    329         on_unused_input=on_unused_input,
    330         profile=profile,
    331         output_keys=output_keys,
    332     )
    333 return fn

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/compile/function/pfunc.py:465, in pfunc()
    451     profile = ProfileStats(message=profile)
    453 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    454     params,
    455     outputs,
   (...)
    462     fgraph=fgraph,
    463 )
--> 465 return orig_function(
    466     inputs,
    467     cloned_outputs,
    468     mode,
    469     accept_inplace=accept_inplace,
    470     name=name,
    471     profile=profile,
    472     on_unused_input=on_unused_input,
    473     output_keys=output_keys,
    474     fgraph=fgraph,
    475 )

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/compile/function/types.py:1750, in orig_function()
   1748 try:
   1749     Maker = getattr(mode, "function_maker", FunctionMaker)
-> 1750     m = Maker(
   1751         inputs,
   1752         outputs,
   1753         mode,
   1754         accept_inplace=accept_inplace,
   1755         profile=profile,
   1756         on_unused_input=on_unused_input,
   1757         output_keys=output_keys,
   1758         name=name,
   1759         fgraph=fgraph,
   1760     )
   1761     with config.change_flags(compute_test_value="off"):
   1762         fn = m.create(defaults)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/compile/function/types.py:1523, in __init__()
   1520 self.fgraph = fgraph
   1522 if not no_fgraph_prep:
-> 1523     self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
   1525 assert len(fgraph.outputs) == len(outputs + found_updates)
   1527 # The 'no_borrow' outputs are the ones for which that we can't
   1528 # return the internal storage pointer.

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/compile/function/types.py:1411, in prepare_fgraph()
   1404 rewrite_time = None
   1406 with config.change_flags(
   1407     mode=mode,
   1408     compute_test_value=config.compute_test_value_opt,
   1409     traceback__limit=config.traceback__compile_limit,
   1410 ):
-> 1411     rewriter_profile = rewriter(fgraph)
   1413     end_rewriter = time.perf_counter()
   1414     rewrite_time = end_rewriter - start_rewriter

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:125, in __call__()
    123 def __call__(self, fgraph):
    124     """Rewrite a `FunctionGraph`."""
--> 125     return self.rewrite(fgraph)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:121, in rewrite()
    112 """
    113 
    114 This is meant as a shortcut for the following::
   (...)
    118 
    119 """
    120 self.add_requirements(fgraph)
--> 121 return self.apply(fgraph, *args, **kwargs)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:291, in apply()
    289 nb_nodes_before = len(fgraph.apply_nodes)
    290 t0 = time.perf_counter()
--> 291 sub_prof = rewriter.apply(fgraph)
    292 l.append(float(time.perf_counter() - t0))
    293 sub_profs.append(sub_prof)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:2443, in apply()
   2441 nb = change_tracker.nb_imported
   2442 t_rewrite = time.perf_counter()
-> 2443 sub_prof = grewrite.apply(fgraph)
   2444 time_rewriters[grewrite] += time.perf_counter() - t_rewrite
   2445 sub_profs.append(sub_prof)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:2027, in apply()
   2025             continue
   2026         current_node = node
-> 2027         nb += self.process_node(fgraph, node)
   2028     loop_t = time.perf_counter() - t0
   2029 finally:

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:1912, in process_node()
   1910 except Exception as e:
   1911     if self.failure_callback is not None:
-> 1912         self.failure_callback(
   1913             e, self, [(x, None) for x in node.outputs], node_rewriter, node
   1914         )
   1915         return False
   1916     else:

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:1767, in warn_inplace()
   1765 if isinstance(exc, InconsistencyError):
   1766     return
-> 1767 return cls.warn(exc, nav, repl_pairs, node_rewriter, node)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:1755, in warn()
   1751     pdb.post_mortem(sys.exc_info()[2])
   1752 elif isinstance(exc, AssertionError) or config.on_opt_error == "raise":
   1753     # We always crash on AssertionError because something may be
   1754     # seriously wrong if such an exception is raised.
-> 1755     raise exc

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:1909, in process_node()
   1907 assert node_rewriter is not None
   1908 try:
-> 1909     replacements = node_rewriter.transform(fgraph, node)
   1910 except Exception as e:
   1911     if self.failure_callback is not None:

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py:1081, in transform()
   1076     if not (
   1077         node.op in self._tracks or isinstance(node.op, self._tracked_types)
   1078     ):
   1079         return False
-> 1081 return self.fn(fgraph, node)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/tensor/rewriting/basic.py:1122, in constant_folding()
   1119     storage_map[o] = [None]
   1120     compute_map[o] = [False]
-> 1122 thunk = node.op.make_thunk(node, storage_map, compute_map, no_recycling=[])
   1123 required = thunk()
   1125 # A node whose inputs are all provided should always return successfully

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/op.py:119, in make_thunk()
    115 self.prepare_node(
    116     node, storage_map=storage_map, compute_map=compute_map, impl="c"
    117 )
    118 try:
--> 119     return self.make_c_thunk(node, storage_map, compute_map, no_recycling)
    120 except (NotImplementedError, MethodNotDefined):
    121     # We requested the c code, so don't catch the error.
    122     if impl == "c":

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/op.py:84, in make_c_thunk()
     82         print(f"Disabling C code for {self} due to unsupported float16")
     83         raise NotImplementedError("float16")
---> 84 outputs = cl.make_thunk(
     85     input_storage=node_input_storage, output_storage=node_output_storage
     86 )
     87 thunk, node_input_filters, node_output_filters = outputs
     89 @is_cthunk_wrapper_type
     90 def rval():

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/basic.py:1182, in make_thunk()
   1147 """Compile this linker's `self.fgraph` and return a function that performs the computations.
   1148 
   1149 The return values can be used as follows:
   (...)
   1179 
   1180 """
   1181 init_tasks, tasks = self.get_init_tasks()
-> 1182 cthunk, module, in_storage, out_storage, error_storage = self.__compile__(
   1183     input_storage, output_storage, storage_map, cache
   1184 )
   1186 res = _CThunk(cthunk, init_tasks, tasks, error_storage, module)
   1187 res.nodes = self.node_order

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/basic.py:1103, in __compile__()
   1101 input_storage = tuple(input_storage)
   1102 output_storage = tuple(output_storage)
-> 1103 thunk, module = self.cthunk_factory(
   1104     error_storage,
   1105     input_storage,
   1106     output_storage,
   1107     storage_map,
   1108     cache,
   1109 )
   1110 return (
   1111     thunk,
   1112     module,
   (...)
   1121     error_storage,
   1122 )

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/basic.py:1627, in cthunk_factory()
   1625     if cache is None:
   1626         cache = get_module_cache()
-> 1627     module = cache.module_from_key(key=key, lnk=self)
   1629 vars = self.inputs + self.outputs + self.orphans
   1630 # List of indices that should be ignored when passing the arguments
   1631 # (basically, everything that the previous call to uniq eliminated)

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:1221, in module_from_key()
   1219 # Is the source code already in the cache?
   1220 module_hash = get_module_hash(src_code, key)
-> 1221 module = self._get_from_hash(module_hash, key)
   1222 if module is not None:
   1223     return module

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:1126, in _get_from_hash()
   1124 with lock_ctx():
   1125     try:
-> 1126         key_data.add_key(key, save_pkl=bool(key[0]))
   1127         key_broken = False
   1128     except pickle.PicklingError:

File /home1/datawork/aponte/miniconda3/envs/nliwave/lib/python3.10/site-packages/pytensor/link/c/cmodule.py:554, in add_key()
    549 def add_key(self, key, save_pkl=True):
    550     """
    551     Add a key to self.keys, and update pickled file if asked to.
    552 
    553     """
--> 554     assert key not in self.keys
    555     self.keys.add(key)
    556     if save_pkl:

AssertionError: 
<\details>

PyMC version information:

pymc='5.17.0' and pytensor='2.25.5', installed with conda

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions