|
58 | 58 |
|
59 | 59 | __all__ = [
|
60 | 60 | "sample",
|
61 |
| - "iter_sample", |
62 | 61 | "init_nuts",
|
63 | 62 | ]
|
64 | 63 |
|
@@ -770,73 +769,6 @@ def _sample(
|
770 | 769 | return strace
|
771 | 770 |
|
772 | 771 |
|
773 |
| -def iter_sample( |
774 |
| - draws: int, |
775 |
| - step, |
776 |
| - start: PointType, |
777 |
| - trace=None, |
778 |
| - chain: int = 0, |
779 |
| - tune: int = 0, |
780 |
| - model: Optional[Model] = None, |
781 |
| - random_seed: RandomSeed = None, |
782 |
| - callback=None, |
783 |
| -) -> Iterator[MultiTrace]: |
784 |
| - """Generate a trace on each iteration using the given step method. |
785 |
| -
|
786 |
| - Multiple step methods ared supported via compound step methods. Returns the |
787 |
| - amount of time taken. |
788 |
| -
|
789 |
| - Parameters |
790 |
| - ---------- |
791 |
| - draws : int |
792 |
| - The number of samples to draw |
793 |
| - step : function |
794 |
| - Step function |
795 |
| - start : dict |
796 |
| - Starting point in parameter space (or partial point). |
797 |
| - trace : backend or list |
798 |
| - This should be a backend instance, or a list of variables to track. |
799 |
| - If None or a list of variables, the NDArray backend is used. |
800 |
| - chain : int, optional |
801 |
| - Chain number used to store sample in backend. |
802 |
| - tune : int, optional |
803 |
| - Number of iterations to tune (defaults to 0). |
804 |
| - model : Model (optional if in ``with`` context) |
805 |
| - random_seed : single random seed, optional |
806 |
| - callback : |
807 |
| - A function which gets called for every sample from the trace of a chain. The function is |
808 |
| - called with the trace and the current draw and will contain all samples for a single trace. |
809 |
| - the ``draw.chain`` argument can be used to determine which of the active chains the sample |
810 |
| - is drawn from. |
811 |
| - Sampling can be interrupted by throwing a ``KeyboardInterrupt`` in the callback. |
812 |
| -
|
813 |
| - Yields |
814 |
| - ------ |
815 |
| - trace : MultiTrace |
816 |
| - Contains all samples up to the current iteration |
817 |
| -
|
818 |
| - Examples |
819 |
| - -------- |
820 |
| - :: |
821 |
| -
|
822 |
| - for trace in iter_sample(500, step): |
823 |
| - ... |
824 |
| - """ |
825 |
| - sampling = _iter_sample( |
826 |
| - draws=draws, |
827 |
| - step=step, |
828 |
| - start=start, |
829 |
| - trace=trace, |
830 |
| - chain=chain, |
831 |
| - tune=tune, |
832 |
| - model=model, |
833 |
| - random_seed=random_seed, |
834 |
| - callback=callback, |
835 |
| - ) |
836 |
| - for i, (strace, _) in enumerate(sampling): |
837 |
| - yield MultiTrace([strace[: i + 1]]) |
838 |
| - |
839 |
| - |
840 | 772 | def _iter_sample(
|
841 | 773 | *,
|
842 | 774 | draws: int,
|
|
0 commit comments