Skip to content

Commit 119c8ab

Browse files
authored
Merge pull request #179 from RacimoLab/generator_func_v
cleanups
2 parents e11ef8e + 18c3b12 commit 119c8ab

File tree

15 files changed

+101
-109
lines changed

15 files changed

+101
-109
lines changed

dinf/__init__.py

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,28 @@
66
except ImportError:
77
pass
88

9-
import os
10-
11-
if "TF_CPP_MIN_LOG_LEVEL" not in os.environ:
12-
# Mute tensorflow/xla.
13-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
14-
15-
if "KMP_AFFINITY" not in os.environ:
16-
# Pin threads to cpus. This can improve blas performance.
17-
os.environ["KMP_AFFINITY"] = "granularity=fine,noverbose,compact,1,0"
18-
19-
# https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
20-
# if "XLA_PYTHON_CLIENT_PREALLOCATE" not in os.environ:
21-
# os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
22-
# if "XLA_PYTHON_CLIENT_ALLOCATOR" not in os.environ:
23-
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
24-
25-
9+
from .misc import ts_individuals
10+
from .store import Store as Store
11+
from .parameters import Param, Parameters
12+
from .dinf_model import DinfModel
13+
from .vcf import (
14+
BagOfVcf,
15+
get_contig_lengths,
16+
get_samples_from_1kgp_metadata,
17+
)
18+
from .feature_extractor import (
19+
HaplotypeMatrix,
20+
MultipleHaplotypeMatrices,
21+
BinnedHaplotypeMatrix,
22+
MultipleBinnedHaplotypeMatrices,
23+
)
24+
from .discriminator import (
25+
Discriminator,
26+
Surrogate,
27+
ExchangeableCNN,
28+
ExchangeablePGGAN,
29+
Symmetric,
30+
)
2631
from .dinf import (
2732
abc_gan,
2833
alfi_mcmc_gan,
@@ -32,49 +37,5 @@
3237
train,
3338
save_results,
3439
load_results,
40+
sample_smooth,
3541
)
36-
from .discriminator import (
37-
Discriminator,
38-
Surrogate,
39-
ExchangeableCNN,
40-
ExchangeablePGGAN,
41-
Symmetric,
42-
)
43-
from .feature_extractor import (
44-
HaplotypeMatrix,
45-
MultipleHaplotypeMatrices,
46-
BinnedHaplotypeMatrix,
47-
MultipleBinnedHaplotypeMatrices,
48-
)
49-
from .dinf_model import DinfModel
50-
from .parameters import Param, Parameters
51-
from .store import Store
52-
from .vcf import BagOfVcf, get_contig_lengths, get_samples_from_1kgp_metadata
53-
54-
__all__ = [
55-
"__version__",
56-
"BagOfVcf",
57-
"BinnedHaplotypeMatrix",
58-
"Discriminator",
59-
"ExchangeableCNN",
60-
"ExchangeablePGGAN",
61-
"MultipleBinnedHaplotypeMatrices",
62-
"MultipleHaplotypeMatrices",
63-
"DinfModel",
64-
"HaplotypeMatrix",
65-
"Param",
66-
"Parameters",
67-
"Surrogate",
68-
"Store",
69-
"Symmetric",
70-
"get_contig_lengths",
71-
"get_samples_from_1kgp_metadata",
72-
"abc_gan",
73-
"alfi_mcmc_gan",
74-
"mcmc_gan",
75-
"load_results",
76-
"pg_gan",
77-
"predict",
78-
"save_results",
79-
"train",
80-
]

dinf/dinf.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def _process_pool_init(parallelism, dinf_model):
5151
_pool = ctx.Pool(
5252
processes=parallelism,
5353
initializer=_initializer,
54-
initargs=(dinf_model._filename,),
54+
initargs=(dinf_model.filename,),
5555
)
5656

5757

