|
21 | 21 | import warnings
|
22 | 22 |
|
23 | 23 | from abc import ABC
|
24 |
| -from typing import List |
25 | 24 |
|
26 | 25 | import aesara.tensor as at
|
27 | 26 | import numpy as np
|
28 | 27 |
|
29 |
| -from pymc.backends.report import SamplerReport, merge_reports |
| 28 | +from pymc.backends.report import SamplerReport |
30 | 29 | from pymc.model import modelcontext
|
31 | 30 | from pymc.util import get_var_name
|
32 | 31 |
|
@@ -570,43 +569,6 @@ def points(self, chains=None):
|
570 | 569 | return itl.chain.from_iterable(self._straces[chain] for chain in chains)
|
571 | 570 |
|
572 | 571 |
|
573 |
| -def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace: |
574 |
| - """Merge MultiTrace objects. |
575 |
| -
|
576 |
| - Parameters |
577 |
| - ---------- |
578 |
| - mtraces: list of MultiTraces |
579 |
| - Each instance should have unique chain numbers. |
580 |
| -
|
581 |
| - Raises |
582 |
| - ------ |
583 |
| - A ValueError is raised if any traces have overlapping chain numbers, |
584 |
| - or if chains are of different lengths. |
585 |
| -
|
586 |
| - Returns |
587 |
| - ------- |
588 |
| - A MultiTrace instance with merged chains |
589 |
| - """ |
590 |
| - if len(mtraces) == 0: |
591 |
| - raise ValueError("Cannot merge an empty set of traces.") |
592 |
| - base_mtrace = mtraces[0] |
593 |
| - chain_len = len(base_mtrace) |
594 |
| - # check base trace |
595 |
| - if any( |
596 |
| - len(st) != chain_len for _, st in base_mtrace._straces.items() |
597 |
| - ): # pylint: disable=line-too-long |
598 |
| - raise ValueError("Chains are of different lengths.") |
599 |
| - for new_mtrace in mtraces[1:]: |
600 |
| - for new_chain, strace in new_mtrace._straces.items(): |
601 |
| - if new_chain in base_mtrace._straces: |
602 |
| - raise ValueError("Chains are not unique.") |
603 |
| - if len(strace) != chain_len: |
604 |
| - raise ValueError("Chains are of different lengths.") |
605 |
| - base_mtrace._straces[new_chain] = strace |
606 |
| - base_mtrace._report = merge_reports([trace.report for trace in mtraces]) |
607 |
| - return base_mtrace |
608 |
| - |
609 |
| - |
610 | 572 | def _squeeze_cat(results, combine, squeeze):
|
611 | 573 | """Squeeze and concatenate the results depending on values of
|
612 | 574 | `combine` and `squeeze`."""
|
|
0 commit comments