diff --git a/funsor/tensor.py b/funsor/tensor.py index a2bac8ff..41770c43 100644 --- a/funsor/tensor.py +++ b/funsor/tensor.py @@ -335,94 +335,58 @@ def _sample(self, sampled_vars, sample_inputs, rng_key): if not sampled_vars: return self - # Partition inputs into sample_inputs + batch_inputs + event_inputs. - sample_inputs = OrderedDict( - (k, d) for k, d in sample_inputs.items() if k not in self.inputs - ) - sample_shape = tuple(int(d.dtype) for d in sample_inputs.values()) + results = [] + backend = get_backend() + remaining_vars = set(sampled_vars) + term = self + shape = tuple(d.size for k, d in sample_inputs.items() if k not in self.inputs) + shape += tuple(d.size for k, d in self.inputs.items() if k not in sampled_vars) batch_inputs = OrderedDict( - (k, d) for k, d in self.inputs.items() if k not in sampled_vars + (k, v) for k, v in sample_inputs.items() if k not in self.inputs ) - event_inputs = OrderedDict( - (k, d) for k, d in self.inputs.items() if k in sampled_vars + batch_inputs.update( + (k, v) for k, v in self.inputs.items() if k not in sampled_vars ) - be_inputs = batch_inputs.copy() - be_inputs.update(event_inputs) - sb_inputs = sample_inputs.copy() - sb_inputs.update(batch_inputs) - - # Sample all variables in a single Categorical call. - logits = align_tensor(be_inputs, self) - batch_shape = logits.shape[: len(batch_inputs)] - flat_logits = logits.reshape(batch_shape + (-1,)) - sample_shape = tuple(d.dtype for d in sample_inputs.values()) + while remaining_vars: + name = remaining_vars.pop() + domain = self.inputs[name] + logits = funsor.Lambda( + Variable(name, domain), term.reduce(ops.logaddexp, remaining_vars) + ) - backend = get_backend() - if backend != "numpy": - from importlib import import_module + if backend != "numpy": + from importlib import import_module - dist = import_module( - funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend] - ) - sample_args = ( - (sample_shape,) if rng_key is None else (rng_key, sample_shape) - ) - flat_sample = dist.CategoricalLogits.dist_class(logits=flat_logits).sample( - *sample_args - ) - else: # default numpy backend - assert backend == "numpy" - shape = sample_shape + flat_logits.shape[:-1] - logit_max = np.amax(flat_logits, -1, keepdims=True) - probs = np.exp(flat_logits - logit_max) - probs = probs / np.sum(probs, -1, keepdims=True) - s = np.cumsum(probs, -1) - r = np.random.rand(*shape) - flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1) - - assert flat_sample.shape == sample_shape + batch_shape - results = [] - mod_sample = flat_sample - for name, domain in reversed(list(event_inputs.items())): - size = domain.dtype - point = Tensor(mod_sample % size, sb_inputs, size) - mod_sample = mod_sample // size - results.append(Delta(name, point)) - - # Account for the log normalizer factor. - # Derivation: Let f be a nonnormalized distribution (a funsor), and - # consider operations in linear space (source code is in log space). - # Let x0 ~ f/|f| be a monte carlo sample from a normalized f/|f|. - # f(x0) / |f| # dice numerator - # Let g = delta(x=x0) |f| ----------------- - # detach(f(x0)/|f|) # dice denominator - # |detach(f)| f(x0) - # = delta(x=x0) ----------------- be a dice approximation of f. - # detach(f(x0)) - # Then g is an unbiased estimator of f in value and all derivatives. - # In the special case f = detach(f), we can simplify to - # g = delta(x=x0) |f|. - if (backend == "torch" and flat_logits.requires_grad) or backend == "jax": - # Apply a dice factor to preserve differentiability. - index = [ - ops.new_arange(self.data, n).reshape( - (n,) + (1,) * (len(flat_logits.shape) - i - 2) + dist = import_module( + funsor.distribution.BACKEND_TO_DISTRIBUTIONS_BACKEND[backend] ) - for i, n in enumerate(flat_logits.shape[:-1]) - ] - index.append(flat_sample) - log_prob = flat_logits[tuple(index)] - assert log_prob.shape == flat_sample.shape - results.append( - Tensor( - ops.logsumexp(ops.detach(flat_logits), -1) - + (log_prob - ops.detach(log_prob)), - sb_inputs, + sample_inputs = OrderedDict( + (k, v) for k, v in sample_inputs.items() if k not in logits.inputs ) - ) - else: - # This is the special case f = detach(f). - results.append(Tensor(ops.logsumexp(flat_logits, -1), batch_inputs)) + delta = dist.CategoricalLogits(logits=logits, value=name)._sample( + frozenset({name}), sample_inputs, rng_key + ) + point = delta.terms[0][1][0] + log_density = delta.terms[0][1][1] + self.reduce( + ops.logaddexp, sampled_vars + ) / len(sampled_vars) + term = term(**{name: point}) + sample = Delta(name, point, log_density) + results.append(sample) + else: # default numpy backend + assert backend == "numpy" + probs = (logits - ops.logsumexp(logits)).exp().data + # shape = sample_shape + flat_logits.shape[:-1] + # logit_max = np.amax(flat_logits, -1, keepdims=True) + # probs = np.exp(flat_logits - logit_max) + # probs = probs / np.sum(probs, -1, keepdims=True) + s = np.cumsum(probs, -1) + r = np.random.rand(*shape) + flat_sample = np.sum(s < np.expand_dims(r, -1), axis=-1) + point = Tensor(flat_sample, batch_inputs, domain.dtype) + term = term(**{name: point}) + sample = Delta(name, point) + results.append(sample) return reduce(ops.add, results)