@@ -408,7 +408,7 @@ def _train_discriminator(
408408
)
409409
train_x, train_y, train_x_generator = _generate_training_data(
410410
target=dinf_model.target_func,
411-
generator=dinf_model.generator_func,
411+
generator=dinf_model.generator_func_v,
412412
thetas=training_thetas,
413413
parallelism=parallelism,
414414
ss=ss_train,
@@ -418,7 +418,7 @@ def _train_discriminator(
418418
if test_thetas is not None and len(test_thetas) > 0:
419419
val_x, val_y, val_x_generator = _generate_training_data(
420420
target=dinf_model.target_func,
421-
generator=dinf_model.generator_func,
421+
generator=dinf_model.generator_func_v,
422422
thetas=test_thetas,
423423
parallelism=parallelism,
424424
ss=ss_val,
@@ -593,7 +593,7 @@ def predict(
593593
replicates, rng=np.random.default_rng(ss_thetas)
594594
)
595595
x = _generate_data(
596-
generator=dinf_model.generator_func,
596+
generator=dinf_model.generator_func_v,
597597
thetas=thetas,
598598
parallelism=parallelism,
599599
rng=np.random.default_rng(ss_generator),
@@ -734,7 +734,7 @@ def mcmc_gan(
734734
log_prob_func = functools.partial(
735735
_log_prob,
736736
discriminator=discriminator,
737-
generator=dinf_model.generator_func,
737+
generator=dinf_model.generator_func_v,
738738
parameters=parameters,
739739
parallelism=parallelism,
740740
num_replicates=Dx_replicates,
@@ -815,10 +815,10 @@ def sample_smooth(
815815
"""
816816
Sample from a smoothed set of weighted observations.
817817
818-
Samples are drawn from the thetas, weighted by their probability.
818+
Samples are drawn from ``thetas``, weighted by their probability.
819819
New points are drawn within a neighbourhood of the sampled thetas
820820
using a mulivariate normal whose covariance is calculated from the
821-
thetas. This is effectively sampling from a Gaussian KDE, but
821+
thetas. This is equivalent to sampling from a Gaussian KDE, but
822822
avoids doing an explicit density estimation.
823823
Scott's rule of thumb is used for bandwidth selection.
824824
@@ -843,11 +843,14 @@ def sample_smooth(
843843
* "transform": thetas are transformed before sampling, and
844844
the sampled values are inverse-transformed before being
845845
returned.
846+
See :meth:`Parameters.transform` and :meth:`Parameters.itransform`.
846847
* "truncate": sampled values are truncated at the parameter limits.
848+
See :meth:`Parameters.truncate`.
847849
* "reflect": sample values that are out of bounds are reflected
848850
inside the parameter limits by the same magnitude that they were
849851
out of bounds. Values that are too far out of bounds to be
850852
reflected are truncated at the parameter limits.
853+
See :meth:`Parameters.reflect`.
851854
852855
:return:
853856
The sampled values.
@@ -1172,7 +1175,7 @@ def pretraining_dinf(
11721175
lp = _log_prob(
11731176
thetas,
11741177
discriminator=discriminator,
1175-
generator=dinf_model.generator_func,
1178+
generator=dinf_model.generator_func_v,
11761179
parameters=parameters,
11771180
num_replicates=1,
11781181
parallelism=parallelism,
@@ -1476,7 +1479,7 @@ def pg_gan(
14761479
lp = _log_prob(
14771480
proposal_thetas,
14781481
discriminator=discriminator,
1479-
generator=dinf_model.generator_func,
1482+
generator=dinf_model.generator_func_v,
14801483
parameters=parameters,
14811484
num_replicates=Dx_replicates,
14821485
parallelism=parallelism,

dinf/dinf_model.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def generator_func3(
166166
Function that simulates features using concrete parameter values.
167167
"""
168168

169+
generator_func_v: Callable = dataclasses.field(init=False)
170+
"""
171+
Wrapper for ``generator_func`` that accepts a single argument containing
172+
the seed and a vector of parameter values (as opposed to keyword arguments).
173+
The signature is ``generator_func_v(a: Tuple[int, v: np.ndarray])``,
174+
where the argument is a 2-tuple of ``(seed, vector)``.
175+
"""
176+
169177
target_func: Callable | None
170178
"""
171179
Function that samples features from the target distribution.
@@ -182,6 +190,12 @@ def generator_func3(
182190
A :doc:`flax <flax:index>` neural network. May be ``None``.
183191
"""
184192

193+
filename: pathlib.Path | None = dataclasses.field(init=False, default=None)
194+
"""
195+
Path to the file from which the model was loaded (if any).
196+
May be ``None``.
197+
"""
198+
185199
def __post_init__(self):
186200
if len(self.parameters) == 0:
187201
raise ValueError("Must define one or more parameters")
@@ -203,12 +217,12 @@ def __post_init__(self):
203217
# Transform generator_func from a function accepting arbitrary kwargs
204218
# (which limits user error) into a function accepting a sequence of
205219
# args (which is easier to pass to the mcmc).
206-
f = self.generator_func
207-
self.generator_func = functools.update_wrapper(
208-
functools.partial(_sim_shim, func=f, keys=tuple(self.parameters)), f
220+
self.generator_func_v = functools.update_wrapper(
221+
functools.partial(
222+
_sim_shim, func=self.generator_func, keys=tuple(self.parameters)
223+
),
224+
self.generator_func,
209225
)
210-
self._orig_generator_func = f
211-
self._filename = None
212226

213227
def check(self, seed=None):
214228
"""
@@ -229,7 +243,7 @@ def check(self, seed=None):
229243
f"{thetas.shape}, expected shape {(5, len(self.parameters))}."
230244
)
231245

232-
x_g = self.generator_func((rng.integers(low=0, high=2**31), thetas[0]))
246+
x_g = self.generator_func_v((rng.integers(low=0, high=2**31), thetas[0]))
233247
if not tree_equal(tree_shape(x_g), self.feature_shape):
234248
raise ValueError(
235249
f"generator_func produced feature shape {tree_shape(x_g)}, "
@@ -269,5 +283,5 @@ def from_file(filename: str | pathlib.Path) -> DinfModel:
269283
raise AttributeError(f"{filename}: variable 'dinf_model' not found")
270284
if not isinstance(dinf_model, DinfModel):
271285
raise TypeError(f"{filename}: dinf_model is not a dinf.DinfModel object")
272-
dinf_model._filename = filename
286+
dinf_model.filename = pathlib.Path(filename)
273287
return dinf_model

dinf/misc.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@
1414

1515

1616
def ts_individuals(
17-
ts: tskit.TreeSequence, population: str | int | None = None
17+
ts: tskit.TreeSequence,
18+
/,
19+
population: str | int | None = None,
1820
) -> npt.NDArray[np.integer]:
1921
"""
2022
Get the individuals corresponding to the tree sequence's samples.
@@ -41,7 +43,9 @@ def ts_individuals(
4143

4244

4345
def ts_nodes_of_individuals(
44-
ts: tskit.TreeSequence, individuals: npt.NDArray[np.integer]
46+
ts: tskit.TreeSequence,
47+
/,
48+
individuals: npt.NDArray[np.integer],
4549
) -> npt.NDArray[np.integer]:
4650
"""
4751
Get the nodes for the individuals.
@@ -57,7 +61,9 @@ def ts_nodes_of_individuals(
5761

5862

5963
def ts_ploidy_of_individuals(
60-
ts: tskit.TreeSequence, individuals: npt.NDArray[np.integer]
64+
ts: tskit.TreeSequence,
65+
/,
66+
individuals: npt.NDArray[np.integer],
6167
) -> npt.NDArray[np.integer]:
6268
"""
6369
Get the ploidy of the individuals.
@@ -148,7 +154,7 @@ def cache(path: str | pathlib.Path, /, *, split: int = 1000):
148154
"""
149155
A decorator to cache the output of generator and/or target functions.
150156
151-
This is analogous to {func}`functools.cache`, except each function's
157+
This is analogous to :func:`functools.cache`, except each function's
152158
result is stored in a file under the given directory. Caching can create
153159
a large number of small files, so the files are split into subdirectories
154160
to mitigate possible problems.
@@ -201,7 +207,7 @@ def sqlite_cache(db_file: str | pathlib.Path, shape, /):
201207
"""
202208
A decorator for generator or target functions that caches features to disk.
203209
204-
This is analogous to {func}`functools.cache`, except the cache is
210+
This is analogous to :func:`functools.cache`, except the cache is
205211
persisted to disk in an sqlite database.
206212
207213
.. warning::

dinf/parameters.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,11 @@ def reflect(self, x: np.ndarray, /) -> np.ndarray:
124124
"""
125125
Reflect values that are out of bounds by the amount they are out.
126126
127-
Values that are too far out of bounds to be reflected are truncated.
127+
As reflecting does not gaurantee values will be within the bounds,
128+
values are first truncated to (2*low - high, 2*high - low),
129+
then reflected. For example, with bounds low=0, high=10,
130+
a value of -11 will be truncated to -10, then reflected to attain
131+
a final value of 10.
128132
129133
:param x:
130134
The values to be reflected.
@@ -288,7 +292,11 @@ def reflect(self, xs: np.ndarray, /) -> np.ndarray:
288292
"""
289293
Reflect values that are out of bounds by the amount they are out.
290294
291-
Values that are too far out of bounds to be reflected are truncated.
295+
As reflecting does not gaurantee values will be within the bounds,
296+
values are first truncated to (2*low - high, 2*high - low),
297+
then reflected. For example, with bounds low=0, high=10,
298+
a value of -11 will be truncated to -10, then reflected to attain
299+
a final value of 10.
292300
293301
:param xs:
294302
The values to be reflected.

dinf/plot.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515

1616
from .cli import ADRDFormatter, _DINF_MODEL_HELP
1717
import dinf
18-
from .dinf import sample_smooth
1918

2019

2120
class MultiPage:
@@ -694,7 +693,7 @@ def __call__(self, args: argparse.Namespace):
694693

695694
class _Features(_SubCommand):
696695
"""
697-
Plot a feature matrix or matrices as heatmaps.
696+
Plot a feature matrices as heatmaps.
698697
699698
By default, one simulation will be performed with the generator to obtain
700699
a set of features for plotting. To instead extract features from the
@@ -721,7 +720,7 @@ def __call__(self, args: argparse.Namespace):
721720
else:
722721
rng = np.random.default_rng(args.seed)
723722
thetas = dinf_model.parameters.draw_prior(1, rng=rng)
724-
mats = dinf_model.generator_func(
723+
mats = dinf_model.generator_func_v(
725724
(rng.integers(low=0, high=2**31), thetas[0])
726725
)
727726

@@ -1042,7 +1041,7 @@ def __call__(self, args: argparse.Namespace):
10421041
names = list(data.dtype.names)
10431042
probs = data["_Pr"]
10441043
thetas = structured_to_unstructured(data[names[1:]])
1045-
X = sample_smooth(
1044+
X = dinf.sample_smooth(
10461045
thetas=thetas,
10471046
probs=probs,
10481047
size=1_000_000,

0 commit comments

Comments
 (0)