diff --git a/tests/optimizer/linear/test_gbp_linear_solver.py b/tests/optimizer/linear/test_gbp_linear_solver.py new file mode 100644 index 000000000..09e04c484 --- /dev/null +++ b/tests/optimizer/linear/test_gbp_linear_solver.py @@ -0,0 +1,171 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import theseus as th + + +""" +Build linear 1D surface estimation problem. +Solve using GBP and using matrix inversion and compare answers. +GBP exactly computes the marginal means on convergence. + +All the following cases should not affect the converged solution: +- with / without vectorization +- with / without factor to variable message damping +- with / without dropout +- with / without factor linear system damping +""" + + +def _check_info(info, batch_size, max_iterations, initial_error, objective): + assert info.err_history.shape == (batch_size, max_iterations + 1) + assert info.err_history[:, 0].allclose(initial_error) + assert info.err_history.argmin(dim=1).allclose(info.best_iter + 1) + last_error = objective.error_squared_norm() / 2 + last_convergence_idx = info.converged_iter.max().item() + assert info.err_history[:, last_convergence_idx].allclose(last_error) + + +def run_gbp_linear_solver( + frac_loops, + vectorize=True, + ftov_damping=0.0, + dropout=0.0, + lin_system_damping=torch.tensor([1e-4]), +): + max_iterations = 200 + + n_variables = 100 + batch_size = 1 + + torch.manual_seed(0) + + # initial input tensors + # measurements come from x = sin(t / 50) * t**2 / 250 + 1 with random noise added + ts = torch.arange(n_variables) + true_meas = torch.sin(ts / 10.0) * ts * ts / 250.0 + 1 + noisy_meas = true_meas[None, :].repeat(batch_size, 1) + noisy_meas += torch.normal(torch.zeros_like(noisy_meas), 1.0) + + variables = [] + meas_vars = [] + for i in range(n_variables): + variables.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"x_{i}")) + meas_vars.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"meas_x{i}")) + + objective = th.Objective() + + # measurement cost functions + meas_weight = th.ScaleCostWeight(5.0, name="meas_weight") + for var, meas in zip(variables, meas_vars): + objective.add(th.Difference(var, meas, meas_weight)) + + # smoothness cost functions between adjacent variables + smoothness_weight = th.ScaleCostWeight(2.0, name="smoothness_weight") + zero = th.Vector(tensor=torch.zeros(batch_size, 1), name="zero") + for i in range(n_variables - 1): + objective.add( + th.Between(variables[i], variables[i + 1], zero, smoothness_weight) + ) + + # difference cost functions between non-adjacent variables to give + # off diagonal elements in information matrix + difference_weight = th.ScaleCostWeight(1.0, name="difference_weight") + for i in range(int(n_variables * frac_loops)): + ix1, ix2 = torch.randint(n_variables, (2,)) + diff = th.Vector( + tensor=torch.tensor([[true_meas[ix2] - true_meas[ix1]]]), name=f"diff{i}" + ) + diff.tensor += torch.normal(torch.zeros(1, 1), 0.2) + objective.add( + th.Between(variables[ix1], variables[ix2], diff, difference_weight) + ) + + input_tensors = {} + for var in variables: + input_tensors[var.name] = var.tensor + for i in range(len(noisy_meas[0])): + input_tensors[f"meas_x{i}"] = noisy_meas[:, i][:, None] + + # Solve with GBP + optimizer = th.GaussianBeliefPropagation( + objective, max_iterations=max_iterations, vectorize=vectorize + ) + optimizer.set_params(max_iterations=max_iterations) + objective.update(input_tensors) + initial_error = objective.error_squared_norm() / 2 + + callback_expected_iter = [0] + + def callback(opt_, info_, _, it_): + assert opt_ is optimizer + assert isinstance(info_, th.optimizer.OptimizerInfo) + assert it_ == callback_expected_iter[0] + callback_expected_iter[0] += 1 + + info = optimizer.optimize( + track_best_solution=True, + track_err_history=True, + end_iter_callback=callback, + ftov_msg_damping=ftov_damping, + dropout=dropout, + lin_system_damping=lin_system_damping, + verbose=True, + ) + gbp_solution = [var.tensor.clone() for var in variables] + + # Solve with linear solver + objective.update(input_tensors) + linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) + linear_optimizer.optimize(verbose=True) + lin_solution = [var.tensor.clone() for var in variables] + + # Solve with Gauss-Newton + # If problem is poorly conditioned solving with Gauss-Newton can yield + # a slightly different solution to one linear solve, so check both + objective.update(input_tensors) + gn_optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver) + gn_optimizer.optimize(verbose=True) + gn_solution = [var.tensor.clone() for var in variables] + + # checks + for x, x_target in zip(gbp_solution, lin_solution): + assert x.allclose(x_target, rtol=1e-3) + for x, x_target in zip(gbp_solution, gn_solution): + assert x.allclose(x_target, rtol=1e-3) + _check_info(info, batch_size, max_iterations, initial_error, objective) + + # # Visualise reconstructed surface + # soln_vec = torch.cat(gbp_solution, dim=1)[0] + # import matplotlib.pylab as plt + # plt.scatter(torch.arange(n_variables), soln_vec, label="solution") + # plt.scatter(torch.arange(n_variables), noisy_meas[0], label="meas") + # plt.legend() + # plt.show() + + +def test_gbp_linear_solver(): + + # problems with increasing loopyness + # the loopier the fewer iterations to solve + frac_loops = [0.1, 0.2, 0.5] + for frac in frac_loops: + + run_gbp_linear_solver(frac_loops=frac) + + # with factor to variable message damping, may take too many steps to converge + # run_gbp_linear_solver(vectorize=vectorize, frac_loops=frac, ftov_damping=0.1) + # with dropout + run_gbp_linear_solver(frac_loops=frac, dropout=0.1) + + # test linear system damping + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([0.0])) + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-2])) + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-6])) + + # test without vectorization once + run_gbp_linear_solver(frac_loops=0.5, vectorize=False) diff --git a/theseus/__init__.py b/theseus/__init__.py index 1edae7c3b..bcd93c52d 100644 --- a/theseus/__init__.py +++ b/theseus/__init__.py @@ -66,6 +66,7 @@ ) from .optimizer import ( # usort: skip DenseLinearization, + GaussianBeliefPropagation, Linearization, ManifoldGaussian, OptimizerInfo, diff --git a/theseus/core/objective.py b/theseus/core/objective.py index ebf57a33a..6dc22d5d3 100644 --- a/theseus/core/objective.py +++ b/theseus/core/objective.py @@ -111,6 +111,10 @@ def __init__( # If vectorization is on, this will also handle vectorized containers self._vectorization_to: Optional[Callable] = None + self.vectorized_cost_fns: Optional[List[CostFunction]] = None + # nested list of name of each base cost function in the vectorized cfs + self.vectorized_cf_names: Optional[List[List[str]]] = None + # If vectorization is on, this gets replaced by a vectorized version self._retract_method = Objective._retract_base @@ -682,6 +686,8 @@ def _enable_vectorization( vectorization_run_fn: Callable, vectorized_to: Callable, vectorized_retract_fn: Callable, + vectorized_cost_fns: List[CostFunction], + vectorized_cf_names: List[List[str]], error_iter_fn: Callable[[], Iterable[CostFunction]], enabler: Any, ): @@ -694,6 +700,8 @@ def _enable_vectorization( self._vectorization_run = vectorization_run_fn self._vectorization_to = vectorized_to self._retract_method = vectorized_retract_fn + self.vectorized_cost_fns = vectorized_cost_fns + self.vectorized_cf_names = vectorized_cf_names self._get_error_iter = error_iter_fn self._vectorized = True @@ -703,6 +711,8 @@ def disable_vectorization(self): self._vectorization_run = None self._vectorization_to = None self._retract_method = Objective._retract_base + self.vectorized_cost_fns = None + self.vectorized_cf_names = None self._get_error_iter = self._get_error_iter_base self._vectorized = False @@ -713,6 +723,9 @@ def vectorized(self): == (self._vectorized_jacobians_iter is None) == (self._vectorization_run is None) == (self._vectorization_to is None) + == (self._retract_method is Objective._retract_base) + == (self.vectorized_cost_fns is None) + == (self.vectorized_cf_names is None) == (self._get_error_iter == self._get_error_iter_base) == (self._retract_method == Objective._retract_base) ) diff --git a/theseus/core/vectorizer.py b/theseus/core/vectorizer.py index 50ba74be1..1fe410bde 100644 --- a/theseus/core/vectorizer.py +++ b/theseus/core/vectorizer.py @@ -116,6 +116,8 @@ def __init__(self, objective: Objective, empty_cuda_cache: bool = False): _CostFunctionSchema, List[_CostFunctionWrapper] ] = defaultdict(list) + schema_cf_names_dict: Dict[_CostFunctionSchema, List[str]] = defaultdict(list) + # Create wrappers for all cost functions and also get their schemas for cost_fn in objective.cost_functions.values(): wrapper = _CostFunctionWrapper(cost_fn) @@ -123,6 +125,8 @@ def __init__(self, objective: Objective, empty_cuda_cache: bool = False): schema = _get_cost_function_schema(cost_fn) self._schema_dict[schema].append(wrapper) + schema_cf_names_dict[schema].append(cost_fn.name) + # Now create a vectorized cost function for each unique schema self._vectorized_cost_fns: Dict[_CostFunctionSchema, CostFunction] = {} for schema in self._schema_dict: @@ -146,6 +150,8 @@ def __init__(self, objective: Objective, empty_cuda_cache: bool = False): self._vectorize, self._to, self._vectorized_retract_optim_vars, + list(self._vectorized_cost_fns.values()), + list(schema_cf_names_dict.values()), self._get_vectorized_error_iter, self, ) @@ -391,10 +397,10 @@ def _vectorize( } ret = [cf for cf_list in schema_dict.values() for cf in cf_list] for schema, cost_fn_wrappers in schema_dict.items(): - if len(cost_fn_wrappers) == 1: - self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode) - else: - self._handle_schema_vectorization(schema, cost_fn_wrappers, mode) + # if len(cost_fn_wrappers) == 1: + # self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode) + # else: + self._handle_schema_vectorization(schema, cost_fn_wrappers, mode) return ret def _get_vectorized_error_iter(self) -> Iterable[_CostFunctionWrapper]: diff --git a/theseus/optimizer/__init__.py b/theseus/optimizer/__init__.py index 9860df0f3..34461a7d0 100644 --- a/theseus/optimizer/__init__.py +++ b/theseus/optimizer/__init__.py @@ -9,3 +9,4 @@ from .optimizer import Optimizer, OptimizerInfo from .sparse_linearization import SparseLinearization from .variable_ordering import VariableOrdering +from .gbp import GaussianBeliefPropagation, GBPSchedule diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py new file mode 100644 index 000000000..8c4ad2adb --- /dev/null +++ b/theseus/optimizer/gbp.py @@ -0,0 +1,1008 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import abc + +# import time +from enum import Enum +from itertools import count +from typing import ( + Callable, + Dict, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +import torch + +import theseus as th +import theseus.constants +from theseus.core import CostFunction, Objective +from theseus.geometry import Manifold +from theseus.optimizer import VariableOrdering, ManifoldGaussian +from theseus.optimizer.nonlinear.nonlinear_optimizer import ( + BackwardMode, + NonlinearOptimizer, + NonlinearOptimizerInfo, + NonlinearOptimizerStatus, +) + +""" +TODO +- Remove implicit backward mode with Gauss-Newton, or at least modify it +to make sure it detaches the hessian. + +Summary. +This file contains the factor class used to wrap cost functions for GBP. +Factor to variable message passing functions are within the factor class. +Variable to factor message passing functions are within GBP optimizer class. + + +References for understanding GBP: +- https://gaussianbp.github.io/ +- https://arxiv.org/abs/1910.14139 +Reference for GBP with non-Euclidean variables: +- https://arxiv.org/abs/2202.03314 +""" + + +""" +Utitily functions +""" + +EndIterCallbackType = Callable[ + ["GaussianBeliefPropagation", NonlinearOptimizerInfo, None, int], NoReturn +] + + +class GBPSchedule(Enum): + SYNCHRONOUS = 0 + + +def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: + return torch.full([max_iters, n_edges], True) + + +# def random_schedule(max_iters, n_edges) -> torch.Tensor: +# schedule = torch.full([max_iters, n_edges], False) +# # on first step send messages along all edges +# schedule[0] = True +# ixs = torch.randint(0, n_edges, [max_iters]) +# schedule[torch.arange(max_iters), ixs] = True +# return schedule + + +# GBP message class, messages are Gaussian distributions +# Has additional fn to initialise messages with zero precision +class Message(ManifoldGaussian): + def __init__( + self, + mean: Sequence[Manifold], + precision: Optional[torch.Tensor] = None, + name: Optional[str] = None, + ): + if precision is None: + dof = sum([v.dof() for v in mean]) + precision = torch.zeros(mean[0].shape[0], dof, dof).to( + dtype=mean[0].dtype, device=mean[0].device + ) + super(Message, self).__init__(mean, precision=precision, name=name) + + # sets mean to the group identity and zero precision matrix + def zero_message(self, batch_ignore_mask: Optional[torch.Tensor] = None): + new_mean = [] + batch_size = self.mean[0].shape[0] + for var in self.mean: + if var.__class__ == th.Vector: + new_mean_i = var.__class__(var.dof()) + else: + new_mean_i = var.__class__() + repeats = torch.ones(var.ndim).int() + repeats[0] = batch_size + new_mean_i.tensor = new_mean_i.tensor.repeat(repeats.tolist()) + new_mean_i.tensor = new_mean_i.tensor.to( + dtype=self.dtype, device=self.device + ) + new_mean.append(new_mean_i) + new_precision = torch.zeros(batch_size, self.dof, self.dof).to( + dtype=self.dtype, device=self.device + ) + self.update( + mean=new_mean, precision=new_precision, batch_ignore_mask=batch_ignore_mask + ) + + +# Factor class, one is created for each cost function +class Factor: + _ids = count(0) + + def __init__( + self, + cf: CostFunction, + var_ixs: torch.Tensor, + lin_system_damping: torch.Tensor, + name: Optional[str] = None, + ): + self._id = next(Factor._ids) + if name: + self.name = name + else: + self.name = f"{self.__class__.__name__}__{self._id}" + + self.cf = cf + self.var_ixs = var_ixs + + # batch_size of the vectorized factor. In general != objective.batch_size. + # They are equal without vectorization or for unique cost function schema. + self.batch_size = cf.optim_var_at(0).shape[0] + + device = cf.optim_var_at(0).device + dtype = cf.optim_var_at(0).dtype + self._dof = sum([var.dof() for var in cf.optim_vars]) + # for storing factor linearization + self.potential_eta = torch.zeros(self.batch_size, self.dof).to( + dtype=dtype, device=device + ) + self.potential_lam = torch.zeros(self.batch_size, self.dof, self.dof).to( + dtype=dtype, device=device + ) + self.lin_point: List[Manifold] = [ + var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars + ] + self.steps_since_lin = torch.zeros( + self.batch_size, device=device, dtype=torch.int + ) + + self.lm_damping = lin_system_damping.repeat(self.batch_size).to(device) + self.min_damping = torch.Tensor([1e-4]).to(dtype=dtype, device=device) + self.max_damping = torch.Tensor([1e2]).to(dtype=dtype, device=device) + self.last_err: torch.Tensor = (self.cf.error() ** 2).sum(dim=1) + self.a = 2 + self.b = 10 + + # messages incoming and outgoing from the factor, they are updated in place + self.vtof_msgs: List[Message] = [] + self.ftov_msgs: List[Message] = [] + for var in cf.optim_vars: + # Set mean of initial message to identity of the group + # NB doesn't matter what it is as long as precision is zero + vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") + ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg.zero_message() + ftov_msg.zero_message() + self.vtof_msgs.append(vtof_msg) + self.ftov_msgs.append(ftov_msg) + + # for vectorized vtof message passing + self.vectorized_var_ixs: List[torch.Tensor] = [None] * cf.num_optim_vars() + + # Linearizes factors at current belief if beliefs have deviated + # from the linearization point by more than the threshold. + def linearize( + self, + relin_threshold: float = None, + detach_hessian: bool = False, + lie=True, + ): + self.steps_since_lin += 1 + + if relin_threshold is None: + do_lin = torch.full( + [self.batch_size], + True, + device=self.cf.optim_var_at(0).device, + ) + else: + lp_dists = torch.cat( + [ + lp.local(self.cf.optim_var_at(j)).norm(dim=1)[..., None] + for j, lp in enumerate(self.lin_point) + ], + dim=1, + ) + max_dists = lp_dists.max(dim=1)[0] + do_lin = max_dists > relin_threshold + + if torch.sum(do_lin) > 0: # if any factor in the batch needs relinearization + J, error = self.cf.weighted_jacobians_error() + J_stk = torch.cat(J, dim=-1) + + # eqn 30 - https://arxiv.org/pdf/2202.03314.pdf + lam = ( + torch.bmm(J_stk.transpose(-2, -1), J_stk).detach() + if detach_hessian + else torch.bmm(J_stk.transpose(-2, -1), J_stk) + ) + # eqn 31 - https://arxiv.org/pdf/2202.03314.pdf + eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) + if lie is False: + optim_vars_stk = torch.cat( + [v.tensor for v in self.cf.optim_vars], dim=-1 + ) + eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) + eta = eta.squeeze(-1) + + # update linear system damping parameter (this is non-differentiable) + with torch.no_grad(): + err = (self.cf.error() ** 2).sum(dim=1) + damp_ixs = torch.logical_and(err > self.last_err, do_lin) + undamp_ixs = torch.logical_and(err < self.last_err, do_lin) + self.lm_damping[undamp_ixs] = torch.max( + self.lm_damping[undamp_ixs] / self.a, self.min_damping + ) + self.lm_damping[damp_ixs] = torch.min( + self.lm_damping[damp_ixs] * self.b, self.max_damping + ) + self.last_err[do_lin] = err[do_lin] + + # damp precision matrix + damped_D = self.lm_damping[:, None, None] * torch.eye( + lam.shape[1], device=lam.device, dtype=lam.dtype + ).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.lm_damping.device) + lam = lam + damped_D + + self.potential_eta[do_lin] = eta[do_lin] + self.potential_lam[do_lin] = lam[do_lin] + + for j, var in enumerate(self.cf.optim_vars): + self.lin_point[j].update(var.tensor, batch_ignore_mask=~do_lin) + + self.steps_since_lin[do_lin] = 0 + + # Compute all outgoing messages from the factor. + def comp_mess(self, msg_damping, schedule): + num_optim_vars = self.cf.num_optim_vars() + new_messages = [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = self.potential_eta.clone() + lam_factor = self.potential_lam.clone() + lam_factor_copy = lam_factor.clone() + + # Take product of factor with incoming messages. + # Convert mesages to tangent space at linearisation point. + # eqns 34-38 - https://arxiv.org/pdf/2202.03314.pdf + start = 0 + for i in range(num_optim_vars): + var_dofs = self.cf.optim_var_at(i).dof() + if i != v: + eta_mess, lam_mess = th.local_gaussian( + self.lin_point[i], self.vtof_msgs[i], return_mean=False + ) + eta_factor[:, start : start + var_dofs] += eta_mess + lam_factor[ + :, start : start + var_dofs, start : start + var_dofs + ] += lam_mess + + start += var_dofs + + dofs = self.cf.optim_var_at(v).dof() + + inc_messages = ( + ~torch.isclose(lam_factor, lam_factor_copy).all(dim=1).all(dim=1) + ) + # if no incoming messages then send out zero message + if not inc_messages.any() and num_optim_vars > 1: + # print(self.cf.name, "---> not updating, incoming precision is zero") + new_mess = Message([self.cf.optim_var_at(v).copy()]) + new_mess.zero_message() + + else: + # print(self.cf.name, "---> sending message") + # Divide up parameters of distribution to compute schur complement + # *_out = parameters for receiver variable (outgoing message vars) + # *_notout = parameters for other variables (not outgoing message vars) + eta_out = eta_factor[:, sdim : sdim + dofs] + eta_notout = torch.cat( + (eta_factor[:, :sdim], eta_factor[:, sdim + dofs :]), dim=1 + ) + + lam_out_out = lam_factor[:, sdim : sdim + dofs, sdim : sdim + dofs] + lam_out_notout = torch.cat( + ( + lam_factor[:, sdim : sdim + dofs, :sdim], + lam_factor[:, sdim : sdim + dofs, sdim + dofs :], + ), + dim=2, + ) + lam_notout_out = torch.cat( + ( + lam_factor[:, :sdim, sdim : sdim + dofs], + lam_factor[:, sdim + dofs :, sdim : sdim + dofs], + ), + dim=1, + ) + lam_notout_notout = torch.cat( + ( + torch.cat( + ( + lam_factor[:, :sdim, :sdim], + lam_factor[:, :sdim, sdim + dofs :], + ), + dim=2, + ), + torch.cat( + ( + lam_factor[:, sdim + dofs :, :sdim], + lam_factor[:, sdim + dofs :, sdim + dofs :], + ), + dim=2, + ), + ), + dim=1, + ) + + # Schur complement computation + new_mess_lam = ( + lam_out_out + - lam_out_notout + @ torch.linalg.inv(lam_notout_notout) + @ lam_notout_out + ) + new_mess_eta = eta_out - torch.bmm( + torch.bmm(lam_out_notout, torch.linalg.inv(lam_notout_notout)), + eta_notout.unsqueeze(-1), + ).squeeze(-1) + + # message damping in tangent space at linearisation point as message + # is already in this tangent space. Could equally do damping + # in the tangent space of the new or old message mean. + # Damping is applied to the mean parameters. + # do_damping = torch.logical_and(msg_damping[v] > 0, self.steps_since_lin >= 0) + do_damping = msg_damping[v] + if do_damping.sum() != 0: + damping_check = torch.logical_and( + new_mess_lam.count_nonzero(1, 2) != 0, + self.ftov_msgs[v].precision.count_nonzero(1, 2) != 0, + ) + do_damping = torch.logical_and(do_damping, damping_check) + if do_damping.sum() > 0: + prev_mess_mean, prev_mess_lam = th.local_gaussian( + self.lin_point[v], self.ftov_msgs[v], return_mean=True + ) + new_mess_mean = torch.bmm( + torch.inverse(new_mess_lam), new_mess_eta.unsqueeze(-1) + ).squeeze(-1) + msg_damping[v][~do_damping] = 0.0 + new_mess_mean = ( + 1 - msg_damping[v][:, None] + ) * new_mess_mean + msg_damping[v][:, None] * prev_mess_mean + new_mess_eta = torch.bmm( + new_mess_lam, new_mess_mean.unsqueeze(-1) + ).squeeze(-1) + + # don't send messages if schedule is False + if not schedule[v].all(): + # if any are False set these to prev message + prev_mess_eta, prev_mess_lam = th.local_gaussian( + self.lin_point[v], self.ftov_msgs[v], return_mean=False + ) + no_update = ~schedule[v] + new_mess_eta[no_update] = prev_mess_eta[no_update] + new_mess_lam[no_update] = prev_mess_lam[no_update] + + # eqns 39-41 - https://arxiv.org/pdf/2202.03314.pdf + new_mess_mean = th.LUDenseSolver._solve_sytem( + new_mess_eta[..., None], new_mess_lam + ) + new_mess_params = th.retract_gaussian( + self.lin_point[v], new_mess_mean, new_mess_lam + ) + new_mess = Message( + mean=new_mess_params.mean, precision=new_mess_params.precision + ) + + # set zero msg for factors with no incoming messages + if not inc_messages.all() and num_optim_vars > 1: + # i.e. ignore (don't zero) if there are incoming messages + new_mess.zero_message(batch_ignore_mask=inc_messages) + + new_messages.append(new_mess) + + sdim += dofs + + # update messages + for v in range(num_optim_vars): + self.ftov_msgs[v].update( + mean=new_messages[v].mean, precision=new_messages[v].precision + ) + + @property + def dof(self) -> int: + return self._dof + + +# Follows notation from https://arxiv.org/pdf/2202.03314.pdf +class GaussianBeliefPropagation(NonlinearOptimizer, abc.ABC): + def __init__( + self, + objective: Objective, + vectorize: bool = True, + abs_err_tolerance: float = 1e-10, + rel_err_tolerance: float = 1e-8, + max_iterations: int = 20, + ): + super().__init__(objective, vectorize=vectorize) + + # ordering is required to identify which messages to send where + self.ordering = VariableOrdering(objective, default_order=True) + + """ + GBP functions + """ + + # eqns in sec 3.1 - https://arxiv.org/pdf/2202.03314.pdf + def _pass_var_to_fac_messages_loop(self, update_belief=True): + for i, var in enumerate(self.ordering): + # Collect all incoming messages in the tangent space at the current belief + # eqns 4-7 - https://arxiv.org/pdf/2202.03314.pdf + etas_tp = [] # message etas + lams_tp = [] # message lams + for factor in self.factors: + for j, msg in enumerate(factor.ftov_msgs): + if factor.var_ixs[j] == i: + eta_tp, lam_tp = th.local_gaussian(var, msg, return_mean=False) + etas_tp.append(eta_tp[None, ...]) + lams_tp.append(lam_tp[None, ...]) + + etas_tp = torch.cat(etas_tp) + lams_tp = torch.cat(lams_tp) + + lam_tau = lams_tp.sum(dim=0) + + # Compute outgoing messages + # eqns 8-10 - https://arxiv.org/pdf/2202.03314.pdf + ix = 0 # index for msg in list of msgs to variables + for factor in self.factors: + for j, msg in enumerate(factor.vtof_msgs): + if factor.var_ixs[j] == i: + etas_inc = torch.cat((etas_tp[:ix], etas_tp[ix + 1 :])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) + + lam_a = lams_inc.sum(dim=0) + if lam_a.count_nonzero() == 0: + msg.zero_message() + else: + inv_lam_a = torch.linalg.inv(lam_a) + sum_etas = etas_inc.sum(dim=0) + mean_a = torch.matmul( + inv_lam_a, sum_etas.unsqueeze(-1) + ).squeeze(-1) + new_mess = th.retract_gaussian(var, mean_a, lam_a) + msg.update(new_mess.mean, new_mess.precision) + ix += 1 + + # update belief mean and variance + # if no incoming messages then leave current belief unchanged + if update_belief and lam_tau.count_nonzero() != 0: + inv_lam_tau = torch.inverse(lam_tau) + sum_taus = etas_tp.sum(dim=0) + tau = torch.matmul(inv_lam_tau, sum_taus.unsqueeze(-1)).squeeze(-1) + + new_belief = th.retract_gaussian(var, tau, lam_tau) + self.beliefs[i].update(new_belief.mean, new_belief.precision) + + # Similar to above fn but vectorization require tracking extra indices + def _pass_var_to_fac_messages_vectorized(self, update_belief=True): + # Each (variable-type, dof) gets mapped to a tuple with: + # - the variable that will hold the vectorized data + # - all the variables of that type that will be vectorized together + # - list of variable indices from ordering + # - tensor that will hold incoming messages [eta, lam] in the belief tangent plane + var_info: Dict[ + Tuple[Type[Manifold], int], + Tuple[Manifold, List[Manifold], List[int], List[torch.Tensor]], + ] = {} + batch_size = -1 + for ix, var in enumerate(self.ordering): + if batch_size == -1: + batch_size = var.shape[0] + else: + assert batch_size == var.shape[0] + + var_type = (var.__class__, var.dof()) + if var_type not in var_info: + var_info[var_type] = (var.copy(), [], [], []) + var_info[var_type][1].append(var) + var_info[var_type][2].append(ix) + + # For each variable-type, create tensors to accumulate incoming messages + for var_type, (vectorized_var, var_list, _, eta_lam) in var_info.items(): + n_vars, dof = len(var_list), vectorized_var.dof() + + # Get the vectorized tensor that has the current variable data. + # The resulting shape is (N * b, M), b is batch size, N is the number of + # variables in the group, and M is the data shape for this class + vectorized_data = torch.cat([v.tensor for v in var_list], dim=0) + assert ( + vectorized_data.shape + == (n_vars * batch_size,) + vectorized_data.shape[1:] + ) + vectorized_var.update(vectorized_data) + + eta_tp_acc = torch.zeros(n_vars * batch_size, dof) + lam_tp_acc = torch.zeros(n_vars * batch_size, dof, dof) + eta_tp_acc = eta_tp_acc.to(vectorized_data.device, vectorized_data.dtype) + lam_tp_acc = lam_tp_acc.to(vectorized_data.device, vectorized_data.dtype) + eta_lam.extend([eta_tp_acc, lam_tp_acc]) + + # add ftov messages to eta_tp and lam_tp accumulator tensors + # eqns 4-7 - https://arxiv.org/pdf/2202.03314.pdf + for factor in self.factors: + for i, msg in enumerate(factor.ftov_msgs): + # transform messages to tangent plane at the current variable value + eta_tp, lam_tp = th.local_gaussian( + factor.cf.optim_var_at(i), msg, return_mean=False + ) + + receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) + # get indices of the vectorized variables that receive each message + if factor.vectorized_var_ixs[i] is None: + receiver_var_ixs = factor.var_ixs[ + :, i + ] # ixs of the receiver variables + var_type_ixs = torch.tensor( + var_info[receiver_var_type][2] + ) # all ixs for variables of this type + var_type_ixs = var_type_ixs[None, :].repeat( + len(receiver_var_ixs), 1 + ) + is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 + indices = is_receiver.nonzero()[:, 1] + # expand indices for all batch variables when batch size > 1 + if self.objective.batch_size != 1: + indices = indices[:, None].repeat(1, self.objective.batch_size) + shift = ( + torch.arange(self.objective.batch_size)[None, :] + .long() + .to(indices.device) + ) + indices = indices + shift + indices = indices.flatten() + indices = indices.to(factor.cf.optim_var_at(0).device) + factor.vectorized_var_ixs[i] = indices + + # add messages to correct variable using indices + eta_tp_acc = var_info[receiver_var_type][3][0] + lam_tp_acc = var_info[receiver_var_type][3][1] + eta_tp_acc.index_add_(0, factor.vectorized_var_ixs[i], eta_tp) + lam_tp_acc.index_add_(0, factor.vectorized_var_ixs[i], lam_tp) + + # compute variable to factor messages, now all incoming messages are accumulated + # eqns 8-10 - https://arxiv.org/pdf/2202.03314.pdf + for factor in self.factors: + for i, msg in enumerate(factor.vtof_msgs): + # transform messages to tangent plane at the current variable value + ftov_msg = factor.ftov_msgs[i] + eta_tp, lam_tp = th.local_gaussian( + factor.cf.optim_var_at(i), ftov_msg, return_mean=False + ) + + # new outgoing message is belief - last incoming mesage (in log space parameters) + receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) + eta_tp_acc = var_info[receiver_var_type][3][0] + lam_tp_acc = var_info[receiver_var_type][3][1] + sum_etas = eta_tp_acc[factor.vectorized_var_ixs[i]] - eta_tp + lam_a = lam_tp_acc[factor.vectorized_var_ixs[i]] - lam_tp + + if lam_a.count_nonzero() == 0: + msg.zero_message() + else: + valid_lam = lam_a.count_nonzero(1, 2) != 0 + inv_lam_a = torch.zeros_like( + lam_a, dtype=lam_a.dtype, device=lam_a.device + ) + inv_lam_a[valid_lam] = torch.linalg.inv(lam_a[valid_lam]) + mean_a = torch.matmul(inv_lam_a, sum_etas.unsqueeze(-1)).squeeze(-1) + new_mess = th.retract_gaussian( + factor.cf.optim_var_at(i), mean_a, lam_a + ) + msg.update(new_mess.mean, new_mess.precision) + + # compute the new belief for the vectorized variables + # eqns 42-45 - https://arxiv.org/pdf/2202.03314.pdf + for vectorized_var, _, var_ixs, eta_lam in var_info.values(): + eta_tp_acc = eta_lam[0] + lam_tau = eta_lam[1] + + if update_belief and lam_tau.count_nonzero() != 0: + valid_lam = lam_tau.count_nonzero(1, 2) != 0 + inv_lam_tau = torch.zeros_like( + lam_tau, dtype=lam_tau.dtype, device=lam_tau.device + ) + inv_lam_tau[valid_lam] = torch.linalg.inv(lam_tau[valid_lam]) + tau = torch.matmul(inv_lam_tau, eta_tp_acc.unsqueeze(-1)).squeeze(-1) + + new_belief = th.retract_gaussian(vectorized_var, tau, lam_tau) + + # update non vectorized beliefs with slices + start_idx = 0 + for ix in var_ixs: + belief_mean_slice = new_belief.mean[0][ + start_idx : start_idx + batch_size + ] + belief_precision_slice = new_belief.precision[ + start_idx : start_idx + batch_size + ] + self.beliefs[ix].update([belief_mean_slice], belief_precision_slice) + start_idx += batch_size + + def _linearize_factors( + self, relin_threshold: float = None, detach_hessian: bool = False + ): + relins = 0 + for factor in self.factors: + factor.linearize( + relin_threshold=relin_threshold, detach_hessian=detach_hessian + ) + relins += int((factor.steps_since_lin == 0).sum().item()) + return relins + + def _pass_fac_to_var_messages( + self, schedule: torch.Tensor, ftov_msg_damping: torch.Tensor + ): + start_d = 0 + for j, factor in enumerate(self.factors): + num_optim_vars = factor.cf.num_optim_vars() + n_edges = num_optim_vars * factor.batch_size + damping_tsr = ftov_msg_damping[start_d : start_d + n_edges] + schedule_tsr = schedule[start_d : start_d + n_edges] + damping_tsr = damping_tsr.reshape(num_optim_vars, factor.batch_size) + schedule_tsr = schedule_tsr.reshape(num_optim_vars, factor.batch_size) + start_d += n_edges + + if schedule_tsr.sum() != 0: + factor.comp_mess(damping_tsr, schedule_tsr) + + def _create_factors_beliefs(self, lin_system_damping): + self.factors: List[Factor] = [] + self.beliefs: List[ManifoldGaussian] = [] + for var in self.ordering: + self.beliefs.append(ManifoldGaussian([var])) + + if self.objective.vectorized: + cf_iterator = iter(self.objective.vectorized_cost_fns) + self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_vectorized + else: + cf_iterator = iter(self.objective) + self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop + + # compute factor potentials for the first time + unary_factor = False + for i, cost_function in enumerate(cf_iterator): + if self.objective.vectorized: + # create array for indexing the messages + base_cf_names = self.objective.vectorized_cf_names[i] + + base_cfs = [ + self.objective.get_cost_function(name) for name in base_cf_names + ] + # index of variables connected to vectorized factor + var_ixs = torch.tensor( + [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] + for cf in base_cfs + ] + ).long() + else: + var_ixs = torch.tensor( + [ + self.ordering.index_of(var.name) + for var in cost_function.optim_vars + ] + ).long() + + self.factors.append( + Factor( + cost_function, + name=cost_function.name, + var_ixs=var_ixs, + lin_system_damping=lin_system_damping, + ) + ) + if cost_function.num_optim_vars() == 1: + unary_factor = True + if unary_factor is False: + raise Exception( + "We require at least one unary cost function to act as a prior." + "This is because Gaussian Belief Propagation is performing Bayesian inference." + ) + if self.objective.vectorized: + self.objective.update_vectorization_if_needed() + self._linearize_factors() + + self.n_individual_factors = ( + len(self.objective.cost_functions) * self.objective.batch_size + ) + self.n_edges = sum( + [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] + ) + + """ + Optimization loop functions + """ + + def _optimize_loop( + self, + num_iter: int, + info: NonlinearOptimizerInfo, + verbose: bool, + relin_threshold: float, + ftov_msg_damping: float, + dropout: float, + schedule: GBPSchedule, + lin_system_damping: torch.Tensor, + clear_messages: bool = True, + implicit_gbp_loop: bool = False, + end_iter_callback: Optional[EndIterCallbackType] = None, + **kwargs, + ): + # we only create the factors and beliefs right before runnig GBP as they are + # not automatically updated when objective.update is called. + if clear_messages: + self._create_factors_beliefs(lin_system_damping) + else: + self.objective.update_vectorization_if_needed() + if not implicit_gbp_loop: + self._linearize_factors() + + if implicit_gbp_loop: + relin_threshold = 1e10 # no relinearisation + if self.objective.vectorized: + self.objective.update_vectorization_if_needed() + self._linearize_factors(detach_hessian=True) + + if schedule == GBPSchedule.SYNCHRONOUS: + ftov_schedule = synchronous_schedule(num_iter, self.n_edges) + + self.ftov_msgs_history = {} + + converged_indices = torch.zeros_like(info.last_err).bool() + iters_done = 0 + for it_ in range(num_iter): + iters_done += 1 + curr_ftov_msgs = [] + for factor in self.factors: + curr_ftov_msgs.extend([msg.copy() for msg in factor.ftov_msgs]) + self.ftov_msgs_history[it_] = curr_ftov_msgs + + # damping + ftov_damping_arr = torch.full( + [self.n_edges], + ftov_msg_damping, + device=self.ordering[0].device, + dtype=self.ordering[0].dtype, + ) + # dropout is implemented by changing the schedule + if dropout != 0.0 and it_ > 1: + dropout_ixs = torch.rand(self.n_edges) < dropout + ftov_schedule[it_, dropout_ixs] = False + + # t0 = time.time() + relins = self._linearize_factors(relin_threshold) + # t_relin = time.time() - t0 + + # t1 = time.time() + self._pass_fac_to_var_messages(ftov_schedule[it_], ftov_damping_arr) + # t_ftov = time.time() - t1 + + # t1 = time.time() + self._pass_var_to_fac_messages(update_belief=True) + # t_vtof = time.time() - t1 + + # t_vec = 0.0 + if self.objective.vectorized: + # t1 = time.time() + self.objective.update_vectorization_if_needed() + # t_vec = time.time() - t1 + + # if verbose: + # t_tot = time.time() - t0 + # print( + # f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," + # f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" + # ) + + # check for convergence + if it_ >= 0: + with torch.no_grad(): + err = self.objective.error_metric() / 2 + self._update_info(info, it_, err, converged_indices) + if verbose: + print( + f"GBP. Iteration: {it_+1}. Error: {err.mean().item()}. " + f"Relins: {relins} / {self.n_individual_factors}" + ) + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + converged_indices.cpu().numpy() + ] = NonlinearOptimizerStatus.CONVERGED + if converged_indices.all() and it_ > 1: + break # nothing else will happen at this point + info.last_err = err + + if end_iter_callback is not None: + end_iter_callback(self, info, None, it_) + + info.status[ + info.status == NonlinearOptimizerStatus.START + ] = NonlinearOptimizerStatus.MAX_ITERATIONS + return iters_done + + # `track_best_solution` keeps a **detached** copy (as in no gradient info) + # of the best variables found, but it is optional to avoid unnecessary copying + # if this is not needed + def _optimize_impl( + self, + track_best_solution: bool = False, + track_err_history: bool = False, + track_state_history: bool = False, + verbose: bool = False, + backward_mode: Union[str, BackwardMode] = BackwardMode.UNROLL, + relin_threshold: float = 1e-8, + ftov_msg_damping: float = 0.0, + dropout: float = 0.0, + schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, + lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), + implicit_step_size: float = 1e-4, + implicit_method: str = "gbp", + end_iter_callback: Optional[EndIterCallbackType] = None, + **kwargs, + ) -> NonlinearOptimizerInfo: + backward_mode = BackwardMode.resolve(backward_mode) + kwargs_plus_bwd_mode = {**kwargs, **{"backward_mode": backward_mode}} + if backward_mode == BackwardMode.DLM: + raise ValueError( + "DLM backward mode not supported for Gaussian Belief Propagation optimizer." + ) + with torch.no_grad(): + info = self._init_info( + track_best_solution, track_err_history, track_state_history + ) + + if ftov_msg_damping > 1.0 or ftov_msg_damping < 0.0: + raise ValueError( + f"Damping must be between 0 and 1. Got {ftov_msg_damping}." + ) + if dropout > 1.0 or dropout < 0.0: + raise ValueError( + f"Dropout probability must be between 0 and 1. Got {dropout}." + ) + if dropout > 0.9: + print( + "Disabling vectorization due to dropout > 0.9 in GBP message schedule." + ) + self.objective.disable_vectorization() + + if not isinstance(lin_system_damping, torch.Tensor): + raise TypeError("lin_system_damping should be an instance of torch.Tensor.") + expected_shape = torch.Size([1]) + if lin_system_damping.shape != expected_shape: + raise ValueError( + f"lin_system_damping should have shape {expected_shape}. " + f"Got shape {lin_system_damping.shape}." + ) + lin_system_damping.to(self.objective.device, self.objective.dtype) + + if verbose: + print( + f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" + ) + + backward_num_iters, no_grad_num_iters = self._split_backward_iters( + **kwargs_plus_bwd_mode + ) + if backward_mode == BackwardMode.UNROLL: + self._optimize_loop( + num_iter=backward_num_iters, + info=info, + verbose=verbose, + relin_threshold=relin_threshold, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + end_iter_callback=end_iter_callback, + **kwargs, + ) + # If didn't coverge, remove misleading converged_iter value + info.converged_iter[ + info.status == NonlinearOptimizerStatus.MAX_ITERATIONS + ] = -1 + return info + + elif backward_mode in [BackwardMode.IMPLICIT, BackwardMode.TRUNCATED]: + if backward_mode == BackwardMode.IMPLICIT: + self.implicit_method = implicit_method + implicit_methods = ["gauss_newton", "gbp"] + if implicit_method not in implicit_methods: + raise ValueError( + f"implicit_method must be one of {implicit_methods}, " + f"but got {implicit_method}" + ) + backward_num_iters = 0 + no_grad_num_iters = self.params.max_iterations + + with torch.no_grad(): + # actual_num_iters could be < num_iter due to early convergence + no_grad_iters_done = self._optimize_loop( + num_iter=no_grad_num_iters, + info=info, + verbose=verbose, + relin_threshold=relin_threshold, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + end_iter_callback=end_iter_callback, + **kwargs, + ) + + grad_loop_info = self._init_info( + track_best_solution, track_err_history, track_state_history + ) + if backward_mode == BackwardMode.TRUNCATED: + grad_iters_done = self._optimize_loop( + num_iter=no_grad_num_iters, + info=grad_loop_info, + verbose=verbose, + relin_threshold=relin_threshold, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + clear_messages=False, + end_iter_callback=end_iter_callback, + **kwargs, + ) + # Adds grad_loop_info results to original info + self._merge_infos( + grad_loop_info, no_grad_iters_done, grad_iters_done, info + ) + + # Compute the approximate the implicit derivative using a gauss-newton step. + # It is first order approximate, as a full Newton step would be exact. + elif implicit_method == "gauss_newton": + self.implicit_step_size = implicit_step_size + gauss_newton_optimizer = th.GaussNewton(self.objective) + gauss_newton_optimizer.linear_solver.linearization.linearize() + delta = gauss_newton_optimizer.linear_solver.solve() + self.objective.retract_vars_sequence( + delta * implicit_step_size, + gauss_newton_optimizer.linear_solver.linearization.ordering, + force_update=True, + ) + if verbose: + err = self.objective.error_metric() / 2 + print( + f"Nonlinear optimizer implcit step. Error: {err.mean().item()}" + ) + # solve normal equation with GBP + elif implicit_method == "gbp": + max_lin_solve_iters = 1000 + grad_iters_done = self._optimize_loop( + num_iter=max_lin_solve_iters, + info=grad_loop_info, + verbose=verbose, + relin_threshold=1e10, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + clear_messages=False, + implicit_gbp_loop=True, + end_iter_callback=end_iter_callback, + **kwargs, + ) + + return info + else: + raise ValueError("Unrecognized backward mode") diff --git a/theseus/optimizer/linear/dense_solver.py b/theseus/optimizer/linear/dense_solver.py index e5e58e53f..3cc7a00f6 100644 --- a/theseus/optimizer/linear/dense_solver.py +++ b/theseus/optimizer/linear/dense_solver.py @@ -137,7 +137,8 @@ def __init__( check_singular=check_singular, ) - def _solve_sytem(self, Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: + @staticmethod + def _solve_sytem(Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: return torch.linalg.solve(AtA, Atb).squeeze(2) @@ -156,6 +157,7 @@ def __init__( check_singular=check_singular, ) - def _solve_sytem(self, Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: + @staticmethod + def _solve_sytem(Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: lower = torch.linalg.cholesky(AtA) return torch.cholesky_solve(Atb, lower).squeeze(2) diff --git a/theseus/optimizer/manifold_gaussian.py b/theseus/optimizer/manifold_gaussian.py index 8cf1d980c..fd37484d7 100644 --- a/theseus/optimizer/manifold_gaussian.py +++ b/theseus/optimizer/manifold_gaussian.py @@ -74,6 +74,7 @@ def update( self, mean: Sequence[Manifold], precision: torch.Tensor, + batch_ignore_mask: Optional[torch.Tensor] = None, ): if len(mean) != len(self.mean): raise ValueError( @@ -82,7 +83,7 @@ def update( f"Expected: {len(self.mean)}" ) for i in range(len(self.mean)): - self.mean[i].update(mean[i]) + self.mean[i].update(mean[i], batch_ignore_mask=batch_ignore_mask) expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof]) if precision.shape != expected_shape: @@ -101,10 +102,16 @@ def update( f"Tried to update using tensor on device: {precision.dtype} but precision " f"is on device: {self.device}." ) - if not torch.allclose(precision, precision.transpose(1, 2)): + if not torch.allclose(precision, precision.transpose(1, 2), atol=1e-5): raise ValueError("Tried to update precision with non-symmetric matrix.") - self.precision = precision + if batch_ignore_mask is not None and batch_ignore_mask.any(): + mask_shape = (-1,) + (1,) * (precision.ndim - 1) + self.precision = torch.where( + batch_ignore_mask.view(mask_shape), self.precision, precision + ) + else: + self.precision = precision # Projects the gaussian (ManifoldGaussian object) into the tangent plane at