1- import torch
21
2+ from typing import Tuple
3+ import torch
4+ from ...utils .types import unpack_tensor_tuple , pack_tensor_in_list
35
46__all__ = ["Sampler" ]
57
68
79class Sampler (torch .nn .Module ):
8-
9- def __init__ (self , ** kwargs ):
10+ """Abstract base class for samplers.
11+
12+ Parameters
13+ ----------
14+ return_hook : Callable, optional
15+ A function to postprocess the samples. This can (for example) be used to
16+ only return samples at a selected thermodynamic state of a replica exchange sampler
17+ or to combine the batch and sample dimension.
18+ The function takes a list of tensors and should return a list of tensors.
19+ Each tensor contains a batch of samples.
20+ """
21+
22+ def __init__ (self , return_hook = lambda x : x , ** kwargs ):
1023 super ().__init__ (** kwargs )
24+ self .return_hook = return_hook
1125
1226 def _sample_with_temperature (self , n_samples , temperature , * args , ** kwargs ):
1327 raise NotImplementedError ()
@@ -16,7 +30,39 @@ def _sample(self, n_samples, *args, **kwargs):
1630 raise NotImplementedError ()
1731
1832 def sample (self , n_samples , temperature = 1.0 , * args , ** kwargs ):
33+ """Create a number of samples.
34+
35+ Parameters
36+ ----------
37+ n_samples : int
38+ The number of samples to be created.
39+ temperature : float, optional
40+ The relative temperature at which to create samples.
41+ Only available for sampler that implement `_sample_with_temperature`.
42+
43+ Returns
44+ -------
45+ samples : Union[torch.Tensor, Tuple[torch.Tensor, ...]]
46+ If this sampler reflects a joint distribution of multiple tensors,
47+ it returns a tuple of tensors, each of which have length n_samples.
48+ Otherwise it returns a single tensor of length n_samples.
49+ """
1950 if isinstance (temperature , float ) and temperature == 1.0 :
20- return self ._sample (n_samples , * args , ** kwargs )
51+ samples = self ._sample (n_samples , * args , ** kwargs )
2152 else :
22- return self ._sample_with_temperature (n_samples , temperature , * args , ** kwargs )
53+ samples = self ._sample_with_temperature (n_samples , temperature , * args , ** kwargs )
54+ samples = pack_tensor_in_list (samples )
55+ return unpack_tensor_tuple (self .return_hook (samples ))
56+
57+ def sample_to_cpu (self , n_samples , batch_size = 64 , * args , ** kwargs ):
58+ """A utility method for creating many samples that might not fit into GPU memory."""
59+ with torch .no_grad ():
60+ samples = self .sample (min (n_samples , batch_size ), * args , ** kwargs )
61+ samples = pack_tensor_in_list (samples )
62+ samples = [tensor .detach ().cpu () for tensor in samples ]
63+ while len (samples [0 ]) < n_samples :
64+ new_samples = self .sample (min (n_samples - len (samples [0 ]), batch_size ), * args , ** kwargs )
65+ new_samples = pack_tensor_in_list (new_samples )
66+ for i , new in enumerate (new_samples ):
67+ samples [i ] = torch .cat ([samples [i ], new .detach ().cpu ()], dim = 0 )
68+ return unpack_tensor_tuple (samples )
0 commit comments