@@ -87,9 +87,9 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
8787
8888 - `resample=True`
8989 - the MCSampler has no `base_samples` attribute.
90- - `shape` is different than `self.base_samples.shape`.
91- - device and/or dtype of posterior are different than those of
92- `self.base_samples`.
90+ - `shape` is different than `self.base_samples.shape` (if
91+ `collapse_batch_dims=True`, then batch dimensions of will be
92+ automatically broadcasted as necessary)
9393
9494 Args:
9595 posterior: The Posterior for which to generate base samples.
@@ -137,11 +137,12 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
137137
138138 This function will generate a new set of base samples and set the
139139 `base_samples` buffer if one of the following is true:
140- - `resample=True`
141- - the MCSampler has no `base_samples` attribute.
142- - `shape` is different than `self.base_samples.shape`.
143- - device and/or dtype of posterior ar different than those of
144- `self.base_samples`.
140+
141+ - `resample=True`
142+ - the MCSampler has no `base_samples` attribute.
143+ - `shape` is different than `self.base_samples.shape` (if
144+ `collapse_batch_dims=True`, then batch dimensions of will be
145+ automatically broadcasted as necessary)
145146
146147 Args:
147148 posterior: The Posterior for which to generate base samples.
@@ -150,16 +151,21 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
150151 if (
151152 self .resample
152153 or not hasattr (self , "base_samples" )
153- or self .base_samples .shape != shape
154- or self .base_samples .device != posterior .device
155- or self .base_samples .dtype != posterior .dtype
154+ or self .base_samples .shape [- 2 :] != shape [- 2 :]
155+ or (not self .collapse_batch_dims and shape != self .base_samples .shape )
156156 ):
157157 with manual_seed (seed = self .seed ):
158158 base_samples = torch .randn (
159159 shape , device = posterior .device , dtype = posterior .dtype
160160 )
161161 self .seed += 1
162162 self .register_buffer ("base_samples" , base_samples )
163+ elif self .collapse_batch_dims and shape != self .base_samples .shape :
164+ self .base_samples = self .base_samples .view (shape )
165+ if self .base_samples .device != posterior .device :
166+ self .to (device = posterior .device ) # pragma: nocover
167+ if self .base_samples .dtype != posterior .dtype :
168+ self .to (dtype = posterior .dtype )
163169
164170
165171class SobolQMCNormalSampler (MCSampler ):
@@ -201,11 +207,12 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
201207
202208 This function will generate a new set of base samples and set the
203209 `base_samples` buffer if one of the following is true:
204- - `resample=True`
205- - the MCSampler has no `base_samples` attribute.
206- - `self.sample_shape` is different than `self.base_samples.shape`.
207- - device and/or dtype of posterior ar different than those of
208- `self.base_samples`.
210+
211+ - `resample=True`
212+ - the MCSampler has no `base_samples` attribute.
213+ - `shape` is different than `self.base_samples.shape` (if
214+ `collapse_batch_dims=True`, then batch dimensions of will be
215+ automatically broadcasted as necessary)
209216
210217 Args:
211218 posterior: The Posterior for which to generate base samples.
@@ -214,9 +221,8 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
214221 if (
215222 self .resample
216223 or not hasattr (self , "base_samples" )
217- or self .base_samples .shape != shape
218- or self .base_samples .device != posterior .device
219- or self .base_samples .dtype != posterior .dtype
224+ or self .base_samples .shape [- 2 :] != shape [- 2 :]
225+ or (not self .collapse_batch_dims and shape != self .base_samples .shape )
220226 ):
221227 output_dim = shape [- 2 :].numel ()
222228 if output_dim > SobolEngine .MAXDIM :
@@ -234,3 +240,9 @@ def _construct_base_samples(self, posterior: Posterior, shape: torch.Size) -> No
234240 self .seed += 1
235241 base_samples = base_samples .view (shape )
236242 self .register_buffer ("base_samples" , base_samples )
243+ elif self .collapse_batch_dims and shape != posterior .event_shape :
244+ self .base_samples = self .base_samples .view (shape )
245+ if self .base_samples .device != posterior .device :
246+ self .to (device = posterior .device ) # pragma: nocover
247+ if self .base_samples .dtype != posterior .dtype :
248+ self .to (dtype = posterior .dtype )
0 commit comments