From 1918b45b8f656cced6748f4c5be6d2713d3443da Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 16 Mar 2022 18:19:32 +0000 Subject: [PATCH 01/64] gbp implementation for pose graph problem, euclidean and lie algebra --- theseus/optimizer/gbp/gbp_baseline.py | 718 +++++++++++++++ theseus/optimizer/gbp/jax_torch_test.py | 292 ++++++ theseus/optimizer/gbp/pose_graph_gbp.py | 1092 +++++++++++++++++++++++ 3 files changed, 2102 insertions(+) create mode 100644 theseus/optimizer/gbp/gbp_baseline.py create mode 100644 theseus/optimizer/gbp/jax_torch_test.py create mode 100644 theseus/optimizer/gbp/pose_graph_gbp.py diff --git a/theseus/optimizer/gbp/gbp_baseline.py b/theseus/optimizer/gbp/gbp_baseline.py new file mode 100644 index 000000000..d63e05a1b --- /dev/null +++ b/theseus/optimizer/gbp/gbp_baseline.py @@ -0,0 +1,718 @@ +import numpy as np +from typing import Any, Dict, Optional, Type +from typing import List, Callable, Optional, Union + + +""" +Defines squared loss functions that correspond to Gaussians. +Robust losses are implemented by scaling the Gaussian covariance. +""" + +class Gaussian: + def __init__( + self, + dim, + eta=None, + lam=None, + ): + self.dim = dim + + if eta is not None and eta.shape == (dim,): + self.eta = eta + else: + self.eta = np.zeros(dim) + + if lam is not None and lam.shape == (dim, dim): + self.lam = lam + else: + self.lam = np.zeros([dim, dim]) + + def mean(self) -> np.ndarray: + return np.matmul(np.linalg.inv(self.lam), self.eta) + + def cov(self) -> np.ndarray: + return np.linalg.inv(self.lam) + + def mean_and_cov(self) -> List[np.ndarray]: + cov = self.cov() + mean = np.matmul(cov, self.eta) + return [mean, cov] + + def set_with_cov_form(self, mean: np.ndarray, cov: np.ndarray) -> None: + self.lam = np.linalg.inv(cov) + self.eta = np.matmul(self.lam, mean) + + +class GBPSettings: + def __init__( + self, + damping: float = 0., + beta: float = 0.1, + num_undamped_iters: int = 5, + min_linear_iters: int = 10, + dropout: float = 0., + reset_iters_since_relin: List[int] = [], + ): + # Parameters for damping the eta component of the message + self.damping = damping + # Number of undamped iterations after relin before damping is on + self.num_undamped_iters = num_undamped_iters + + self.dropout = dropout + + # Parameters for just in time factor relinearisation. + # Threshold absolute distance between linpoint + # and adjacent belief means for relinearisation. + self.beta = beta + # Minimum number of linear iterations before + # a factor is allowed to realinearise. + self.min_linear_iters = min_linear_iters + self.reset_iters_since_relin = reset_iters_since_relin + + def get_damping(self, iters_since_relin: int) -> float: + if iters_since_relin > self.num_undamped_iters: + return self.damping + else: + return 0. + + +class SquaredLoss(): + def __init__( + self, + dofs: int, + diag_cov: Union[float, np.ndarray] + ): + """ + dofs: dofs of the measurement + cov: diagonal elements of covariance matrix + """ + assert diag_cov.shape[0] == dofs + mat = np.zeros([dofs, dofs]) + mat[range(dofs), range(dofs)] = diag_cov + self.cov = mat + self.effective_cov = mat.copy() + + def get_effective_cov(self, residual: np.ndarray) -> None: + """ + Returns the covariance of the Gaussian (squared loss) + that matches the loss at the error value. + """ + self.effective_cov = self.cov.copy() + + def robust(self) -> bool: + return not np.equal(self.cov, self.effective_cov) + + +class HuberLoss(SquaredLoss): + def __init__( + self, + dofs: int, + diag_cov: Union[float, np.ndarray], + stds_transition: float + ): + """ + stds_transition: num standard deviations from minimum at + which quadratic loss transitions to linear. + """ + SquaredLoss.__init__(self, dofs, diag_cov) + self.stds_transition = stds_transition + + def get_effective_cov(self, residual: np.ndarray) -> None: + energy = residual @ np.linalg.inv(self.cov) @ residual + mahalanobis_dist = np.sqrt(energy) + if mahalanobis_dist > self.stds_transition: + denom = (2 * self.stds_transition * mahalanobis_dist - self.stds_transition ** 2) + self.effective_cov = self.cov * mahalanobis_dist**2 / denom + else: + self.effective_cov = self.cov.copy() + + +class MeasModel: + def __init__( + self, + meas_fn: Callable, + jac_fn: Callable, + loss: SquaredLoss, + *args, + ): + self._meas_fn = meas_fn + self._jac_fn = jac_fn + self.loss = loss + self.args = args + self.linear = True + + def jac_fn(self, x: np.ndarray) -> np.ndarray: + return self._jac_fn(x, *self.args) + + def meas_fn(self, x: np.ndarray) -> np.ndarray: + return self._meas_fn(x, *self.args) + + +def lin_meas_fn(x): + length = int(x.shape[0] / 2) + J = np.concatenate((-np.eye(length), np.eye(length)), axis=1) + return J @ x + + +def lin_jac_fn(x): + length = int(x.shape[0] / 2) + return np.concatenate((-np.eye(length), np.eye(length)), axis=1) + + +class LinearDisplacementModel(MeasModel): + def __init__(self, loss: SquaredLoss) -> None: + MeasModel.__init__(self, lin_meas_fn, lin_jac_fn, loss) + self.linear = True + + +""" +Main GBP functions. +Defines classes for variable nodes, factor nodes and edges and factor graph. +""" + + +class FactorGraph: + def __init__( + self, + gbp_settings: GBPSettings = GBPSettings(), + ): + self.var_nodes = [] + self.factors = [] + self.gbp_settings = gbp_settings + + def add_var_node( + self, + dofs: int, + prior_mean: Optional[np.ndarray] = None, + prior_diag_cov: Optional[Union[float, np.ndarray]] = None, + ) -> None: + variableID = len(self.var_nodes) + self.var_nodes.append(VariableNode(variableID, dofs)) + if prior_mean is not None and prior_diag_cov is not None: + prior_cov = np.zeros([dofs, dofs]) + prior_cov[range(dofs), range(dofs)] = prior_diag_cov + self.var_nodes[-1].prior.set_with_cov_form(prior_mean, prior_cov) + self.var_nodes[-1].update_belief() + + def add_factor( + self, + adj_var_ids: List[int], + measurement: np.ndarray, + meas_model: MeasModel, + ) -> None: + factorID = len(self.factors) + adj_var_nodes = [self.var_nodes[i] for i in adj_var_ids] + self.factors.append( + Factor(factorID, adj_var_nodes, measurement, meas_model)) + for var in adj_var_nodes: + var.adj_factors.append(self.factors[-1]) + + def update_all_beliefs(self) -> None: + for var_node in self.var_nodes: + var_node.update_belief() + + def compute_all_messages(self, apply_dropout: bool = True) -> None: + for factor in self.factors: + dropout_off = apply_dropout and np.random.rand() > self.gbp_settings.dropout + if dropout_off or not apply_dropout: + damping = self.gbp_settings.get_damping( + factor.iters_since_relin) + factor.compute_messages(damping) + + def linearise_all_factors(self) -> None: + for factor in self.factors: + factor.compute_factor() + + def robustify_all_factors(self) -> None: + for factor in self.factors: + factor.robustify_loss() + + def jit_linearisation(self) -> None: + """ + Check for all factors that the current estimate + is close to the linearisation point. + If not, relinearise the factor distribution. + Relinearisation is only allowed at a maximum frequency + of once every min_linear_iters iterations. + """ + for factor in self.factors: + if not factor.meas_model.linear: + adj_belief_means = factor.get_adj_means() + factor.iters_since_relin += 1 + diff_cond = np.linalg.norm(factor.linpoint - adj_belief_means) > self.gbp_settings.beta + iters_cond = factor.iters_since_relin >= self.gbp_settings.min_linear_iters + if diff_cond and iters_cond: + factor.compute_factor() + + def synchronous_iteration(self) -> None: + self.robustify_all_factors() + self.jit_linearisation() # For linear factors, no compute is done + self.compute_all_messages() + self.update_all_beliefs() + + def random_message(self) -> None: + """ + Sends messages to all adjacent nodes from a random factor. + """ + self.robustify_all_factors() + self.jit_linearisation() # For linear factors, no compute is done + ix = np.random.randint(len(self.factors)) + factor = self.factors[ix] + damping = self.gbp_settings.get_damping(factor.iters_since_relin) + factor.compute_messages(damping) + self.update_all_beliefs() + + def gbp_solve( + self, + n_iters: Optional[int] = 20, + converged_threshold: Optional[float] = 1e-6, + include_priors: bool = True + ) -> None: + energy_log = [self.energy()] + print(f"\nInitial Energy {energy_log[0]:.5f}") + + i = 0 + count = 0 + not_converged = True + + while not_converged and i < n_iters: + self.synchronous_iteration() + if i in self.gbp_settings.reset_iters_since_relin: + for f in self.factors: + f.iters_since_relin = 1 + + energy_log.append(self.energy(include_priors=include_priors)) + print( + f"Iter {i+1} --- " + f"Energy {energy_log[-1]:.5f} --- " + ) + i += 1 + if abs(energy_log[-2] - energy_log[-1]) < converged_threshold: + count += 1 + if count == 3: + not_converged = False + else: + count = 0 + + def energy( + self, + eval_point: np.ndarray = None, + include_priors: bool = True + ) -> float: + """ + Computes the sum of all of the squared errors in the graph + using the appropriate local loss function. + """ + if eval_point is None: + energy = sum([factor.get_energy() for factor in self.factors]) + else: + var_dofs = np.ndarray([v.dofs for v in self.var_nodes]) + var_ix = np.concatenate([np.ndarray([0]), np.cumsum(var_dofs, axis=0)[:-1]]) + energy = 0. + for f in self.factors: + local_eval_point = np.concatenate([eval_point[var_ix[v.variableID]: var_ix[v.variableID] + v.dofs] for v in f.adj_var_nodes]) + energy += f.get_energy(local_eval_point) + if include_priors: + prior_energy = sum([var.get_prior_energy() for var in self.var_nodes]) + energy += prior_energy + return energy + + def get_joint_dim(self) -> int: + return sum([var.dofs for var in self.var_nodes]) + + def get_joint(self) -> Gaussian: + """ + Get the joint distribution over all variables in the information form + If nonlinear factors, it is taken at the current linearisation point. + """ + dim = self.get_joint_dim() + joint = Gaussian(dim) + + # Priors + var_ix = [0] * len(self.var_nodes) + counter = 0 + for var in self.var_nodes: + var_ix[var.variableID] = int(counter) + joint.eta[counter:counter + var.dofs] += var.prior.eta + joint.lam[counter:counter + var.dofs, counter:counter + var.dofs] += var.prior.lam + counter += var.dofs + + # Other factors + for factor in self.factors: + factor_ix = 0 + for adj_var_node in factor.adj_var_nodes: + vID = adj_var_node.variableID + # Diagonal contribution of factor + joint.eta[var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ + factor.factor.eta[factor_ix:factor_ix + adj_var_node.dofs] + joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ + factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] + other_factor_ix = 0 + for other_adj_var_node in factor.adj_var_nodes: + if other_adj_var_node.variableID > adj_var_node.variableID: + other_vID = other_adj_var_node.variableID + # Off diagonal contributions of factor + joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs] += \ + factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, other_factor_ix:other_factor_ix + other_adj_var_node.dofs] + joint.lam[var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ + factor.factor.lam[other_factor_ix:other_factor_ix + other_adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] + other_factor_ix += other_adj_var_node.dofs + factor_ix += adj_var_node.dofs + + return joint + + def MAP(self) -> np.ndarray: + return self.get_joint().mean() + + def dist_from_MAP(self) -> np.ndarray: + return np.linalg.norm(self.get_joint().mean() - self.belief_means()) + + def belief_means(self) -> np.ndarray: + """ Get an array containing all current estimates of belief means. """ + return np.concatenate([var.belief.mean() for var in self.var_nodes]) + + def belief_covs(self) -> List[np.ndarray]: + """ Get a list of all belief covariances. """ + covs = [var.belief.cov() for var in self.var_nodes] + return covs + + def print(self, brief=False) -> None: + print("\nFactor Graph:") + print(f"# Variable nodes: {len(self.var_nodes)}") + if not brief: + for i, var in enumerate(self.var_nodes): + print(f"Variable {i}: connects to factors {[f.factorID for f in var.adj_factors]}") + print(f" dofs: {var.dofs}") + print(f" prior mean: {var.prior.mean()}") + print(f" prior covariance: diagonal sigma {np.diag(var.prior.cov())}") + print(f"# Factors: {len(self.factors)}") + if not brief: + for i, factor in enumerate(self.factors): + if factor.meas_model.linear: + print("Linear", end=" ") + else: + print("Nonlinear", end=" ") + print(f"Factor {i}: connects to variables {factor.adj_vIDs}") + print( + f" measurement model: {type(factor.meas_model).__name__}," + f" {type(factor.meas_model.loss).__name__}," + f" diagonal sigma {np.diag(factor.meas_model.loss.effective_cov)}" + ) + print(f" measurement: {factor.measurement}") + print("\n") + + +class VariableNode: + def __init__(self, id: int, dofs: int): + self.variableID = id + self.dofs = dofs + self.adj_factors = [] + # prior factor, implemented as part of variable node + self.prior = Gaussian(dofs) + self.belief = Gaussian(dofs) + + def update_belief(self) -> None: + """ + Update local belief estimate by taking product + of all incoming messages along all edges. + """ + # message from prior factor + self.belief.eta = self.prior.eta.copy() + self.belief.lam = self.prior.lam.copy() + # messages from other adjacent variables + for factor in self.adj_factors: + message_ix = factor.adj_vIDs.index(self.variableID) + self.belief.eta += factor.messages[message_ix].eta + self.belief.lam += factor.messages[message_ix].lam + + def get_prior_energy(self) -> float: + energy = 0. + if self.prior.lam[0, 0] != 0.: + residual = self.belief.mean() - self.prior.mean() + energy += 0.5 * residual @ self.prior.lam @ residual + return energy + + +class Factor: + def __init__( + self, + id: int, + adj_var_nodes: List[VariableNode], + measurement: np.ndarray, + meas_model: MeasModel, + ) -> None: + + self.factorID = id + + self.adj_var_nodes = adj_var_nodes + self.dofs = sum([var.dofs for var in adj_var_nodes]) + self.adj_vIDs = [var.variableID for var in adj_var_nodes] + self.messages = [Gaussian(var.dofs) for var in adj_var_nodes] + + self.factor = Gaussian(self.dofs) + self.linpoint = np.zeros(self.dofs) + + self.measurement = measurement + self.meas_model = meas_model + + # For smarter GBP implementations + self.iters_since_relin = 0 + + self.compute_factor() + + def get_adj_means(self) -> np.ndarray: + adj_belief_means = [var.belief.mean() for var in self.adj_var_nodes] + return np.concatenate(adj_belief_means) + + def get_residual(self, eval_point: np.ndarray = None) -> np.ndarray: + """ Compute the residual vector. """ + if eval_point is None: + eval_point = self.get_adj_means() + return self.meas_model.meas_fn(eval_point) - self.measurement + + def get_energy(self, eval_point: np.ndarray = None) -> float: + """ Computes the squared error using the appropriate loss function. """ + residual = self.get_residual(eval_point) + inf_mat = np.linalg.inv(self.meas_model.loss.effective_cov) + return 0.5 * residual @ inf_mat @ residual + + def robust(self) -> bool: + return self.meas_model.loss.robust() + + def compute_factor(self) -> None: + """ + Compute the factor at current adjacente beliefs using robust. + If measurement model is linear then factor will always be + the same regardless of linearisation point. + """ + self.linpoint = self.get_adj_means() + J = self.meas_model.jac_fn(self.linpoint) + pred_measurement = self.meas_model.meas_fn(self.linpoint) + self.meas_model.loss.get_effective_cov(pred_measurement - self.measurement) + effective_lam = np.linalg.inv(self.meas_model.loss.effective_cov) + self.factor.lam = J.T @ effective_lam @ J + self.factor.eta = ((J.T @ effective_lam) @ (J @ self.linpoint + self.measurement - pred_measurement)).flatten() + self.iters_since_relin = 0 + + def robustify_loss(self) -> None: + """ + Rescale the variance of the noise in the Gaussian + measurement model if necessary and update the factor + correspondingly. + """ + old_effective_cov = self.meas_model.loss.effective_cov[0, 0] + self.meas_model.loss.get_effective_cov(self.get_residual()) + self.factor.eta *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] + self.factor.lam *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] + + def compute_messages(self, damping: float = 0.) -> None: + """ Compute all outgoing messages from the factor. """ + messages_eta, messages_lam = [], [] + + sdim = 0 + for v in range(len(self.adj_vIDs)): + eta_factor = self.factor.eta.copy() + lam_factor = self.factor.lam.copy() + + # Take product of factor with incoming messages + start = 0 + for var in range(len(self.adj_vIDs)): + if var != v: + var_dofs = self.adj_var_nodes[var].dofs + eta_mess = self.adj_var_nodes[var].belief.eta - self.messages[var].eta + lam_mess = self.adj_var_nodes[var].belief.lam - self.messages[var].lam + eta_factor[start:start + var_dofs] += eta_mess + lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess + start += self.adj_var_nodes[var].dofs + + # Divide up parameters of distribution + dofs = self.adj_var_nodes[v].dofs + eo = eta_factor[sdim:sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + + loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + lono = np.concatenate( + (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), + axis=1) + lnoo = np.concatenate( + (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), + axis=0) + lnono = np.concatenate(( + np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), + np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) + ), axis=0) + + new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno + messages_eta.append((1 - damping) * new_message_eta + damping * self.messages[v].eta) + messages_lam.append((1 - damping) * new_message_lam + damping * self.messages[v].lam) + sdim += self.adj_var_nodes[v].dofs + + for v in range(len(self.adj_vIDs)): + self.messages[v].lam = messages_lam[v] + self.messages[v].eta = messages_eta[v] + + +""" +Visualisation function +""" + + +def draw(i): + fig, ax = plt.subplots(figsize=(7, 6)) + fig.set_tight_layout(True) + plt.title(i) + + # plot beliefs + means = fg.belief_means().reshape([size * size, 2]) + plt.scatter(means[:, 0], means[:, 1], color="blue") + for j, cov in enumerate(fg.belief_covs()): + circle = plt.Circle( + (means[j, 0], means[j, 1]), + np.sqrt(cov[0, 0]), linewidth=0.5, color='blue', fill=False + ) + ax.add_patch(circle) + + # plot true marginals + plt.scatter(map_soln[:, 0], map_soln[:, 1], color="g") + for j, cov in enumerate(marg_covs): + circle = plt.Circle( + (map_soln[j, 0], map_soln[j, 1]), + np.sqrt(marg_covs[j]), linewidth=0.5, color='g', fill=False + ) + ax.add_patch(circle) + + # draw lines for factors + for f in fg.factors: + bels = np.array([means[f.adj_vIDs[0]], means[f.adj_vIDs[1]]]) + plt.plot(bels[:, 0], bels[:, 1], color='black', linewidth=0.3) + + # draw lines for belief error + for i in range(len(means)): + xs = [means[i, 0], map_soln[i, 0]] + ys = [means[i, 1], map_soln[i, 1]] + plt.plot(xs, ys, color='grey', linewidth=0.3, linestyle='dashed') + + plt.axis('scaled') + plt.xlim([-1, size]) + plt.ylim([-1, size]) + + # convert to image + ax.axis('off') + fig.tight_layout(pad=0) + ax.margins(0) + fig.canvas.draw() + img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + + return img + +if __name__ == "__main__": + + np.random.seed(1) + + size = 3 + dim = 2 + + prior_noise_std = 0.2 + + gbp_settings = GBPSettings( + damping=0., + beta=0.1, + num_undamped_iters=1, + min_linear_iters=10, + dropout=0.0, + ) + + # GBP library soln ------------------------------------------ + + noise_cov = np.array([0.01, 0.01]) + + prior_sigma = np.array([1.3**2, 1.3**2]) + prior_noise_std = 0.2 + + fg = FactorGraph(gbp_settings) + + init_noises = np.random.normal(np.zeros([size*size, 2]), prior_noise_std) + meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) + + for i in range(size): + for j in range(size): + init = np.array([j, i]) + noise_init = init_noises[j + i * size] + init = init + noise_init + sigma = prior_sigma + if i == 0 and j == 0: + init = np.array([j, i]) + sigma = np.array([0.0001, 0.0001]) + print(init, sigma) + fg.add_var_node(2, init, sigma) + + m = 0 + for i in range(size): + for j in range(size): + if j < size - 1: + meas = np.array([1., 0.]) + meas += meas_noises[m] + fg.add_factor( + [i * size + j, i * size + j + 1], + meas, + LinearDisplacementModel(SquaredLoss(dim, noise_cov)) + ) + m += 1 + if i < size - 1: + meas = np.array([0., 1.]) + meas += meas_noises[m] + fg.add_factor( + [i * size + j, (i + 1) * size + j], + meas, + LinearDisplacementModel(SquaredLoss(dim, noise_cov)) + ) + m += 1 + + fg.print(brief=True) + + # # for vis --------------- + + joint = fg.get_joint() + marg_covs = np.diag(joint.cov())[::2] + map_soln = fg.MAP().reshape([size * size, 2]) + + # # run gbp --------------- + + gbp_settings = GBPSettings( + damping=0., + beta=0.1, + num_undamped_iters=1, + min_linear_iters=10, + dropout=0.0, + ) + + # fg.compute_all_messages() + + import ipdb; ipdb.set_trace() + + # i = 0 + n_iters = 5 + while i <= n_iters: + # img = draw(i) + # cv2.imshow('img', img) + # cv2.waitKey(1) + + print(f"Iter {i} --- Energy {fg.energy():.5f}") + + # fg.random_message() + fg.synchronous_iteration() + i += 1 + + for f in fg.factors: + for m in f.messages: + print(np.linalg.inv(m.lam) @ m.eta) + + print(fg.belief_means()) + + import ipdb; ipdb.set_trace() + + + # time.sleep(0.05) diff --git a/theseus/optimizer/gbp/jax_torch_test.py b/theseus/optimizer/gbp/jax_torch_test.py new file mode 100644 index 000000000..bf6aecca0 --- /dev/null +++ b/theseus/optimizer/gbp/jax_torch_test.py @@ -0,0 +1,292 @@ +import numpy as np +import torch +import jax +import jax.numpy as jnp + +import time + + +def pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + adj_var_dofs_nested, +): + ftov_msgs_eta = [None] * len(vtof_msgs_eta) + ftov_msgs_lam = [None] * len(vtof_msgs_eta) + + start = 0 + for i in range(len(adj_var_dofs_nested)): + adj_var_dofs = adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + + inp_msgs_eta = vtof_msgs_eta[start: start + num_optim_vars] + inp_msgs_lam = vtof_msgs_lam[start: start + num_optim_vars] + + num_optim_vars = len(adj_var_dofs) + ftov_eta, ftov_lam = [], [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potentials_eta[i].clone()[0] + lam_factor = potentials_lam[i].clone()[0] + + # Take product of factor with incoming messages + start_in = 0 + for var in range(num_optim_vars): + var_dofs = adj_var_dofs[var] + if var != v: + eta_mess = vtof_msgs_eta[var] + lam_mess = vtof_msgs_lam[var] + eta_factor[start_in:start_in + var_dofs] += eta_mess + lam_factor[start_in:start_in + var_dofs, start_in:start_in + var_dofs] += lam_mess + start_in += var_dofs + + # Divide up parameters of distribution + dofs = adj_var_dofs[v] + eo = eta_factor[sdim:sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + + loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + lono = np.concatenate( + (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), + axis=1) + lnoo = np.concatenate( + (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), + axis=0) + lnono = np.concatenate(( + np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), + np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) + ), axis=0) + + new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno + + ftov_eta.append(new_message_eta[None, :]) + ftov_lam.append(new_message_lam[None, :]) + + sdim += dofs + + ftov_msgs_eta[start: start + num_optim_vars] = ftov_eta + ftov_msgs_lam[start: start + num_optim_vars] = ftov_lam + + start += num_optim_vars + + return ftov_msgs_eta, ftov_msgs_lam + + +@jax.jit +def pass_fac_to_var_messages_jax( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + adj_var_dofs_nested, +): + ftov_msgs_eta = [None] * len(vtof_msgs_eta) + ftov_msgs_lam = [None] * len(vtof_msgs_eta) + + start = 0 + for i in range(len(adj_var_dofs_nested)): + adj_var_dofs = adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + + inp_msgs_eta = vtof_msgs_eta[start: start + num_optim_vars] + inp_msgs_lam = vtof_msgs_lam[start: start + num_optim_vars] + + num_optim_vars = len(adj_var_dofs) + ftov_eta, ftov_lam = [], [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potentials_eta[i][0] + lam_factor = potentials_lam[i][0] + + # Take product of factor with incoming messages + start_in = 0 + for var in range(num_optim_vars): + var_dofs = adj_var_dofs[var] + if var != v: + eta_mess = vtof_msgs_eta[var] + lam_mess = vtof_msgs_lam[var] + eta_factor = eta_factor.at[start_in:start_in + var_dofs].add(eta_mess) + lam_factor = lam_factor.at[start_in:start_in + var_dofs, start_in:start_in + var_dofs].add(lam_mess) + start_in += var_dofs + + # Divide up parameters of distribution + dofs = adj_var_dofs[v] + eo = eta_factor[sdim:sdim + dofs] + eno = jnp.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + + loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + lono = jnp.concatenate( + (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), + axis=1) + lnoo = jnp.concatenate( + (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), + axis=0) + lnono = jnp.concatenate(( + jnp.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), + jnp.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) + ), axis=0) + + new_message_lam = loo - lono @ jnp.linalg.inv(lnono) @ lnoo + new_message_eta = eo - lono @ jnp.linalg.inv(lnono) @ eno + + ftov_eta.append(new_message_eta[None, :]) + ftov_lam.append(new_message_lam[None, :]) + + sdim += dofs + + ftov_msgs_eta[start: start + num_optim_vars] = ftov_eta + ftov_msgs_lam[start: start + num_optim_vars] = ftov_lam + + start += num_optim_vars + + return ftov_msgs_eta, ftov_msgs_lam + + + + + +if __name__ == "__main__": + + adj_var_dofs_nested = [[2], [2], [2], [2], [2], [2], [2], [2], [2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]] + + potentials_eta = [torch.tensor([[0., 0.]]), torch.tensor([[ 0.5292, -0.1270]]), torch.tensor([[ 1.2858, -0.2724]]), torch.tensor([[0.2065, 0.5016]]), torch.tensor([[0.6295, 0.5622]]), torch.tensor([[1.3565, 0.3479]]), torch.tensor([[-0.0382, 1.1380]]), torch.tensor([[0.7259, 1.0533]]), torch.tensor([[1.1630, 1.0795]]), torch.tensor([[-100.4221, -5.8282, 100.4221, 5.8282]]), torch.tensor([[ 11.0062, -111.4472, -11.0062, 111.4472]]), torch.tensor([[-109.0159, -5.0249, 109.0159, 5.0249]]), torch.tensor([[ -9.0086, -93.1627, 9.0086, 93.1627]]), torch.tensor([[ 1.2289, -90.6423, -1.2289, 90.6423]]), torch.tensor([[-97.3211, -5.3036, 97.3211, 5.3036]]), torch.tensor([[ 6.9166, -96.0325, -6.9166, 96.0325]]), torch.tensor([[-93.1283, 8.4521, 93.1283, -8.4521]]), torch.tensor([[ 6.7125, -99.8733, -6.7125, 99.8733]]), torch.tensor([[ 11.1731, -102.3442, -11.1731, 102.3442]]), torch.tensor([[-116.5980, -7.4204, 116.5980, 7.4204]]), torch.tensor([[-98.0816, 8.8763, 98.0816, -8.8763]])] + potentials_lam = [torch.tensor([[[10000., 0.], + [ 0., 10000.]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], + [0.0000, 0.5917]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], + [ 0., 100., 0., -100.], + [-100., 0., 100., 0.], + [ 0., -100., 0., 100.]]])] + + vtof_msgs_eta = [torch.tensor([[ 0.8536, -1.5929]]), torch.tensor([[182.3461, 16.7745]]), torch.tensor([[222.8854, 13.1250]]), torch.tensor([[-10.1678, 202.9393]]), torch.tensor([[200.4927, 213.6843]]), torch.tensor([[264.5976, 132.6887]]), torch.tensor([[-17.9007, 222.3988]]), torch.tensor([[127.5813, 277.0478]]), torch.tensor([[191.0187, 201.1600]]), torch.tensor([[ 5.6620, -4.6983]]), torch.tensor([[83.3856, 10.9277]]), torch.tensor([[-4.8085, 3.1053]]), torch.tensor([[ 0.9854, 93.0631]]), torch.tensor([[153.1307, 16.3761]]), torch.tensor([[98.1263, 3.3349]]), torch.tensor([[129.7635, 5.8644]]), torch.tensor([[140.2319, 158.5661]]), torch.tensor([[127.3308, 9.2454]]), torch.tensor([[187.8824, 92.8337]]), torch.tensor([[-16.5414, 145.2973]]), torch.tensor([[152.8149, 148.6686]]), torch.tensor([[ -4.1601, 169.0230]]), torch.tensor([[-12.0344, 99.1287]]), torch.tensor([[153.7062, 168.3496]]), torch.tensor([[149.0974, 72.7772]]), torch.tensor([[157.2429, 167.7175]]), torch.tensor([[ 70.8858, 152.1307]]), torch.tensor([[196.2848, 100.8102]]), torch.tensor([[ 99.5512, 100.5530]]), torch.tensor([[ -5.9426, 125.5461]]), torch.tensor([[ 87.5787, 197.8408]]), torch.tensor([[ 98.8758, 207.2840]]), torch.tensor([[ 93.7936, 102.7661]])] + vtof_msgs_lam = [torch.tensor([[95.7949, 0.0000], + [ 0.0000, 95.7949]]), torch.tensor([[190.3769, 0.0000], + [ 0.0000, 190.3769]]), torch.tensor([[109.9605, 0.0000], + [ 0.0000, 109.9605]]), torch.tensor([[190.3769, 0.0000], + [ 0.0000, 190.3769]]), torch.tensor([[197.8604, 0.0000], + [ 0.0000, 197.8604]]), torch.tensor([[132.5915, 0.0000], + [ 0.0000, 132.5915]]), torch.tensor([[109.9605, 0.0000], + [ 0.0000, 109.9605]]), torch.tensor([[132.5915, 0.0000], + [ 0.0000, 132.5915]]), torch.tensor([[99.8496, 0.0000], + [ 0.0000, 99.8496]]), torch.tensor([[10047.8975, 0.0000], + [ 0.0000, 10047.8975]]), torch.tensor([[91.9540, 0.0000], + [ 0.0000, 91.9540]]), torch.tensor([[10047.8975, 0.0000], + [ 0.0000, 10047.8975]]), torch.tensor([[91.9540, 0.0000], + [ 0.0000, 91.9540]]), torch.tensor([[158.0642, 0.0000], + [ 0.0000, 158.0642]]), torch.tensor([[49.3043, 0.0000], + [ 0.0000, 49.3043]]), torch.tensor([[132.5106, 0.0000], + [ 0.0000, 132.5106]]), torch.tensor([[141.4631, 0.0000], + [ 0.0000, 141.4631]]), torch.tensor([[61.8396, 0.0000], + [ 0.0000, 61.8396]]), torch.tensor([[94.9975, 0.0000], + [ 0.0000, 94.9975]]), torch.tensor([[132.5106, 0.0000], + [ 0.0000, 132.5106]]), torch.tensor([[141.4631, 0.0000], + [ 0.0000, 141.4631]]), torch.tensor([[158.0642, 0.0000], + [ 0.0000, 158.0642]]), torch.tensor([[49.3043, 0.0000], + [ 0.0000, 49.3043]]), torch.tensor([[156.5110, 0.0000], + [ 0.0000, 156.5110]]), torch.tensor([[72.2502, 0.0000], + [ 0.0000, 72.2502]]), torch.tensor([[156.5110, 0.0000], + [ 0.0000, 156.5110]]), torch.tensor([[72.2502, 0.0000], + [ 0.0000, 72.2502]]), torch.tensor([[99.7104, 0.0000], + [ 0.0000, 99.7104]]), torch.tensor([[50.5165, 0.0000], + [ 0.0000, 50.5165]]), torch.tensor([[61.8396, 0.0000], + [ 0.0000, 61.8396]]), torch.tensor([[94.9975, 0.0000], + [ 0.0000, 94.9975]]), torch.tensor([[99.7104, 0.0000], + [ 0.0000, 99.7104]]), torch.tensor([[50.5165, 0.0000], + [ 0.0000, 50.5165]])] + vtof_msgs_eta = torch.cat(vtof_msgs_eta) + # vtof_msgs_lam = torch.cat([m[None, ...] for m in vtof_msgs_lam]) + + t1 = time.time() + times = [] + for i in range(100): + t_start = time.time() + ftov_msgs_eta, ftov_msgs_lam = pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + adj_var_dofs_nested, + ) + times.append(time.time() - t_start) + + t2 = time.time() + print("-------------- TORCH --------------") + print("elapsed", t2 - t1) + print("min max mean", np.min(times), np.max(times), np.mean(times)) + + # print(ftov_msgs_eta) + # print(ftov_msgs_lam) + + + potentials_eta_jax = [jnp.array(pe) for pe in potentials_eta] + potentials_lam_jax = [jnp.array(pe) for pe in potentials_lam] + vtof_msgs_eta_jax = jnp.array(vtof_msgs_eta) + vtof_msgs_lam_jax = [jnp.array(pe) for pe in vtof_msgs_lam] + + t1 = time.time() + times = [] + for i in range(10): + t_start = time.time() + ftov_msgs_eta_jax, ftov_msgs_lam_jax = pass_fac_to_var_messages_jax( + potentials_eta_jax, + potentials_lam_jax, + vtof_msgs_eta_jax, + vtof_msgs_lam_jax, + adj_var_dofs_nested, + ) + times.append(time.time() - t_start) + + t2 = time.time() + print("\n\n") + print("-------------- JAX --------------") + print("elapsed", t2 - t1) + print("min max mean", np.min(times), np.max(times), np.mean(times)) + + # print(ftov_msgs_eta_jax) + # print(ftov_msgs_lam_jax) diff --git a/theseus/optimizer/gbp/pose_graph_gbp.py b/theseus/optimizer/gbp/pose_graph_gbp.py new file mode 100644 index 000000000..af1d2a75d --- /dev/null +++ b/theseus/optimizer/gbp/pose_graph_gbp.py @@ -0,0 +1,1092 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# This example illustrates the Gaussian Belief Propagation (GBP) optimizer +# for a 2D pose graph optimization problem. +# Linear problem where we are estimating the (x, y)position of 9 nodes, +# arranged in a 3x3 grid. +# Linear factors connect each node to its adjacent nodes. + +import abc +from typing import Any, Dict, Optional, Type +from enum import Enum +from dataclasses import dataclass +import math + +import torch +import theseus as th + +import numpy as np +import numdifftools as nd + +from collections import defaultdict +import time + +from typing import List, Callable, Optional, Union, Sequence + +import matplotlib.pylab as plt +import cv2 + +from theseus.core import Objective, CostFunction +from theseus.optimizer import Linearization, Optimizer, OptimizerInfo, VariableOrdering +from theseus.optimizer.linear import LinearSolver +import theseus.constants + + +""" +TODO + - Parallelise factor to variable message comp + - Benchmark speed + - test jax implementation of message comp functions + - add class for message schedule + - damping for lie algebra vars + - solving inverse problem to compute message mean +""" + + +@dataclass +class GBPOptimizerParams: + abs_err_tolerance: float + rel_err_tolerance: float + max_iterations: int + + def update(self, params_dict): + for param, value in params_dict.items(): + if hasattr(self, param): + setattr(self, param, value) + else: + raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") + + +class NonlinearOptimizerStatus(Enum): + START = 0 + CONVERGED = 1 + MAX_ITERATIONS = 2 + FAIL = -1 + + +# All info information is batched +@dataclass +class NonlinearOptimizerInfo(OptimizerInfo): + converged_iter: torch.Tensor + best_iter: torch.Tensor + err_history: Optional[torch.Tensor] + last_err: torch.Tensor + best_err: torch.Tensor + + +class BackwardMode(Enum): + FULL = 0 + IMPLICIT = 1 + TRUNCATED = 2 + + +class Gaussian: + def __init__(self, mean: th.Variable): + self.name = mean.name + "_gaussian" + self.mean = mean + self.lam = torch.zeros( + mean.shape[0], mean.dof(), mean.dof(), dtype=mean.dtype) + + +class CostFunctionOrdering: + def __init__(self, objective: Objective, default_order: bool = True): + self.objective = objective + self._cf_order: List[CostFunction] = [] + self._cf_name_to_index: Dict[str, int] = {} + if default_order: + self._compute_default_order(objective) + + def _compute_default_order(self, objective: Objective): + assert not self._cf_order and not self._cf_name_to_index + cur_idx = 0 + for cf_name, cf in objective.cost_functions.items(): + if cf_name in self._cf_name_to_index: + continue + self._cf_order.append(cf) + self._cf_name_to_index[cf_name] = cur_idx + cur_idx += 1 + + def index_of(self, key: str) -> int: + return self._cf_name_to_index[key] + + def __getitem__(self, index) -> CostFunction: + return self._cf_order[index] + + def __iter__(self): + return iter(self._cf_order) + + def append(self, cf: CostFunction): + if cf in self._cf_order: + raise ValueError( + f"Cost Function {cf.name} has already been added to the order." + ) + if cf.name not in self.objective.cost_functions: + raise ValueError( + f"Cost Function {cf.name} is not a cost function for the objective." + ) + self._cf_order.append(cf) + self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 + + def remove(self, cf: CostFunction): + self._cf_order.remove(cf) + del self._cf_name_to_index[cf.name] + + def extend(self, cfs: Sequence[CostFunction]): + for cf in cfs: + self.append(cf) + + @property + def complete(self): + return len(self._cf_order) == self.objective.size_variables() + + +# Compute the factor at current adjacent beliefs. +def compute_factor(cf, lie=True): + J, error = cf.weighted_jacobians_error() + J_stk = torch.cat(J, dim=-1) + + lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) + + optim_vars_stk = torch.cat([v.data for v in cf.optim_vars], dim=-1) + eta = - torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) + if lie is False: + eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) + eta = eta.squeeze(-1) + + return eta, lam + + +def pass_var_to_fac_messages( + ftov_msgs_eta, + ftov_msgs_lam, + var_ix_for_edges, + n_vars, + max_dofs, +): + belief_eta = torch.zeros( + n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) + belief_lam = torch.zeros( + n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) + + belief_eta = belief_eta.index_add( + 0, var_ix_for_edges, ftov_msgs_eta) + belief_lam = belief_lam.index_add( + 0, var_ix_for_edges, ftov_msgs_lam) + + vtof_msgs_eta = belief_eta[var_ix_for_edges] - ftov_msgs_eta + vtof_msgs_lam = belief_lam[var_ix_for_edges] - ftov_msgs_lam + + return vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam + + +def pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + adj_var_dofs_nested: List[List], +): + ftov_msgs_eta = torch.zeros_like(vtof_msgs_eta) + ftov_msgs_lam = torch.zeros_like(vtof_msgs_lam) + + start = 0 + for i in range(len(adj_var_dofs_nested)): + adj_var_dofs = adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + ftov_eta, ftov_lam = ftov_comp_mess( + adj_var_dofs, + potentials_eta[i], + potentials_lam[i], + vtof_msgs_eta[start: start + num_optim_vars], + vtof_msgs_lam[start: start + num_optim_vars], + ) + + ftov_msgs_eta[start: start + num_optim_vars] = torch.cat(ftov_eta) + ftov_msgs_lam[start: start + num_optim_vars] = torch.cat(ftov_lam) + + start += num_optim_vars + + return ftov_msgs_eta, ftov_msgs_lam + + +def euclidean_jac(var, tau, jacobians=None): + jacobians.extend([torch.eye(2)[None, ...]]) + return tau + + +# Transforms message to tangent plane at var +# if return_mean is True, return the (mean, lam) else return (eta, lam). +# Generalises the local function by transforming the covariance as well as mean. +def local_gaussian( + gauss: Gaussian, + var: th.Manifold, + return_mean: bool = True, +) -> [torch.Tensor, torch.Tensor]: + # mean_tp is message mean in tangent space / plane at var + mean_tp = var.local(gauss.mean) + + jac = [] + # th.exp_map(var, mean_tp, jacobians=jac) + euclidean_jac(var, mean_tp, jacobians=jac) + jac = jac[0] + + # lam_tp is lambda matrix in the tangent plane + lam_tp = torch.bmm(torch.bmm(jac.transpose(-1, -2), gauss.lam), jac) + + if return_mean: + return mean_tp, lam_tp + + else: + eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) + return eta_tp, lam_tp + + +# Transforms Gaussian in the tangent plane at var to Gaussian where the mean +# is a group element and the precision matrix is defined in the tangent plane +# at the mean. +# Generalises the retract function by transforming the covariance as well as mean. +# out_gauss is the transformed Gaussian that is updated in place. +def retract_gaussian( + mean_tp: torch.Tensor, + lam_tp: torch.Tensor, + var: th.Manifold, + out_gauss: Gaussian, +) -> [th.Manifold, torch.Tensor]: + mean = var.retract(mean_tp) + + jac = [] + # th.exp_map(var, tau_a, jacobians=jac) + euclidean_jac(var, mean, jacobians=jac) + jac = jac[0] + lam = torch.bmm(torch.bmm(jac.transpose(-1, -2), lam_tp), jac) + + out_gauss.mean.update(mean.data) + out_gauss.lam = lam + + +def pass_var_to_fac_messages_and_update_beliefs_lie( + ftov_msgs, + vtof_msgs, + var_ordering, + var_ix_for_edges, +): + belief_covs = [] + + for i, var in enumerate(var_ordering): + + # Collect all incoming messages in the tangent space at the current belief + taus = [] # message means + lams_tp = [] # message lams + for j, msg in enumerate(ftov_msgs): + if var_ix_for_edges[j] == i: + tau, lam_tp = local_gaussian(msg, var, return_mean=True) + taus.append(tau[None, ...]) + lams_tp.append(lam_tp[None, ...]) + + taus = torch.cat(taus) + lams_tp = torch.cat(lams_tp) + + lam_tau = lams_tp.sum(dim=0) + + # Compute outgoing messages + ix = 0 + for j, msg in enumerate(ftov_msgs): + if var_ix_for_edges[j] == i: + taus_inc = torch.cat((taus[:ix], taus[ix + 1:])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1:])) + + lam_a = lams_inc.sum(dim=0) + if lam_a.count_nonzero() == 0: + vtof_msgs[j].mean.data[:] = 0.0 + vtof_msgs[j].lam = lam_a + else: + inv_lam_a = torch.inverse(lam_a) + sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum(dim=0) + tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) + retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) + ix += 1 + + # update belief mean and variance + # if no incoming messages then leave current belief unchanged + if lam_tau.count_nonzero() != 0: + inv_lam_tau = torch.inverse(lam_tau) + sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) + tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) + + belief = Gaussian(var) + retract_gaussian(tau, lam_tau, var, belief) + + +def pass_fac_to_var_messages_lie( + potentials_eta, + potentials_lam, + lin_points, + vtof_msgs, + ftov_msgs, + adj_var_dofs_nested: List[List], + damping: torch.Tensor, +): + start = 0 + for i in range(len(adj_var_dofs_nested)): + adj_var_dofs = adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + new_messages = ftov_comp_mess_lie( + potentials_eta[i], + potentials_lam[i], + lin_points[i], + vtof_msgs[start: start + num_optim_vars], + ftov_msgs[start: start + num_optim_vars], + damping[start: start + num_optim_vars], + ) + + start += num_optim_vars + + +# Compute all outgoing messages from the factor. +def ftov_comp_mess( + adj_var_dofs, + potential_eta, + potential_lam, + vtof_msgs_eta, + vtof_msgs_lam, +): + num_optim_vars = len(adj_var_dofs) + messages_eta, messages_lam = [], [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potential_eta.clone()[0] + lam_factor = potential_lam.clone()[0] + + # Take product of factor with incoming messages + start = 0 + for var in range(num_optim_vars): + var_dofs = adj_var_dofs[var] + if var != v: + eta_mess = vtof_msgs_eta[var] + lam_mess = vtof_msgs_lam[var] + eta_factor[start:start + var_dofs] += eta_mess + lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess + start += var_dofs + + # Divide up parameters of distribution + dofs = adj_var_dofs[v] + eo = eta_factor[sdim:sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + + loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + lono = np.concatenate( + (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), + axis=1) + lnoo = np.concatenate( + (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), + axis=0) + lnono = np.concatenate(( + np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), + np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) + ), axis=0) + + new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno + + messages_eta.append(new_message_eta[None, :]) + messages_lam.append(new_message_lam[None, :]) + + sdim += dofs + + return messages_eta, messages_lam + + +# Compute all outgoing messages from the factor. +def ftov_comp_mess_lie( + potential_eta, + potential_lam, + lin_points, + vtof_msgs, + ftov_msgs, + damping, +): + num_optim_vars = len(lin_points) + new_messages = [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potential_eta.clone()[0] + lam_factor = potential_lam.clone()[0] + + # Take product of factor with incoming messages. + # Convert mesages to tangent space at linearisation point. + start = 0 + for i in range(num_optim_vars): + var_dofs = lin_points[i].dof() + if i != v: + eta_mess, lam_mess = local_gaussian( + vtof_msgs[i], lin_points[i], return_mean=False) + eta_factor[start:start + var_dofs] += eta_mess[0] + lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess[0] + start += var_dofs + + # Divide up parameters of distribution + dofs = lin_points[v].dof() + eo = eta_factor[sdim:sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + + loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + lono = np.concatenate( + (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), + axis=1) + lnoo = np.concatenate( + (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), + axis=0) + lnono = np.concatenate(( + np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), + np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) + ), axis=0) + + new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno + + # damping in tangent space at linearisation point + # prev_mess_eta, prev_mess_lam = local_gaussian( + # vtof_msgs[v], lin_points[v], return_mean=False) + # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] + # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] + + if new_mess_lam.count_nonzero() == 0: + new_mess = Gaussian(lin_points[v].copy()) + new_mess.mean.data[:] = 0.0 + else: + new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) + new_mess_mean = new_mess_mean[None, ...] + new_mess_lam = new_mess_lam[None, ...] + + new_mess = Gaussian(lin_points[v].copy()) + retract_gaussian(new_mess_mean, new_mess_lam, lin_points[v], new_mess) + new_messages.append(new_mess) + + sdim += dofs + + # update messages + for v in range(num_optim_vars): + ftov_msgs[v].mean.update(new_messages[v].mean.data) + ftov_msgs[v].lam = new_messages[v].lam + + return new_messages + +# Follows notation from https://arxiv.org/pdf/2202.03314.pdf +class GaussianBeliefPropagation(Optimizer, abc.ABC): + def __init__( + self, + objective: Objective, + *args, + linearization_cls: Optional[Type[Linearization]] = None, + linearization_kwargs: Optional[Dict[str, Any]] = None, + abs_err_tolerance: float = 1e-10, + rel_err_tolerance: float = 1e-8, + max_iterations: int = 20, + ): + super().__init__(objective) + self.ordering = VariableOrdering(objective, default_order=True) + self.cf_ordering = CostFunctionOrdering(objective) + + self.schedule = None + self.damping = None + + self.params = GBPOptimizerParams( + abs_err_tolerance, rel_err_tolerance, max_iterations + ) + + self.n_edges = sum([cf.num_optim_vars() for cf in self.cf_ordering]) + self.max_dofs = max([var.dof() for var in self.ordering]) + + # create arrays for indexing the messages + var_ixs = [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] for cf in + self.cf_ordering + ] + var_ixs = [item for sublist in var_ixs for item in sublist] + self.var_ix_for_edges = torch.tensor(var_ixs).long() + + self.adj_var_dofs_nested = [ + [var.shape[1] for var in cf.optim_vars] for cf in self.cf_ordering + ] + + lie_groups = False + for v in self.ordering: + if isinstance(v, th.LieGroup) and not isinstance(v, th.Vector): + lie_groups = True + self.lie_groups = lie_groups + print("lie groups:", self.lie_groups) + + def set_params(self, **kwargs): + self.params.update(kwargs) + + def _check_convergence(self, err: torch.Tensor, last_err: torch.Tensor): + assert not torch.is_grad_enabled() + if err.abs().mean() < theseus.constants.EPS: + return torch.ones_like(err).bool() + + abs_error = (last_err - err).abs() + rel_error = abs_error / last_err + return (abs_error < self.params.abs_err_tolerance).logical_or( + rel_error < self.params.rel_err_tolerance + ) + + def _maybe_init_best_solution( + self, do_init: bool = False + ) -> Optional[Dict[str, torch.Tensor]]: + if not do_init: + return None + solution_dict = {} + for var in self.ordering: + solution_dict[var.name] = var.data.detach().clone().cpu() + return solution_dict + + def _init_info( + self, track_best_solution: bool, track_err_history: bool, verbose: bool + ) -> NonlinearOptimizerInfo: + with torch.no_grad(): + last_err = self.objective.error_squared_norm() / 2 + best_err = last_err.clone() if track_best_solution else None + if track_err_history: + err_history = ( + torch.ones(self.objective.batch_size, self.params.max_iterations + 1) + * math.inf + ) + assert last_err.grad_fn is None + err_history[:, 0] = last_err.clone().cpu() + else: + err_history = None + return NonlinearOptimizerInfo( + best_solution=self._maybe_init_best_solution(do_init=track_best_solution), + last_err=last_err, + best_err=best_err, + status=np.array( + [NonlinearOptimizerStatus.START] * self.objective.batch_size + ), + converged_iter=torch.zeros_like(last_err, dtype=torch.long), + best_iter=torch.zeros_like(last_err, dtype=torch.long), + err_history=err_history, + ) + + def _update_info( + self, + info: NonlinearOptimizerInfo, + current_iter: int, + err: torch.Tensor, + converged_indices: torch.Tensor, + ): + info.converged_iter += 1 - converged_indices.long() + if info.err_history is not None: + assert err.grad_fn is None + info.err_history[:, current_iter + 1] = err.clone().cpu() + + if info.best_solution is not None: + # Only copy best solution if needed (None means track_best_solution=False) + assert info.best_err is not None + good_indices = err < info.best_err + info.best_iter[good_indices] = current_iter + for var in self.ordering: + info.best_solution[var.name][good_indices] = ( + var.data.detach().clone()[good_indices].cpu() + ) + + info.best_err = torch.minimum(info.best_err, err) + + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + np.array(converged_indices.detach().cpu()) + ] = NonlinearOptimizerStatus.CONVERGED + + # Modifies the (no grad) info in place to add data of grad loop info + def _merge_infos( + self, + grad_loop_info: NonlinearOptimizerInfo, + num_no_grad_iter: int, + backward_num_iterations: int, + info: NonlinearOptimizerInfo, + ): + # Concatenate error histories + if info.err_history is not None: + info.err_history[:, num_no_grad_iter:] = grad_loop_info.err_history[ + :, : backward_num_iterations + 1 + ] + # Merge best solution and best error + if info.best_solution is not None: + best_solution = {} + best_err_no_grad = info.best_err + best_err_grad = grad_loop_info.best_err + idx_no_grad = best_err_no_grad < best_err_grad + best_err = torch.minimum(best_err_no_grad, best_err_grad) + for var_name in info.best_solution: + sol_no_grad = info.best_solution[var_name] + sol_grad = grad_loop_info.best_solution[var_name] + best_solution[var_name] = torch.where( + idx_no_grad, sol_no_grad, sol_grad + ) + info.best_solution = best_solution + info.best_err = best_err + + # Merge the converged status into the info from the detached loop, + M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS + assert np.all( + (grad_loop_info.status[M] == NonlinearOptimizerStatus.MAX_ITERATIONS) + | (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED) + ) + info.status[M] = grad_loop_info.status[M] + info.converged_iter[M] = ( + info.converged_iter[M] + grad_loop_info.converged_iter[M] + ) + # If didn't coverge in either loop, remove misleading converged_iter value + info.converged_iter[ + M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS) + ] = -1 + + # Linearizes factors at current belief if beliefs have deviated + # from the linearization point by more than the threshold. + def _linearize( + self, + potentials_eta, + potentials_lam, + lin_points, + lp_dist_thresh: float = None, + lie=False, + ): + do_lins = [] + for i, cf in enumerate(self.cf_ordering): + + do_lin = False + if lp_dist_thresh is None: + do_lin = True + else: + lp_dists = [ + lp.local(cf.optim_var_at(j)).norm() + for j, lp in enumerate(lin_points[i]) + ] + do_lin = np.max(lp_dists) > lp_dist_thresh + + do_lins.append(do_lin) + + if do_lin: + potential_eta, potential_lam = compute_factor(cf, lie=lie) + + potentials_eta[i] = potential_eta + potentials_lam[i] = potential_lam + + for j, var in enumerate(cf.optim_vars): + lin_points[i][j].update(var.data) + + # print(f"Linearised {np.sum(do_lins)} out of {len(do_lins)} factors.") + return potentials_eta, potentials_lam, lin_points + + # loop for the iterative optimizer + def _optimize_loop( + self, + start_iter: int, + num_iter: int, + info: OptimizerInfo, + verbose: bool, + truncated_grad_loop: bool, + relin_threshold: float = 0.1, + damping: float = 0.0, + dropout: float = 0.0, + lp_dist_thresh: float = 0.1, + **kwargs, + ): + # initialise messages with zeros + vtof_msgs_eta = torch.zeros( + self.n_edges, self.max_dofs, dtype=self.objective.dtype) + vtof_msgs_lam = torch.zeros( + self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype) + ftov_msgs_eta = vtof_msgs_eta.clone() + ftov_msgs_lam = vtof_msgs_lam.clone() + + # compute factor potentials for the first time + potentials_eta = [None] * self.objective.size_cost_functions() + potentials_lam = [None] * self.objective.size_cost_functions() + lin_points = [ + [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] + for cf in self.cf_ordering + ] + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None) + + converged_indices = torch.zeros_like(info.last_err).bool() + for it_ in range(start_iter, start_iter + num_iter): + + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None) + + msgs_eta, msgs_lam = pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + self.adj_var_dofs_nested, + ) + + # damping + # damping = self.gbp_settings.get_damping(iters_since_relin) + if isinstance(damping, float): + damping = torch.full([len(msgs_eta)], damping) + + # dropout can be implemented through damping + if dropout != 0.0: + dropout_ixs = torch.rand(len(msgs_eta)) < dropout + damping[dropout_ixs] = 1.0 + + ftov_msgs_eta = (1 - damping[:, None]) * msgs_eta + damping[:, None] * ftov_msgs_eta + ftov_msgs_lam = (1 - damping[:, None, None]) * msgs_lam + damping[:, None, None] * ftov_msgs_lam + + ( + vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam + ) = pass_var_to_fac_messages( + ftov_msgs_eta, + ftov_msgs_lam, + self.var_ix_for_edges, + len(self.ordering._var_order), + self.max_dofs, + ) + + # update beliefs + belief_cov = torch.inverse(belief_lam) + belief_mean = torch.matmul(belief_cov, belief_eta.unsqueeze(-1)).squeeze() + for i, var in enumerate(self.ordering): + var.update(data=belief_mean[i][None, :]) + + # check for convergence + with torch.no_grad(): + err = self.objective.error_squared_norm() / 2 + self._update_info(info, it_, err, converged_indices) + if verbose: + print( + f"GBP. Iteration: {it_+1}. " + f"Error: {err.mean().item()}" + ) + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + converged_indices.cpu().numpy() + ] = NonlinearOptimizerStatus.CONVERGED + if converged_indices.all(): + break # nothing else will happen at this point + info.last_err = err + + info.status[ + info.status == NonlinearOptimizerStatus.START + ] = NonlinearOptimizerStatus.MAX_ITERATIONS + return info + + # loop for the iterative optimizer + def _optimize_loop_lie( + self, + start_iter: int, + num_iter: int, + info: OptimizerInfo, + verbose: bool, + truncated_grad_loop: bool, + relin_threshold: float = 0.1, + damping: float = 0.0, + dropout: float = 0.0, + lp_dist_thresh: float = 0.1, + **kwargs, + ): + # initialise messages with zeros + vtof_msgs = [] + ftov_msgs = [] + for cf in self.cf_ordering: + for var in cf.optim_vars: + vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") + vtof_msg_mu.data[:] = 0 + ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") + vtof_msgs.append(Gaussian(vtof_msg_mu)) + ftov_msgs.append(Gaussian(ftov_msg_mu)) + + # compute factor potentials for the first time + potentials_eta = [None] * self.objective.size_cost_functions() + potentials_lam = [None] * self.objective.size_cost_functions() + lin_points = [ + [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] + for cf in self.cf_ordering + ] + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, + lp_dist_thresh=None, lie=True) + + converged_indices = torch.zeros_like(info.last_err).bool() + for it_ in range(start_iter, start_iter + num_iter): + + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, + lp_dist_thresh=None, lie=True) + + # damping + # damping = self.gbp_settings.get_damping(iters_since_relin) + if isinstance(damping, float): + damping = torch.full([self.n_edges], damping) + + # dropout can be implemented through damping + if dropout != 0.0: + dropout_ixs = torch.rand(self.n_edges) < dropout + damping[dropout_ixs] = 1.0 + + pass_fac_to_var_messages_lie( + potentials_eta, + potentials_lam, + lin_points, + vtof_msgs, + ftov_msgs, + self.adj_var_dofs_nested, + damping, + ) + + belief_covs = pass_var_to_fac_messages_and_update_beliefs_lie( + ftov_msgs, + vtof_msgs, + self.ordering, + self.var_ix_for_edges, + ) + + # check for convergence + if it_ > 0: + with torch.no_grad(): + err = self.objective.error_squared_norm() / 2 + self._update_info(info, it_, err, converged_indices) + if verbose: + print( + f"GBP. Iteration: {it_+1}. " + f"Error: {err.mean().item()}" + ) + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + converged_indices.cpu().numpy() + ] = NonlinearOptimizerStatus.CONVERGED + if converged_indices.all() and it_ > 1: + break # nothing else will happen at this point + info.last_err = err + + info.status[ + info.status == NonlinearOptimizerStatus.START + ] = NonlinearOptimizerStatus.MAX_ITERATIONS + return info + + # `track_best_solution` keeps a **detached** copy (as in no gradient info) + # of the best variables found, but it is optional to avoid unnecessary copying + # if this is not needed + def _optimize_impl( + self, + track_best_solution: bool = False, + track_err_history: bool = False, + verbose: bool = False, + backward_mode: BackwardMode = BackwardMode.FULL, + damping: float = 0.0, + dropout: float = 0.0, + **kwargs, + ) -> OptimizerInfo: + if damping > 1.0 or damping < 0.0: + raise NotImplementedError( + "damping must be in between 0 and 1." + ) + if dropout > 1.0 or dropout < 0.0: + raise NotImplementedError( + "dropout probability must be in between 0 and 1." + ) + + with torch.no_grad(): + info = self._init_info(track_best_solution, track_err_history, verbose) + + if verbose: + print( + f"GBP optimizer. Iteration: 0. " + f"Error: {info.last_err.mean().item()}" + ) + + grad = False + if backward_mode == BackwardMode.FULL: + grad = True + + with torch.set_grad_enabled(grad): + + + # if self.lie_groups: + info = self._optimize_loop_lie( + start_iter=0, + num_iter=self.params.max_iterations, + info=info, + verbose=verbose, + truncated_grad_loop=False, + damping=damping, + dropout=dropout, + **kwargs, + ) + # else: + # info = self._optimize_loop( + # start_iter=0, + # num_iter=self.params.max_iterations, + # info=info, + # verbose=verbose, + # truncated_grad_loop=False, + # damping=damping, + # dropout=dropout, + # **kwargs, + # ) + # If didn't coverge, remove misleading converged_iter value + info.converged_iter[ + info.status == NonlinearOptimizerStatus.MAX_ITERATIONS + ] = -1 + return info + + +if __name__ == "__main__": + + np.random.seed(1) + torch.manual_seed(0) + + size = 3 + dim = 2 + + noise_cov = np.array([0.01, 0.01]) + + prior_noise_std = 0.2 + prior_sigma = np.array([1.3**2, 1.3**2]) + + init_noises = np.random.normal(np.zeros([size*size, 2]), prior_noise_std) + meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) + + # create theseus objective ------------------------------------- + + objective = th.Objective() + inputs = {} + + n_poses = size * size + + # create variables + poses = [] + for i in range(n_poses): + poses.append(th.Vector(data=torch.rand(1, 2), name=f"x{i}")) + + # add prior cost constraints with VariableDifference cost + prior_std = 1.3 + anchor_std = 0.01 + prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") + anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") + + p = 0 + for i in range(size): + for j in range(size): + init = torch.Tensor([j, i]) + if i == 0 and j == 0: + w = anchor_w + else: + # noise_init = torch.normal(torch.zeros(2), prior_noise_std) + init = init + torch.FloatTensor(init_noises[p]) + w = prior_w + + prior_target = th.Vector(data=init, name=f"prior_{p}") + inputs[f"x{p}"] = init[None, :] + inputs[f"prior_{p}"] = init[None, :] + + cf = th.eb.VariableDifference( + poses[p], w, prior_target, name=f"prior_cost_{p}") + + objective.add(cf) + + p += 1 + + # Measurement cost functions + + meas_std = 0.1 + meas_w = th.ScaleCostWeight(1 / meas_std, name="prior_weight") + + m = 0 + for i in range(size): + for j in range(size): + if j < size - 1: + measurement = torch.Tensor([1., 0.]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = i * size + j + 1 + + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] + + cf = th.eb.Between( + poses[ix0], poses[ix1], + meas_w, meas, name=f"meas_cost_{m}") + objective.add(cf) + m += 1 + + if i < size - 1: + measurement = torch.Tensor([0., 1.]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = (i + 1) * size + j + + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] + + cf = th.eb.Between( + poses[ix0], poses[ix1], + meas_w, meas, name=f"meas_cost_{m}") + objective.add(cf) + m += 1 + + # # objective.update(init_dict) + # print("Initial cost:", objective.error_squared_norm()) + + # fg.print(brief=True) + + # # for vis --------------- + + # joint = fg.get_joint() + # marg_covs = np.diag(joint.cov())[::2] + # map_soln = fg.MAP().reshape([size * size, 2]) + + # Solve with Gauss Newton --------------- + + # print("inputs", inputs) + + optimizer = GaussianBeliefPropagation( + objective, + max_iterations=100, + ) + theseus_optim = th.TheseusLayer(optimizer) + + optim_arg = { + "track_best_solution": True, + "track_err_history": True, + "verbose": True, + "backward_mode": BackwardMode.FULL, + "damping": 0.6, + "dropout": 0.0, + } + updated_inputs, info = theseus_optim.forward(inputs, optim_arg) + + print("updated_inputs", updated_inputs) + print("info", info) + + import ipdb; ipdb.set_trace() + + + # optimizer = th.GaussNewton( + # objective, + # max_iterations=15, + # step_size=0.5, + # ) + # theseus_optim = th.TheseusLayer(optimizer) + + # with torch.no_grad(): + # optim_args = {"track_best_solution": True, "verbose": True} + # updated_inputs, info = theseus_optim.forward(inputs, optim_args) + # print("updated_inputs", updated_inputs) + # print("info", info) + + # import ipdb; ipdb.set_trace() + From 31a8874fed86e99add3eca7d3a793e70b4340656 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 16 Mar 2022 18:53:54 +0000 Subject: [PATCH 02/64] gbp uses exp_map jacobians --- theseus/optimizer/gbp/pose_graph_gbp.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/theseus/optimizer/gbp/pose_graph_gbp.py b/theseus/optimizer/gbp/pose_graph_gbp.py index af1d2a75d..0e05d6ee8 100644 --- a/theseus/optimizer/gbp/pose_graph_gbp.py +++ b/theseus/optimizer/gbp/pose_graph_gbp.py @@ -232,7 +232,8 @@ def local_gaussian( jac = [] # th.exp_map(var, mean_tp, jacobians=jac) - euclidean_jac(var, mean_tp, jacobians=jac) + var.__class__.exp_map(mean_tp, jac) + # euclidean_jac(var, mean_tp, jacobians=jac) jac = jac[0] # lam_tp is lambda matrix in the tangent plane @@ -260,10 +261,12 @@ def retract_gaussian( mean = var.retract(mean_tp) jac = [] - # th.exp_map(var, tau_a, jacobians=jac) - euclidean_jac(var, mean, jacobians=jac) + # th.exp_map(var, mean_tp, jacobians=jac) + var.__class__.exp_map(mean_tp, jac) + # euclidean_jac(var, mean_tp, jacobians=jac) jac = jac[0] - lam = torch.bmm(torch.bmm(jac.transpose(-1, -2), lam_tp), jac) + inv_jac = torch.inverse(jac) + lam = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), lam_tp), inv_jac) out_gauss.mean.update(mean.data) out_gauss.lam = lam From 1a6eae469f3f47c019d6aefc90e2a11beef89cae Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 16 Mar 2022 18:59:48 +0000 Subject: [PATCH 03/64] gaussian for Manifold rather than Variable class --- theseus/optimizer/gbp/pose_graph_gbp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/theseus/optimizer/gbp/pose_graph_gbp.py b/theseus/optimizer/gbp/pose_graph_gbp.py index 0e05d6ee8..0ddf4f30a 100644 --- a/theseus/optimizer/gbp/pose_graph_gbp.py +++ b/theseus/optimizer/gbp/pose_graph_gbp.py @@ -85,7 +85,7 @@ class BackwardMode(Enum): class Gaussian: - def __init__(self, mean: th.Variable): + def __init__(self, mean: th.Manifold): self.name = mean.name + "_gaussian" self.mean = mean self.lam = torch.zeros( From e93aa4950979c1a121ae106191ccc45217d0efb5 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 4 Apr 2022 12:00:22 +0100 Subject: [PATCH 04/64] uses proper theseus exp_map jacobian --- theseus/optimizer/gbp/pose_graph_gbp.py | 326 +++++++++++++----------- 1 file changed, 172 insertions(+), 154 deletions(-) diff --git a/theseus/optimizer/gbp/pose_graph_gbp.py b/theseus/optimizer/gbp/pose_graph_gbp.py index 0ddf4f30a..dd48f2b9e 100644 --- a/theseus/optimizer/gbp/pose_graph_gbp.py +++ b/theseus/optimizer/gbp/pose_graph_gbp.py @@ -11,30 +11,18 @@ # Linear factors connect each node to its adjacent nodes. import abc -from typing import Any, Dict, Optional, Type -from enum import Enum -from dataclasses import dataclass import math - -import torch -import theseus as th +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import numpy as np -import numdifftools as nd - -from collections import defaultdict -import time - -from typing import List, Callable, Optional, Union, Sequence - -import matplotlib.pylab as plt -import cv2 +import torch -from theseus.core import Objective, CostFunction -from theseus.optimizer import Linearization, Optimizer, OptimizerInfo, VariableOrdering -from theseus.optimizer.linear import LinearSolver +import theseus as th import theseus.constants - +from theseus.core import CostFunction, Objective +from theseus.optimizer import Linearization, Optimizer, OptimizerInfo, VariableOrdering """ TODO @@ -43,7 +31,7 @@ - test jax implementation of message comp functions - add class for message schedule - damping for lie algebra vars - - solving inverse problem to compute message mean + - solving inverse problem to compute message mean """ @@ -88,8 +76,7 @@ class Gaussian: def __init__(self, mean: th.Manifold): self.name = mean.name + "_gaussian" self.mean = mean - self.lam = torch.zeros( - mean.shape[0], mean.dof(), mean.dof(), dtype=mean.dtype) + self.lam = torch.zeros(mean.shape[0], mean.dof(), mean.dof(), dtype=mean.dtype) class CostFunctionOrdering: @@ -152,9 +139,9 @@ def compute_factor(cf, lie=True): lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) optim_vars_stk = torch.cat([v.data for v in cf.optim_vars], dim=-1) - eta = - torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) + eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) if lie is False: - eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) + eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) return eta, lam @@ -167,15 +154,11 @@ def pass_var_to_fac_messages( n_vars, max_dofs, ): - belief_eta = torch.zeros( - n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) - belief_lam = torch.zeros( - n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) + belief_eta = torch.zeros(n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) + belief_lam = torch.zeros(n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) - belief_eta = belief_eta.index_add( - 0, var_ix_for_edges, ftov_msgs_eta) - belief_lam = belief_lam.index_add( - 0, var_ix_for_edges, ftov_msgs_lam) + belief_eta = belief_eta.index_add(0, var_ix_for_edges, ftov_msgs_eta) + belief_lam = belief_lam.index_add(0, var_ix_for_edges, ftov_msgs_lam) vtof_msgs_eta = belief_eta[var_ix_for_edges] - ftov_msgs_eta vtof_msgs_lam = belief_lam[var_ix_for_edges] - ftov_msgs_lam @@ -202,43 +185,35 @@ def pass_fac_to_var_messages( adj_var_dofs, potentials_eta[i], potentials_lam[i], - vtof_msgs_eta[start: start + num_optim_vars], - vtof_msgs_lam[start: start + num_optim_vars], + vtof_msgs_eta[start : start + num_optim_vars], + vtof_msgs_lam[start : start + num_optim_vars], ) - ftov_msgs_eta[start: start + num_optim_vars] = torch.cat(ftov_eta) - ftov_msgs_lam[start: start + num_optim_vars] = torch.cat(ftov_lam) + ftov_msgs_eta[start : start + num_optim_vars] = torch.cat(ftov_eta) + ftov_msgs_lam[start : start + num_optim_vars] = torch.cat(ftov_lam) start += num_optim_vars return ftov_msgs_eta, ftov_msgs_lam -def euclidean_jac(var, tau, jacobians=None): - jacobians.extend([torch.eye(2)[None, ...]]) - return tau - - # Transforms message to tangent plane at var # if return_mean is True, return the (mean, lam) else return (eta, lam). # Generalises the local function by transforming the covariance as well as mean. def local_gaussian( gauss: Gaussian, - var: th.Manifold, + var: th.LieGroup, return_mean: bool = True, -) -> [torch.Tensor, torch.Tensor]: +) -> Tuple[torch.Tensor, torch.Tensor]: # mean_tp is message mean in tangent space / plane at var mean_tp = var.local(gauss.mean) - jac = [] - # th.exp_map(var, mean_tp, jacobians=jac) - var.__class__.exp_map(mean_tp, jac) - # euclidean_jac(var, mean_tp, jacobians=jac) - jac = jac[0] + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) # lam_tp is lambda matrix in the tangent plane - lam_tp = torch.bmm(torch.bmm(jac.transpose(-1, -2), gauss.lam), jac) - + lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gauss.lam), jac[0]) + if return_mean: return mean_tp, lam_tp @@ -255,17 +230,14 @@ def local_gaussian( def retract_gaussian( mean_tp: torch.Tensor, lam_tp: torch.Tensor, - var: th.Manifold, + var: th.LieGroup, out_gauss: Gaussian, -) -> [th.Manifold, torch.Tensor]: +): mean = var.retract(mean_tp) - jac = [] - # th.exp_map(var, mean_tp, jacobians=jac) - var.__class__.exp_map(mean_tp, jac) - # euclidean_jac(var, mean_tp, jacobians=jac) - jac = jac[0] - inv_jac = torch.inverse(jac) + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) + inv_jac = torch.inverse(jac[0]) lam = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), lam_tp), inv_jac) out_gauss.mean.update(mean.data) @@ -278,8 +250,6 @@ def pass_var_to_fac_messages_and_update_beliefs_lie( var_ordering, var_ix_for_edges, ): - belief_covs = [] - for i, var in enumerate(var_ordering): # Collect all incoming messages in the tangent space at the current belief @@ -300,8 +270,8 @@ def pass_var_to_fac_messages_and_update_beliefs_lie( ix = 0 for j, msg in enumerate(ftov_msgs): if var_ix_for_edges[j] == i: - taus_inc = torch.cat((taus[:ix], taus[ix + 1:])) - lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1:])) + taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) lam_a = lams_inc.sum(dim=0) if lam_a.count_nonzero() == 0: @@ -339,13 +309,13 @@ def pass_fac_to_var_messages_lie( adj_var_dofs = adj_var_dofs_nested[i] num_optim_vars = len(adj_var_dofs) - new_messages = ftov_comp_mess_lie( + ftov_comp_mess_lie( potentials_eta[i], potentials_lam[i], lin_points[i], - vtof_msgs[start: start + num_optim_vars], - ftov_msgs[start: start + num_optim_vars], - damping[start: start + num_optim_vars], + vtof_msgs[start : start + num_optim_vars], + ftov_msgs[start : start + num_optim_vars], + damping[start : start + num_optim_vars], ) start += num_optim_vars @@ -374,26 +344,47 @@ def ftov_comp_mess( if var != v: eta_mess = vtof_msgs_eta[var] lam_mess = vtof_msgs_lam[var] - eta_factor[start:start + var_dofs] += eta_mess - lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess + eta_factor[start : start + var_dofs] += eta_mess + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess start += var_dofs # Divide up parameters of distribution dofs = adj_var_dofs[v] - eo = eta_factor[sdim:sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] lono = np.concatenate( - (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), - axis=1) + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) lnoo = np.concatenate( - (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), - axis=0) - lnono = np.concatenate(( - np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), - np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) - ), axis=0) + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno @@ -430,27 +421,49 @@ def ftov_comp_mess_lie( var_dofs = lin_points[i].dof() if i != v: eta_mess, lam_mess = local_gaussian( - vtof_msgs[i], lin_points[i], return_mean=False) - eta_factor[start:start + var_dofs] += eta_mess[0] - lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess[0] + vtof_msgs[i], lin_points[i], return_mean=False + ) + eta_factor[start : start + var_dofs] += eta_mess[0] + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess[0] start += var_dofs # Divide up parameters of distribution dofs = lin_points[v].dof() - eo = eta_factor[sdim:sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] lono = np.concatenate( - (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), - axis=1) + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) lnoo = np.concatenate( - (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), - axis=0) - lnono = np.concatenate(( - np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), - np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) - ), axis=0) + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno @@ -482,6 +495,7 @@ def ftov_comp_mess_lie( return new_messages + # Follows notation from https://arxiv.org/pdf/2202.03314.pdf class GaussianBeliefPropagation(Optimizer, abc.ABC): def __init__( @@ -509,11 +523,11 @@ def __init__( self.max_dofs = max([var.dof() for var in self.ordering]) # create arrays for indexing the messages - var_ixs = [ - [self.ordering.index_of(var.name) for var in cf.optim_vars] for cf in - self.cf_ordering + var_ixs_nested = [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] + for cf in self.cf_ordering ] - var_ixs = [item for sublist in var_ixs for item in sublist] + var_ixs = [item for sublist in var_ixs_nested for item in sublist] self.var_ix_for_edges = torch.tensor(var_ixs).long() self.adj_var_dofs_nested = [ @@ -678,7 +692,7 @@ def _linearize( if do_lin: potential_eta, potential_lam = compute_factor(cf, lie=lie) - + potentials_eta[i] = potential_eta potentials_lam[i] = potential_lam @@ -693,7 +707,7 @@ def _optimize_loop( self, start_iter: int, num_iter: int, - info: OptimizerInfo, + info: NonlinearOptimizerInfo, verbose: bool, truncated_grad_loop: bool, relin_threshold: float = 0.1, @@ -704,9 +718,11 @@ def _optimize_loop( ): # initialise messages with zeros vtof_msgs_eta = torch.zeros( - self.n_edges, self.max_dofs, dtype=self.objective.dtype) + self.n_edges, self.max_dofs, dtype=self.objective.dtype + ) vtof_msgs_lam = torch.zeros( - self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype) + self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype + ) ftov_msgs_eta = vtof_msgs_eta.clone() ftov_msgs_lam = vtof_msgs_lam.clone() @@ -715,16 +731,18 @@ def _optimize_loop( potentials_lam = [None] * self.objective.size_cost_functions() lin_points = [ [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering + for cf in self.cf_ordering ] potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None) + potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None + ) converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None) + potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None + ) msgs_eta, msgs_lam = pass_fac_to_var_messages( potentials_eta, @@ -736,19 +754,25 @@ def _optimize_loop( # damping # damping = self.gbp_settings.get_damping(iters_since_relin) - if isinstance(damping, float): - damping = torch.full([len(msgs_eta)], damping) + damping_arr = torch.full([len(msgs_eta)], damping) # dropout can be implemented through damping if dropout != 0.0: dropout_ixs = torch.rand(len(msgs_eta)) < dropout - damping[dropout_ixs] = 1.0 + damping_arr[dropout_ixs] = 1.0 - ftov_msgs_eta = (1 - damping[:, None]) * msgs_eta + damping[:, None] * ftov_msgs_eta - ftov_msgs_lam = (1 - damping[:, None, None]) * msgs_lam + damping[:, None, None] * ftov_msgs_lam + ftov_msgs_eta = (1 - damping_arr[:, None]) * msgs_eta + damping_arr[ + :, None + ] * ftov_msgs_eta + ftov_msgs_lam = (1 - damping_arr[:, None, None]) * msgs_lam + damping_arr[ + :, None, None + ] * ftov_msgs_lam ( - vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam + vtof_msgs_eta, + vtof_msgs_lam, + belief_eta, + belief_lam, ) = pass_var_to_fac_messages( ftov_msgs_eta, ftov_msgs_lam, @@ -768,10 +792,7 @@ def _optimize_loop( err = self.objective.error_squared_norm() / 2 self._update_info(info, it_, err, converged_indices) if verbose: - print( - f"GBP. Iteration: {it_+1}. " - f"Error: {err.mean().item()}" - ) + print(f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}") converged_indices = self._check_convergence(err, info.last_err) info.status[ converged_indices.cpu().numpy() @@ -790,7 +811,7 @@ def _optimize_loop_lie( self, start_iter: int, num_iter: int, - info: OptimizerInfo, + info: NonlinearOptimizerInfo, verbose: bool, truncated_grad_loop: bool, relin_threshold: float = 0.1, @@ -815,28 +836,31 @@ def _optimize_loop_lie( potentials_lam = [None] * self.objective.size_cost_functions() lin_points = [ [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering + for cf in self.cf_ordering ] potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, - lp_dist_thresh=None, lie=True) + potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None, lie=True + ) converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, - lp_dist_thresh=None, lie=True) + potentials_eta, + potentials_lam, + lin_points, + lp_dist_thresh=None, + lie=True, + ) # damping # damping = self.gbp_settings.get_damping(iters_since_relin) - if isinstance(damping, float): - damping = torch.full([self.n_edges], damping) + damping_arr = torch.full([self.n_edges], damping) # dropout can be implemented through damping if dropout != 0.0: dropout_ixs = torch.rand(self.n_edges) < dropout - damping[dropout_ixs] = 1.0 + damping_arr[dropout_ixs] = 1.0 pass_fac_to_var_messages_lie( potentials_eta, @@ -845,10 +869,10 @@ def _optimize_loop_lie( vtof_msgs, ftov_msgs, self.adj_var_dofs_nested, - damping, + damping_arr, ) - belief_covs = pass_var_to_fac_messages_and_update_beliefs_lie( + pass_var_to_fac_messages_and_update_beliefs_lie( ftov_msgs, vtof_msgs, self.ordering, @@ -862,8 +886,7 @@ def _optimize_loop_lie( self._update_info(info, it_, err, converged_indices) if verbose: print( - f"GBP. Iteration: {it_+1}. " - f"Error: {err.mean().item()}" + f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}" ) converged_indices = self._check_convergence(err, info.last_err) info.status[ @@ -890,23 +913,18 @@ def _optimize_impl( damping: float = 0.0, dropout: float = 0.0, **kwargs, - ) -> OptimizerInfo: + ) -> NonlinearOptimizerInfo: if damping > 1.0 or damping < 0.0: - raise NotImplementedError( - "damping must be in between 0 and 1." - ) + raise NotImplementedError("damping must be in between 0 and 1.") if dropout > 1.0 or dropout < 0.0: - raise NotImplementedError( - "dropout probability must be in between 0 and 1." - ) + raise NotImplementedError("dropout probability must be in between 0 and 1.") with torch.no_grad(): info = self._init_info(track_best_solution, track_err_history, verbose) if verbose: print( - f"GBP optimizer. Iteration: 0. " - f"Error: {info.last_err.mean().item()}" + f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" ) grad = False @@ -915,7 +933,6 @@ def _optimize_impl( with torch.set_grad_enabled(grad): - # if self.lie_groups: info = self._optimize_loop_lie( start_iter=0, @@ -937,7 +954,7 @@ def _optimize_impl( # damping=damping, # dropout=dropout, # **kwargs, - # ) + # ) # If didn't coverge, remove misleading converged_iter value info.converged_iter[ info.status == NonlinearOptimizerStatus.MAX_ITERATIONS @@ -958,7 +975,7 @@ def _optimize_impl( prior_noise_std = 0.2 prior_sigma = np.array([1.3**2, 1.3**2]) - init_noises = np.random.normal(np.zeros([size*size, 2]), prior_noise_std) + init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) # create theseus objective ------------------------------------- @@ -994,10 +1011,11 @@ def _optimize_impl( inputs[f"x{p}"] = init[None, :] inputs[f"prior_{p}"] = init[None, :] - cf = th.eb.VariableDifference( - poses[p], w, prior_target, name=f"prior_cost_{p}") + cf_prior = th.eb.VariableDifference( + poses[p], w, prior_target, name=f"prior_cost_{p}" + ) - objective.add(cf) + objective.add(cf_prior) p += 1 @@ -1010,7 +1028,7 @@ def _optimize_impl( for i in range(size): for j in range(size): if j < size - 1: - measurement = torch.Tensor([1., 0.]) + measurement = torch.Tensor([1.0, 0.0]) # measurement += torch.normal(torch.zeros(2), meas_std) measurement += torch.FloatTensor(meas_noises[m]) ix0 = i * size + j @@ -1019,14 +1037,14 @@ def _optimize_impl( meas = th.Vector(data=measurement, name=f"meas_{m}") inputs[f"meas_{m}"] = measurement[None, :] - cf = th.eb.Between( - poses[ix0], poses[ix1], - meas_w, meas, name=f"meas_cost_{m}") - objective.add(cf) + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) m += 1 if i < size - 1: - measurement = torch.Tensor([0., 1.]) + measurement = torch.Tensor([0.0, 1.0]) # measurement += torch.normal(torch.zeros(2), meas_std) measurement += torch.FloatTensor(meas_noises[m]) ix0 = i * size + j @@ -1035,10 +1053,10 @@ def _optimize_impl( meas = th.Vector(data=measurement, name=f"meas_{m}") inputs[f"meas_{m}"] = measurement[None, :] - cf = th.eb.Between( - poses[ix0], poses[ix1], - meas_w, meas, name=f"meas_cost_{m}") - objective.add(cf) + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) m += 1 # # objective.update(init_dict) @@ -1075,8 +1093,9 @@ def _optimize_impl( print("updated_inputs", updated_inputs) print("info", info) - import ipdb; ipdb.set_trace() + import ipdb + ipdb.set_trace() # optimizer = th.GaussNewton( # objective, @@ -1092,4 +1111,3 @@ def _optimize_impl( # print("info", info) # import ipdb; ipdb.set_trace() - From aa7062ed5011ae7314b98b0ddca26d269282e0aa Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Tue, 5 Apr 2022 17:26:54 +0100 Subject: [PATCH 05/64] gaussian class plus marginals and message class --- theseus/optimizer/gbp/__init__.py | 6 + .../gbp/{pose_graph_gbp.py => gbp.py} | 263 +++------- theseus/optimizer/gbp/gbp_baseline.py | 259 ++++++---- theseus/optimizer/gbp/jax_torch_test.py | 464 +++++++++++++----- theseus/optimizer/gbp/pgo_test.py | 161 ++++++ 5 files changed, 725 insertions(+), 428 deletions(-) create mode 100644 theseus/optimizer/gbp/__init__.py rename theseus/optimizer/gbp/{pose_graph_gbp.py => gbp.py} (83%) create mode 100644 theseus/optimizer/gbp/pgo_test.py diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py new file mode 100644 index 000000000..b27e0742d --- /dev/null +++ b/theseus/optimizer/gbp/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .gbp import GaussianBeliefPropagation diff --git a/theseus/optimizer/gbp/pose_graph_gbp.py b/theseus/optimizer/gbp/gbp.py similarity index 83% rename from theseus/optimizer/gbp/pose_graph_gbp.py rename to theseus/optimizer/gbp/gbp.py index dd48f2b9e..ef32d5b15 100644 --- a/theseus/optimizer/gbp/pose_graph_gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -1,19 +1,11 @@ -#!/usr/bin/env python3 # Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -# -# This example illustrates the Gaussian Belief Propagation (GBP) optimizer -# for a 2D pose graph optimization problem. -# Linear problem where we are estimating the (x, y)position of 9 nodes, -# arranged in a 3x3 grid. -# Linear factors connect each node to its adjacent nodes. import abc import math from dataclasses import dataclass -from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import numpy as np @@ -22,7 +14,13 @@ import theseus as th import theseus.constants from theseus.core import CostFunction, Objective -from theseus.optimizer import Linearization, Optimizer, OptimizerInfo, VariableOrdering +from theseus.geometry import Manifold +from theseus.optimizer import Linearization, Optimizer, VariableOrdering +from theseus.optimizer.nonlinear.nonlinear_optimizer import ( + BackwardMode, + NonlinearOptimizerInfo, + NonlinearOptimizerStatus, +) """ TODO @@ -35,6 +33,11 @@ """ +""" +Utitily functions +""" + + @dataclass class GBPOptimizerParams: abs_err_tolerance: float @@ -49,34 +52,32 @@ def update(self, params_dict): raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") -class NonlinearOptimizerStatus(Enum): - START = 0 - CONVERGED = 1 - MAX_ITERATIONS = 2 - FAIL = -1 - +# Stores variable beliefs that converge towards the marginals +class Gaussian: + def __init__(self, variable: Manifold): + self.name = variable.name + "_gaussian" + self.mean = variable + self.tot_dof = variable.dof() + + # tot_dof = 0 + # for v in variables: + # tot_dof += v.dof() + # self.tot_dof = tot_dof + + self.precision = torch.zeros( + self.mean.shape[0], self.tot_dof, self.tot_dof, dtype=variable.dtype + ) -# All info information is batched -@dataclass -class NonlinearOptimizerInfo(OptimizerInfo): - converged_iter: torch.Tensor - best_iter: torch.Tensor - err_history: Optional[torch.Tensor] - last_err: torch.Tensor - best_err: torch.Tensor + def dof(self) -> int: + return self.tot_dof -class BackwardMode(Enum): - FULL = 0 - IMPLICIT = 1 - TRUNCATED = 2 +class Marginals(Gaussian): + pass -class Gaussian: - def __init__(self, mean: th.Manifold): - self.name = mean.name + "_gaussian" - self.mean = mean - self.lam = torch.zeros(mean.shape[0], mean.dof(), mean.dof(), dtype=mean.dtype) +class Message(Gaussian): + pass class CostFunctionOrdering: @@ -131,6 +132,11 @@ def complete(self): return len(self._cf_order) == self.objective.size_variables() +""" +GBP functions +""" + + # Compute the factor at current adjacent beliefs. def compute_factor(cf, lie=True): J, error = cf.weighted_jacobians_error() @@ -197,22 +203,22 @@ def pass_fac_to_var_messages( return ftov_msgs_eta, ftov_msgs_lam -# Transforms message to tangent plane at var +# Transforms message gaussian to tangent plane at var # if return_mean is True, return the (mean, lam) else return (eta, lam). # Generalises the local function by transforming the covariance as well as mean. def local_gaussian( - gauss: Gaussian, + mess: Message, var: th.LieGroup, return_mean: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: # mean_tp is message mean in tangent space / plane at var - mean_tp = var.local(gauss.mean) + mean_tp = var.local(mess.mean) jac: List[torch.Tensor] = [] th.exp_map(var, mean_tp, jacobians=jac) - # lam_tp is lambda matrix in the tangent plane - lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), gauss.lam), jac[0]) + # lam_tp is the precision matrix in the tangent plane + lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) if return_mean: return mean_tp, lam_tp @@ -229,7 +235,7 @@ def local_gaussian( # out_gauss is the transformed Gaussian that is updated in place. def retract_gaussian( mean_tp: torch.Tensor, - lam_tp: torch.Tensor, + prec_tp: torch.Tensor, var: th.LieGroup, out_gauss: Gaussian, ): @@ -238,10 +244,10 @@ def retract_gaussian( jac: List[torch.Tensor] = [] th.exp_map(var, mean_tp, jacobians=jac) inv_jac = torch.inverse(jac[0]) - lam = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), lam_tp), inv_jac) + precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), prec_tp), inv_jac) out_gauss.mean.update(mean.data) - out_gauss.lam = lam + out_gauss.precision = precision def pass_var_to_fac_messages_and_update_beliefs_lie( @@ -276,7 +282,7 @@ def pass_var_to_fac_messages_and_update_beliefs_lie( lam_a = lams_inc.sum(dim=0) if lam_a.count_nonzero() == 0: vtof_msgs[j].mean.data[:] = 0.0 - vtof_msgs[j].lam = lam_a + vtof_msgs[j].precision = lam_a else: inv_lam_a = torch.inverse(lam_a) sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum(dim=0) @@ -491,12 +497,14 @@ def ftov_comp_mess_lie( # update messages for v in range(num_optim_vars): ftov_msgs[v].mean.update(new_messages[v].mean.data) - ftov_msgs[v].lam = new_messages[v].lam + ftov_msgs[v].precision = new_messages[v].precision return new_messages # Follows notation from https://arxiv.org/pdf/2202.03314.pdf + + class GaussianBeliefPropagation(Optimizer, abc.ABC): def __init__( self, @@ -509,11 +517,12 @@ def __init__( max_iterations: int = 20, ): super().__init__(objective) + + # ordering is required to identify which messages to send where self.ordering = VariableOrdering(objective, default_order=True) self.cf_ordering = CostFunctionOrdering(objective) self.schedule = None - self.damping = None self.params = GBPOptimizerParams( abs_err_tolerance, rel_err_tolerance, max_iterations @@ -541,6 +550,10 @@ def __init__( self.lie_groups = lie_groups print("lie groups:", self.lie_groups) + """ + Copied and slightly modified from nonlinear optimizer class + """ + def set_params(self, **kwargs): self.params.update(kwargs) @@ -665,6 +678,10 @@ def _merge_infos( M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS) ] = -1 + """ + GBP specific functions + """ + # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. def _linearize( @@ -826,10 +843,11 @@ def _optimize_loop_lie( for cf in self.cf_ordering: for var in cf.optim_vars: vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") + # mean of initial message doesn't matter as long as precision is zero vtof_msg_mu.data[:] = 0 ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") - vtof_msgs.append(Gaussian(vtof_msg_mu)) - ftov_msgs.append(Gaussian(ftov_msg_mu)) + vtof_msgs.append(Message(vtof_msg_mu)) + ftov_msgs.append(Message(ftov_msg_mu)) # compute factor potentials for the first time potentials_eta = [None] * self.objective.size_cost_functions() @@ -915,9 +933,9 @@ def _optimize_impl( **kwargs, ) -> NonlinearOptimizerInfo: if damping > 1.0 or damping < 0.0: - raise NotImplementedError("damping must be in between 0 and 1.") + raise NotImplementedError("Damping must be in between 0 and 1.") if dropout > 1.0 or dropout < 0.0: - raise NotImplementedError("dropout probability must be in between 0 and 1.") + raise NotImplementedError("Dropout probability must be in between 0 and 1.") with torch.no_grad(): info = self._init_info(track_best_solution, track_err_history, verbose) @@ -960,154 +978,3 @@ def _optimize_impl( info.status == NonlinearOptimizerStatus.MAX_ITERATIONS ] = -1 return info - - -if __name__ == "__main__": - - np.random.seed(1) - torch.manual_seed(0) - - size = 3 - dim = 2 - - noise_cov = np.array([0.01, 0.01]) - - prior_noise_std = 0.2 - prior_sigma = np.array([1.3**2, 1.3**2]) - - init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) - meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) - - # create theseus objective ------------------------------------- - - objective = th.Objective() - inputs = {} - - n_poses = size * size - - # create variables - poses = [] - for i in range(n_poses): - poses.append(th.Vector(data=torch.rand(1, 2), name=f"x{i}")) - - # add prior cost constraints with VariableDifference cost - prior_std = 1.3 - anchor_std = 0.01 - prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") - anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") - - p = 0 - for i in range(size): - for j in range(size): - init = torch.Tensor([j, i]) - if i == 0 and j == 0: - w = anchor_w - else: - # noise_init = torch.normal(torch.zeros(2), prior_noise_std) - init = init + torch.FloatTensor(init_noises[p]) - w = prior_w - - prior_target = th.Vector(data=init, name=f"prior_{p}") - inputs[f"x{p}"] = init[None, :] - inputs[f"prior_{p}"] = init[None, :] - - cf_prior = th.eb.VariableDifference( - poses[p], w, prior_target, name=f"prior_cost_{p}" - ) - - objective.add(cf_prior) - - p += 1 - - # Measurement cost functions - - meas_std = 0.1 - meas_w = th.ScaleCostWeight(1 / meas_std, name="prior_weight") - - m = 0 - for i in range(size): - for j in range(size): - if j < size - 1: - measurement = torch.Tensor([1.0, 0.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = i * size + j + 1 - - meas = th.Vector(data=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] - - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 - - if i < size - 1: - measurement = torch.Tensor([0.0, 1.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = (i + 1) * size + j - - meas = th.Vector(data=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] - - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 - - # # objective.update(init_dict) - # print("Initial cost:", objective.error_squared_norm()) - - # fg.print(brief=True) - - # # for vis --------------- - - # joint = fg.get_joint() - # marg_covs = np.diag(joint.cov())[::2] - # map_soln = fg.MAP().reshape([size * size, 2]) - - # Solve with Gauss Newton --------------- - - # print("inputs", inputs) - - optimizer = GaussianBeliefPropagation( - objective, - max_iterations=100, - ) - theseus_optim = th.TheseusLayer(optimizer) - - optim_arg = { - "track_best_solution": True, - "track_err_history": True, - "verbose": True, - "backward_mode": BackwardMode.FULL, - "damping": 0.6, - "dropout": 0.0, - } - updated_inputs, info = theseus_optim.forward(inputs, optim_arg) - - print("updated_inputs", updated_inputs) - print("info", info) - - import ipdb - - ipdb.set_trace() - - # optimizer = th.GaussNewton( - # objective, - # max_iterations=15, - # step_size=0.5, - # ) - # theseus_optim = th.TheseusLayer(optimizer) - - # with torch.no_grad(): - # optim_args = {"track_best_solution": True, "verbose": True} - # updated_inputs, info = theseus_optim.forward(inputs, optim_args) - # print("updated_inputs", updated_inputs) - # print("info", info) - - # import ipdb; ipdb.set_trace() diff --git a/theseus/optimizer/gbp/gbp_baseline.py b/theseus/optimizer/gbp/gbp_baseline.py index d63e05a1b..fef31b412 100644 --- a/theseus/optimizer/gbp/gbp_baseline.py +++ b/theseus/optimizer/gbp/gbp_baseline.py @@ -1,13 +1,14 @@ -import numpy as np -from typing import Any, Dict, Optional, Type -from typing import List, Callable, Optional, Union +from typing import Callable, List, Optional, Union +import matplotlib.pylab as plt +import numpy as np """ Defines squared loss functions that correspond to Gaussians. Robust losses are implemented by scaling the Gaussian covariance. """ + class Gaussian: def __init__( self, @@ -46,11 +47,11 @@ def set_with_cov_form(self, mean: np.ndarray, cov: np.ndarray) -> None: class GBPSettings: def __init__( self, - damping: float = 0., + damping: float = 0.0, beta: float = 0.1, num_undamped_iters: int = 5, min_linear_iters: int = 10, - dropout: float = 0., + dropout: float = 0.0, reset_iters_since_relin: List[int] = [], ): # Parameters for damping the eta component of the message @@ -73,20 +74,17 @@ def get_damping(self, iters_since_relin: int) -> float: if iters_since_relin > self.num_undamped_iters: return self.damping else: - return 0. + return 0.0 -class SquaredLoss(): - def __init__( - self, - dofs: int, - diag_cov: Union[float, np.ndarray] - ): +class SquaredLoss: + def __init__(self, dofs: int, diag_cov: Union[float, np.ndarray]): """ dofs: dofs of the measurement cov: diagonal elements of covariance matrix """ - assert diag_cov.shape[0] == dofs + if isinstance(diag_cov, np.ndarray): + assert diag_cov.shape[0] == dofs mat = np.zeros([dofs, dofs]) mat[range(dofs), range(dofs)] = diag_cov self.cov = mat @@ -105,10 +103,7 @@ def robust(self) -> bool: class HuberLoss(SquaredLoss): def __init__( - self, - dofs: int, - diag_cov: Union[float, np.ndarray], - stds_transition: float + self, dofs: int, diag_cov: Union[float, np.ndarray], stds_transition: float ): """ stds_transition: num standard deviations from minimum at @@ -121,7 +116,9 @@ def get_effective_cov(self, residual: np.ndarray) -> None: energy = residual @ np.linalg.inv(self.cov) @ residual mahalanobis_dist = np.sqrt(energy) if mahalanobis_dist > self.stds_transition: - denom = (2 * self.stds_transition * mahalanobis_dist - self.stds_transition ** 2) + denom = ( + 2 * self.stds_transition * mahalanobis_dist - self.stds_transition**2 + ) self.effective_cov = self.cov * mahalanobis_dist**2 / denom else: self.effective_cov = self.cov.copy() @@ -176,8 +173,8 @@ def __init__( self, gbp_settings: GBPSettings = GBPSettings(), ): - self.var_nodes = [] - self.factors = [] + self.var_nodes: List[VariableNode] = [] + self.factors: List[Factor] = [] self.gbp_settings = gbp_settings def add_var_node( @@ -202,8 +199,7 @@ def add_factor( ) -> None: factorID = len(self.factors) adj_var_nodes = [self.var_nodes[i] for i in adj_var_ids] - self.factors.append( - Factor(factorID, adj_var_nodes, measurement, meas_model)) + self.factors.append(Factor(factorID, adj_var_nodes, measurement, meas_model)) for var in adj_var_nodes: var.adj_factors.append(self.factors[-1]) @@ -215,8 +211,7 @@ def compute_all_messages(self, apply_dropout: bool = True) -> None: for factor in self.factors: dropout_off = apply_dropout and np.random.rand() > self.gbp_settings.dropout if dropout_off or not apply_dropout: - damping = self.gbp_settings.get_damping( - factor.iters_since_relin) + damping = self.gbp_settings.get_damping(factor.iters_since_relin) factor.compute_messages(damping) def linearise_all_factors(self) -> None: @@ -239,8 +234,13 @@ def jit_linearisation(self) -> None: if not factor.meas_model.linear: adj_belief_means = factor.get_adj_means() factor.iters_since_relin += 1 - diff_cond = np.linalg.norm(factor.linpoint - adj_belief_means) > self.gbp_settings.beta - iters_cond = factor.iters_since_relin >= self.gbp_settings.min_linear_iters + diff_cond = ( + np.linalg.norm(factor.linpoint - adj_belief_means) + > self.gbp_settings.beta + ) + iters_cond = ( + factor.iters_since_relin >= self.gbp_settings.min_linear_iters + ) if diff_cond and iters_cond: factor.compute_factor() @@ -266,7 +266,7 @@ def gbp_solve( self, n_iters: Optional[int] = 20, converged_threshold: Optional[float] = 1e-6, - include_priors: bool = True + include_priors: bool = True, ) -> None: energy_log = [self.energy()] print(f"\nInitial Energy {energy_log[0]:.5f}") @@ -282,10 +282,7 @@ def gbp_solve( f.iters_since_relin = 1 energy_log.append(self.energy(include_priors=include_priors)) - print( - f"Iter {i+1} --- " - f"Energy {energy_log[-1]:.5f} --- " - ) + print(f"Iter {i+1} --- " f"Energy {energy_log[-1]:.5f} --- ") i += 1 if abs(energy_log[-2] - energy_log[-1]) < converged_threshold: count += 1 @@ -295,9 +292,7 @@ def gbp_solve( count = 0 def energy( - self, - eval_point: np.ndarray = None, - include_priors: bool = True + self, eval_point: np.ndarray = None, include_priors: bool = True ) -> float: """ Computes the sum of all of the squared errors in the graph @@ -308,9 +303,14 @@ def energy( else: var_dofs = np.ndarray([v.dofs for v in self.var_nodes]) var_ix = np.concatenate([np.ndarray([0]), np.cumsum(var_dofs, axis=0)[:-1]]) - energy = 0. + energy = 0.0 for f in self.factors: - local_eval_point = np.concatenate([eval_point[var_ix[v.variableID]: var_ix[v.variableID] + v.dofs] for v in f.adj_var_nodes]) + local_eval_point = np.concatenate( + [ + eval_point[var_ix[v.variableID] : var_ix[v.variableID] + v.dofs] + for v in f.adj_var_nodes + ] + ) energy += f.get_energy(local_eval_point) if include_priors: prior_energy = sum([var.get_prior_energy() for var in self.var_nodes]) @@ -333,8 +333,10 @@ def get_joint(self) -> Gaussian: counter = 0 for var in self.var_nodes: var_ix[var.variableID] = int(counter) - joint.eta[counter:counter + var.dofs] += var.prior.eta - joint.lam[counter:counter + var.dofs, counter:counter + var.dofs] += var.prior.lam + joint.eta[counter : counter + var.dofs] += var.prior.eta + joint.lam[ + counter : counter + var.dofs, counter : counter + var.dofs + ] += var.prior.lam counter += var.dofs # Other factors @@ -343,19 +345,37 @@ def get_joint(self) -> Gaussian: for adj_var_node in factor.adj_var_nodes: vID = adj_var_node.variableID # Diagonal contribution of factor - joint.eta[var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ - factor.factor.eta[factor_ix:factor_ix + adj_var_node.dofs] - joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ - factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] + joint.eta[ + var_ix[vID] : var_ix[vID] + adj_var_node.dofs + ] += factor.factor.eta[factor_ix : factor_ix + adj_var_node.dofs] + joint.lam[ + var_ix[vID] : var_ix[vID] + adj_var_node.dofs, + var_ix[vID] : var_ix[vID] + adj_var_node.dofs, + ] += factor.factor.lam[ + factor_ix : factor_ix + adj_var_node.dofs, + factor_ix : factor_ix + adj_var_node.dofs, + ] other_factor_ix = 0 for other_adj_var_node in factor.adj_var_nodes: if other_adj_var_node.variableID > adj_var_node.variableID: other_vID = other_adj_var_node.variableID # Off diagonal contributions of factor - joint.lam[var_ix[vID]:var_ix[vID] + adj_var_node.dofs, var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs] += \ - factor.factor.lam[factor_ix:factor_ix + adj_var_node.dofs, other_factor_ix:other_factor_ix + other_adj_var_node.dofs] - joint.lam[var_ix[other_vID]:var_ix[other_vID] + other_adj_var_node.dofs, var_ix[vID]:var_ix[vID] + adj_var_node.dofs] += \ - factor.factor.lam[other_factor_ix:other_factor_ix + other_adj_var_node.dofs, factor_ix:factor_ix + adj_var_node.dofs] + joint.lam[ + var_ix[vID] : var_ix[vID] + adj_var_node.dofs, + var_ix[other_vID] : var_ix[other_vID] + + other_adj_var_node.dofs, + ] += factor.factor.lam[ + factor_ix : factor_ix + adj_var_node.dofs, + other_factor_ix : other_factor_ix + other_adj_var_node.dofs, + ] + joint.lam[ + var_ix[other_vID] : var_ix[other_vID] + + other_adj_var_node.dofs, + var_ix[vID] : var_ix[vID] + adj_var_node.dofs, + ] += factor.factor.lam[ + other_factor_ix : other_factor_ix + other_adj_var_node.dofs, + factor_ix : factor_ix + adj_var_node.dofs, + ] other_factor_ix += other_adj_var_node.dofs factor_ix += adj_var_node.dofs @@ -368,11 +388,11 @@ def dist_from_MAP(self) -> np.ndarray: return np.linalg.norm(self.get_joint().mean() - self.belief_means()) def belief_means(self) -> np.ndarray: - """ Get an array containing all current estimates of belief means. """ + """Get an array containing all current estimates of belief means.""" return np.concatenate([var.belief.mean() for var in self.var_nodes]) def belief_covs(self) -> List[np.ndarray]: - """ Get a list of all belief covariances. """ + """Get a list of all belief covariances.""" covs = [var.belief.cov() for var in self.var_nodes] return covs @@ -381,10 +401,14 @@ def print(self, brief=False) -> None: print(f"# Variable nodes: {len(self.var_nodes)}") if not brief: for i, var in enumerate(self.var_nodes): - print(f"Variable {i}: connects to factors {[f.factorID for f in var.adj_factors]}") + print( + f"Variable {i}: connects to factors {[f.factorID for f in var.adj_factors]}" + ) print(f" dofs: {var.dofs}") print(f" prior mean: {var.prior.mean()}") - print(f" prior covariance: diagonal sigma {np.diag(var.prior.cov())}") + print( + f" prior covariance: diagonal sigma {np.diag(var.prior.cov())}" + ) print(f"# Factors: {len(self.factors)}") if not brief: for i, factor in enumerate(self.factors): @@ -406,7 +430,7 @@ class VariableNode: def __init__(self, id: int, dofs: int): self.variableID = id self.dofs = dofs - self.adj_factors = [] + self.adj_factors: List[Factor] = [] # prior factor, implemented as part of variable node self.prior = Gaussian(dofs) self.belief = Gaussian(dofs) @@ -426,8 +450,8 @@ def update_belief(self) -> None: self.belief.lam += factor.messages[message_ix].lam def get_prior_energy(self) -> float: - energy = 0. - if self.prior.lam[0, 0] != 0.: + energy = 0.0 + if self.prior.lam[0, 0] != 0.0: residual = self.belief.mean() - self.prior.mean() energy += 0.5 * residual @ self.prior.lam @ residual return energy @@ -465,13 +489,13 @@ def get_adj_means(self) -> np.ndarray: return np.concatenate(adj_belief_means) def get_residual(self, eval_point: np.ndarray = None) -> np.ndarray: - """ Compute the residual vector. """ + """Compute the residual vector.""" if eval_point is None: eval_point = self.get_adj_means() return self.meas_model.meas_fn(eval_point) - self.measurement def get_energy(self, eval_point: np.ndarray = None) -> float: - """ Computes the squared error using the appropriate loss function. """ + """Computes the squared error using the appropriate loss function.""" residual = self.get_residual(eval_point) inf_mat = np.linalg.inv(self.meas_model.loss.effective_cov) return 0.5 * residual @ inf_mat @ residual @@ -481,9 +505,9 @@ def robust(self) -> bool: def compute_factor(self) -> None: """ - Compute the factor at current adjacente beliefs using robust. - If measurement model is linear then factor will always be - the same regardless of linearisation point. + Compute the factor at current adjacente beliefs using robust. + If measurement model is linear then factor will always be + the same regardless of linearisation point. """ self.linpoint = self.get_adj_means() J = self.meas_model.jac_fn(self.linpoint) @@ -491,7 +515,10 @@ def compute_factor(self) -> None: self.meas_model.loss.get_effective_cov(pred_measurement - self.measurement) effective_lam = np.linalg.inv(self.meas_model.loss.effective_cov) self.factor.lam = J.T @ effective_lam @ J - self.factor.eta = ((J.T @ effective_lam) @ (J @ self.linpoint + self.measurement - pred_measurement)).flatten() + self.factor.eta = ( + (J.T @ effective_lam) + @ (J @ self.linpoint + self.measurement - pred_measurement) + ).flatten() self.iters_since_relin = 0 def robustify_loss(self) -> None: @@ -505,8 +532,8 @@ def robustify_loss(self) -> None: self.factor.eta *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] self.factor.lam *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] - def compute_messages(self, damping: float = 0.) -> None: - """ Compute all outgoing messages from the factor. """ + def compute_messages(self, damping: float = 0.0) -> None: + """Compute all outgoing messages from the factor.""" messages_eta, messages_lam = [], [] sdim = 0 @@ -519,33 +546,63 @@ def compute_messages(self, damping: float = 0.) -> None: for var in range(len(self.adj_vIDs)): if var != v: var_dofs = self.adj_var_nodes[var].dofs - eta_mess = self.adj_var_nodes[var].belief.eta - self.messages[var].eta - lam_mess = self.adj_var_nodes[var].belief.lam - self.messages[var].lam - eta_factor[start:start + var_dofs] += eta_mess - lam_factor[start:start + var_dofs, start:start + var_dofs] += lam_mess + eta_mess = ( + self.adj_var_nodes[var].belief.eta - self.messages[var].eta + ) + lam_mess = ( + self.adj_var_nodes[var].belief.lam - self.messages[var].lam + ) + eta_factor[start : start + var_dofs] += eta_mess + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess start += self.adj_var_nodes[var].dofs # Divide up parameters of distribution dofs = self.adj_var_nodes[v].dofs - eo = eta_factor[sdim:sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] lono = np.concatenate( - (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), - axis=1) + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) lnoo = np.concatenate( - (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), - axis=0) - lnono = np.concatenate(( - np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), - np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) - ), axis=0) + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), + axis=1, + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno - messages_eta.append((1 - damping) * new_message_eta + damping * self.messages[v].eta) - messages_lam.append((1 - damping) * new_message_lam + damping * self.messages[v].lam) + messages_eta.append( + (1 - damping) * new_message_eta + damping * self.messages[v].eta + ) + messages_lam.append( + (1 - damping) * new_message_lam + damping * self.messages[v].lam + ) sdim += self.adj_var_nodes[v].dofs for v in range(len(self.adj_vIDs)): @@ -569,7 +626,10 @@ def draw(i): for j, cov in enumerate(fg.belief_covs()): circle = plt.Circle( (means[j, 0], means[j, 1]), - np.sqrt(cov[0, 0]), linewidth=0.5, color='blue', fill=False + np.sqrt(cov[0, 0]), + linewidth=0.5, + color="blue", + fill=False, ) ax.add_patch(circle) @@ -578,27 +638,30 @@ def draw(i): for j, cov in enumerate(marg_covs): circle = plt.Circle( (map_soln[j, 0], map_soln[j, 1]), - np.sqrt(marg_covs[j]), linewidth=0.5, color='g', fill=False + np.sqrt(marg_covs[j]), + linewidth=0.5, + color="g", + fill=False, ) ax.add_patch(circle) # draw lines for factors for f in fg.factors: bels = np.array([means[f.adj_vIDs[0]], means[f.adj_vIDs[1]]]) - plt.plot(bels[:, 0], bels[:, 1], color='black', linewidth=0.3) + plt.plot(bels[:, 0], bels[:, 1], color="black", linewidth=0.3) # draw lines for belief error for i in range(len(means)): xs = [means[i, 0], map_soln[i, 0]] ys = [means[i, 1], map_soln[i, 1]] - plt.plot(xs, ys, color='grey', linewidth=0.3, linestyle='dashed') + plt.plot(xs, ys, color="grey", linewidth=0.3, linestyle="dashed") - plt.axis('scaled') + plt.axis("scaled") plt.xlim([-1, size]) plt.ylim([-1, size]) # convert to image - ax.axis('off') + ax.axis("off") fig.tight_layout(pad=0) ax.margins(0) fig.canvas.draw() @@ -608,6 +671,7 @@ def draw(i): return img + if __name__ == "__main__": np.random.seed(1) @@ -618,7 +682,7 @@ def draw(i): prior_noise_std = 0.2 gbp_settings = GBPSettings( - damping=0., + damping=0.0, beta=0.1, num_undamped_iters=1, min_linear_iters=10, @@ -634,7 +698,7 @@ def draw(i): fg = FactorGraph(gbp_settings) - init_noises = np.random.normal(np.zeros([size*size, 2]), prior_noise_std) + init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) for i in range(size): @@ -653,21 +717,21 @@ def draw(i): for i in range(size): for j in range(size): if j < size - 1: - meas = np.array([1., 0.]) + meas = np.array([1.0, 0.0]) meas += meas_noises[m] fg.add_factor( [i * size + j, i * size + j + 1], meas, - LinearDisplacementModel(SquaredLoss(dim, noise_cov)) + LinearDisplacementModel(SquaredLoss(dim, noise_cov)), ) m += 1 if i < size - 1: - meas = np.array([0., 1.]) + meas = np.array([0.0, 1.0]) meas += meas_noises[m] fg.add_factor( [i * size + j, (i + 1) * size + j], meas, - LinearDisplacementModel(SquaredLoss(dim, noise_cov)) + LinearDisplacementModel(SquaredLoss(dim, noise_cov)), ) m += 1 @@ -682,7 +746,7 @@ def draw(i): # # run gbp --------------- gbp_settings = GBPSettings( - damping=0., + damping=0.0, beta=0.1, num_undamped_iters=1, min_linear_iters=10, @@ -691,7 +755,9 @@ def draw(i): # fg.compute_all_messages() - import ipdb; ipdb.set_trace() + import ipdb + + ipdb.set_trace() # i = 0 n_iters = 5 @@ -706,13 +772,14 @@ def draw(i): fg.synchronous_iteration() i += 1 - for f in fg.factors: - for m in f.messages: - print(np.linalg.inv(m.lam) @ m.eta) + # for f in fg.factors: + # for m in f.messages: + # print(np.linalg.inv(m.lam) @ m.eta) print(fg.belief_means()) - import ipdb; ipdb.set_trace() + import ipdb + ipdb.set_trace() # time.sleep(0.05) diff --git a/theseus/optimizer/gbp/jax_torch_test.py b/theseus/optimizer/gbp/jax_torch_test.py index bf6aecca0..fb4fe3e97 100644 --- a/theseus/optimizer/gbp/jax_torch_test.py +++ b/theseus/optimizer/gbp/jax_torch_test.py @@ -1,9 +1,9 @@ -import numpy as np -import torch +import time + import jax import jax.numpy as jnp - -import time +import numpy as np +import torch def pass_fac_to_var_messages( @@ -21,11 +21,6 @@ def pass_fac_to_var_messages( adj_var_dofs = adj_var_dofs_nested[i] num_optim_vars = len(adj_var_dofs) - - inp_msgs_eta = vtof_msgs_eta[start: start + num_optim_vars] - inp_msgs_lam = vtof_msgs_lam[start: start + num_optim_vars] - - num_optim_vars = len(adj_var_dofs) ftov_eta, ftov_lam = [], [] sdim = 0 @@ -40,26 +35,48 @@ def pass_fac_to_var_messages( if var != v: eta_mess = vtof_msgs_eta[var] lam_mess = vtof_msgs_lam[var] - eta_factor[start_in:start_in + var_dofs] += eta_mess - lam_factor[start_in:start_in + var_dofs, start_in:start_in + var_dofs] += lam_mess + eta_factor[start_in : start_in + var_dofs] += eta_mess + lam_factor[ + start_in : start_in + var_dofs, start_in : start_in + var_dofs + ] += lam_mess start_in += var_dofs # Divide up parameters of distribution dofs = adj_var_dofs[v] - eo = eta_factor[sdim:sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] lono = np.concatenate( - (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), - axis=1) + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) lnoo = np.concatenate( - (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), - axis=0) - lnono = np.concatenate(( - np.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), - np.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) - ), axis=0) + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), + axis=1, + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno @@ -69,8 +86,8 @@ def pass_fac_to_var_messages( sdim += dofs - ftov_msgs_eta[start: start + num_optim_vars] = ftov_eta - ftov_msgs_lam[start: start + num_optim_vars] = ftov_lam + ftov_msgs_eta[start : start + num_optim_vars] = ftov_eta + ftov_msgs_lam[start : start + num_optim_vars] = ftov_lam start += num_optim_vars @@ -93,11 +110,6 @@ def pass_fac_to_var_messages_jax( adj_var_dofs = adj_var_dofs_nested[i] num_optim_vars = len(adj_var_dofs) - - inp_msgs_eta = vtof_msgs_eta[start: start + num_optim_vars] - inp_msgs_lam = vtof_msgs_lam[start: start + num_optim_vars] - - num_optim_vars = len(adj_var_dofs) ftov_eta, ftov_lam = [], [] sdim = 0 @@ -112,26 +124,50 @@ def pass_fac_to_var_messages_jax( if var != v: eta_mess = vtof_msgs_eta[var] lam_mess = vtof_msgs_lam[var] - eta_factor = eta_factor.at[start_in:start_in + var_dofs].add(eta_mess) - lam_factor = lam_factor.at[start_in:start_in + var_dofs, start_in:start_in + var_dofs].add(lam_mess) + eta_factor = eta_factor.at[start_in : start_in + var_dofs].add( + eta_mess + ) + lam_factor = lam_factor.at[ + start_in : start_in + var_dofs, start_in : start_in + var_dofs + ].add(lam_mess) start_in += var_dofs # Divide up parameters of distribution dofs = adj_var_dofs[v] - eo = eta_factor[sdim:sdim + dofs] - eno = jnp.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs:])) + eo = eta_factor[sdim : sdim + dofs] + eno = jnp.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim:sdim + dofs, sdim:sdim + dofs] + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] lono = jnp.concatenate( - (lam_factor[sdim:sdim + dofs, :sdim], lam_factor[sdim:sdim + dofs, sdim + dofs:]), - axis=1) + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) lnoo = jnp.concatenate( - (lam_factor[:sdim, sdim:sdim + dofs], lam_factor[sdim + dofs:, sdim:sdim + dofs]), - axis=0) - lnono = jnp.concatenate(( - jnp.concatenate((lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs:]), axis=1), - jnp.concatenate((lam_factor[sdim + dofs:, :sdim], lam_factor[sdim + dofs:, sdim + dofs:]), axis=1) - ), axis=0) + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = jnp.concatenate( + ( + jnp.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), + axis=1, + ), + jnp.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) new_message_lam = loo - lono @ jnp.linalg.inv(lnono) @ lnoo new_message_eta = eo - lono @ jnp.linalg.inv(lnono) @ eno @@ -141,105 +177,266 @@ def pass_fac_to_var_messages_jax( sdim += dofs - ftov_msgs_eta[start: start + num_optim_vars] = ftov_eta - ftov_msgs_lam[start: start + num_optim_vars] = ftov_lam + ftov_msgs_eta[start : start + num_optim_vars] = ftov_eta + ftov_msgs_lam[start : start + num_optim_vars] = ftov_lam start += num_optim_vars return ftov_msgs_eta, ftov_msgs_lam - - - if __name__ == "__main__": - adj_var_dofs_nested = [[2], [2], [2], [2], [2], [2], [2], [2], [2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]] - - potentials_eta = [torch.tensor([[0., 0.]]), torch.tensor([[ 0.5292, -0.1270]]), torch.tensor([[ 1.2858, -0.2724]]), torch.tensor([[0.2065, 0.5016]]), torch.tensor([[0.6295, 0.5622]]), torch.tensor([[1.3565, 0.3479]]), torch.tensor([[-0.0382, 1.1380]]), torch.tensor([[0.7259, 1.0533]]), torch.tensor([[1.1630, 1.0795]]), torch.tensor([[-100.4221, -5.8282, 100.4221, 5.8282]]), torch.tensor([[ 11.0062, -111.4472, -11.0062, 111.4472]]), torch.tensor([[-109.0159, -5.0249, 109.0159, 5.0249]]), torch.tensor([[ -9.0086, -93.1627, 9.0086, 93.1627]]), torch.tensor([[ 1.2289, -90.6423, -1.2289, 90.6423]]), torch.tensor([[-97.3211, -5.3036, 97.3211, 5.3036]]), torch.tensor([[ 6.9166, -96.0325, -6.9166, 96.0325]]), torch.tensor([[-93.1283, 8.4521, 93.1283, -8.4521]]), torch.tensor([[ 6.7125, -99.8733, -6.7125, 99.8733]]), torch.tensor([[ 11.1731, -102.3442, -11.1731, 102.3442]]), torch.tensor([[-116.5980, -7.4204, 116.5980, 7.4204]]), torch.tensor([[-98.0816, 8.8763, 98.0816, -8.8763]])] - potentials_lam = [torch.tensor([[[10000., 0.], - [ 0., 10000.]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[0.5917, 0.0000], - [0.0000, 0.5917]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]]), torch.tensor([[[ 100., 0., -100., 0.], - [ 0., 100., 0., -100.], - [-100., 0., 100., 0.], - [ 0., -100., 0., 100.]]])] - - vtof_msgs_eta = [torch.tensor([[ 0.8536, -1.5929]]), torch.tensor([[182.3461, 16.7745]]), torch.tensor([[222.8854, 13.1250]]), torch.tensor([[-10.1678, 202.9393]]), torch.tensor([[200.4927, 213.6843]]), torch.tensor([[264.5976, 132.6887]]), torch.tensor([[-17.9007, 222.3988]]), torch.tensor([[127.5813, 277.0478]]), torch.tensor([[191.0187, 201.1600]]), torch.tensor([[ 5.6620, -4.6983]]), torch.tensor([[83.3856, 10.9277]]), torch.tensor([[-4.8085, 3.1053]]), torch.tensor([[ 0.9854, 93.0631]]), torch.tensor([[153.1307, 16.3761]]), torch.tensor([[98.1263, 3.3349]]), torch.tensor([[129.7635, 5.8644]]), torch.tensor([[140.2319, 158.5661]]), torch.tensor([[127.3308, 9.2454]]), torch.tensor([[187.8824, 92.8337]]), torch.tensor([[-16.5414, 145.2973]]), torch.tensor([[152.8149, 148.6686]]), torch.tensor([[ -4.1601, 169.0230]]), torch.tensor([[-12.0344, 99.1287]]), torch.tensor([[153.7062, 168.3496]]), torch.tensor([[149.0974, 72.7772]]), torch.tensor([[157.2429, 167.7175]]), torch.tensor([[ 70.8858, 152.1307]]), torch.tensor([[196.2848, 100.8102]]), torch.tensor([[ 99.5512, 100.5530]]), torch.tensor([[ -5.9426, 125.5461]]), torch.tensor([[ 87.5787, 197.8408]]), torch.tensor([[ 98.8758, 207.2840]]), torch.tensor([[ 93.7936, 102.7661]])] - vtof_msgs_lam = [torch.tensor([[95.7949, 0.0000], - [ 0.0000, 95.7949]]), torch.tensor([[190.3769, 0.0000], - [ 0.0000, 190.3769]]), torch.tensor([[109.9605, 0.0000], - [ 0.0000, 109.9605]]), torch.tensor([[190.3769, 0.0000], - [ 0.0000, 190.3769]]), torch.tensor([[197.8604, 0.0000], - [ 0.0000, 197.8604]]), torch.tensor([[132.5915, 0.0000], - [ 0.0000, 132.5915]]), torch.tensor([[109.9605, 0.0000], - [ 0.0000, 109.9605]]), torch.tensor([[132.5915, 0.0000], - [ 0.0000, 132.5915]]), torch.tensor([[99.8496, 0.0000], - [ 0.0000, 99.8496]]), torch.tensor([[10047.8975, 0.0000], - [ 0.0000, 10047.8975]]), torch.tensor([[91.9540, 0.0000], - [ 0.0000, 91.9540]]), torch.tensor([[10047.8975, 0.0000], - [ 0.0000, 10047.8975]]), torch.tensor([[91.9540, 0.0000], - [ 0.0000, 91.9540]]), torch.tensor([[158.0642, 0.0000], - [ 0.0000, 158.0642]]), torch.tensor([[49.3043, 0.0000], - [ 0.0000, 49.3043]]), torch.tensor([[132.5106, 0.0000], - [ 0.0000, 132.5106]]), torch.tensor([[141.4631, 0.0000], - [ 0.0000, 141.4631]]), torch.tensor([[61.8396, 0.0000], - [ 0.0000, 61.8396]]), torch.tensor([[94.9975, 0.0000], - [ 0.0000, 94.9975]]), torch.tensor([[132.5106, 0.0000], - [ 0.0000, 132.5106]]), torch.tensor([[141.4631, 0.0000], - [ 0.0000, 141.4631]]), torch.tensor([[158.0642, 0.0000], - [ 0.0000, 158.0642]]), torch.tensor([[49.3043, 0.0000], - [ 0.0000, 49.3043]]), torch.tensor([[156.5110, 0.0000], - [ 0.0000, 156.5110]]), torch.tensor([[72.2502, 0.0000], - [ 0.0000, 72.2502]]), torch.tensor([[156.5110, 0.0000], - [ 0.0000, 156.5110]]), torch.tensor([[72.2502, 0.0000], - [ 0.0000, 72.2502]]), torch.tensor([[99.7104, 0.0000], - [ 0.0000, 99.7104]]), torch.tensor([[50.5165, 0.0000], - [ 0.0000, 50.5165]]), torch.tensor([[61.8396, 0.0000], - [ 0.0000, 61.8396]]), torch.tensor([[94.9975, 0.0000], - [ 0.0000, 94.9975]]), torch.tensor([[99.7104, 0.0000], - [ 0.0000, 99.7104]]), torch.tensor([[50.5165, 0.0000], - [ 0.0000, 50.5165]])] - vtof_msgs_eta = torch.cat(vtof_msgs_eta) + adj_var_dofs_nested = [ + [2], + [2], + [2], + [2], + [2], + [2], + [2], + [2], + [2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + [2, 2], + ] + + potentials_eta = [ + torch.tensor([[0.0, 0.0]]), + torch.tensor([[0.5292, -0.1270]]), + torch.tensor([[1.2858, -0.2724]]), + torch.tensor([[0.2065, 0.5016]]), + torch.tensor([[0.6295, 0.5622]]), + torch.tensor([[1.3565, 0.3479]]), + torch.tensor([[-0.0382, 1.1380]]), + torch.tensor([[0.7259, 1.0533]]), + torch.tensor([[1.1630, 1.0795]]), + torch.tensor([[-100.4221, -5.8282, 100.4221, 5.8282]]), + torch.tensor([[11.0062, -111.4472, -11.0062, 111.4472]]), + torch.tensor([[-109.0159, -5.0249, 109.0159, 5.0249]]), + torch.tensor([[-9.0086, -93.1627, 9.0086, 93.1627]]), + torch.tensor([[1.2289, -90.6423, -1.2289, 90.6423]]), + torch.tensor([[-97.3211, -5.3036, 97.3211, 5.3036]]), + torch.tensor([[6.9166, -96.0325, -6.9166, 96.0325]]), + torch.tensor([[-93.1283, 8.4521, 93.1283, -8.4521]]), + torch.tensor([[6.7125, -99.8733, -6.7125, 99.8733]]), + torch.tensor([[11.1731, -102.3442, -11.1731, 102.3442]]), + torch.tensor([[-116.5980, -7.4204, 116.5980, 7.4204]]), + torch.tensor([[-98.0816, 8.8763, 98.0816, -8.8763]]), + ] + potentials_lam = [ + torch.tensor([[[10000.0, 0.0], [0.0, 10000.0]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + torch.tensor( + [ + [ + [100.0, 0.0, -100.0, 0.0], + [0.0, 100.0, 0.0, -100.0], + [-100.0, 0.0, 100.0, 0.0], + [0.0, -100.0, 0.0, 100.0], + ] + ] + ), + ] + + vtof_msgs_eta_list = [ + torch.tensor([[0.8536, -1.5929]]), + torch.tensor([[182.3461, 16.7745]]), + torch.tensor([[222.8854, 13.1250]]), + torch.tensor([[-10.1678, 202.9393]]), + torch.tensor([[200.4927, 213.6843]]), + torch.tensor([[264.5976, 132.6887]]), + torch.tensor([[-17.9007, 222.3988]]), + torch.tensor([[127.5813, 277.0478]]), + torch.tensor([[191.0187, 201.1600]]), + torch.tensor([[5.6620, -4.6983]]), + torch.tensor([[83.3856, 10.9277]]), + torch.tensor([[-4.8085, 3.1053]]), + torch.tensor([[0.9854, 93.0631]]), + torch.tensor([[153.1307, 16.3761]]), + torch.tensor([[98.1263, 3.3349]]), + torch.tensor([[129.7635, 5.8644]]), + torch.tensor([[140.2319, 158.5661]]), + torch.tensor([[127.3308, 9.2454]]), + torch.tensor([[187.8824, 92.8337]]), + torch.tensor([[-16.5414, 145.2973]]), + torch.tensor([[152.8149, 148.6686]]), + torch.tensor([[-4.1601, 169.0230]]), + torch.tensor([[-12.0344, 99.1287]]), + torch.tensor([[153.7062, 168.3496]]), + torch.tensor([[149.0974, 72.7772]]), + torch.tensor([[157.2429, 167.7175]]), + torch.tensor([[70.8858, 152.1307]]), + torch.tensor([[196.2848, 100.8102]]), + torch.tensor([[99.5512, 100.5530]]), + torch.tensor([[-5.9426, 125.5461]]), + torch.tensor([[87.5787, 197.8408]]), + torch.tensor([[98.8758, 207.2840]]), + torch.tensor([[93.7936, 102.7661]]), + ] + vtof_msgs_lam = [ + torch.tensor([[95.7949, 0.0000], [0.0000, 95.7949]]), + torch.tensor([[190.3769, 0.0000], [0.0000, 190.3769]]), + torch.tensor([[109.9605, 0.0000], [0.0000, 109.9605]]), + torch.tensor([[190.3769, 0.0000], [0.0000, 190.3769]]), + torch.tensor([[197.8604, 0.0000], [0.0000, 197.8604]]), + torch.tensor([[132.5915, 0.0000], [0.0000, 132.5915]]), + torch.tensor([[109.9605, 0.0000], [0.0000, 109.9605]]), + torch.tensor([[132.5915, 0.0000], [0.0000, 132.5915]]), + torch.tensor([[99.8496, 0.0000], [0.0000, 99.8496]]), + torch.tensor([[10047.8975, 0.0000], [0.0000, 10047.8975]]), + torch.tensor([[91.9540, 0.0000], [0.0000, 91.9540]]), + torch.tensor([[10047.8975, 0.0000], [0.0000, 10047.8975]]), + torch.tensor([[91.9540, 0.0000], [0.0000, 91.9540]]), + torch.tensor([[158.0642, 0.0000], [0.0000, 158.0642]]), + torch.tensor([[49.3043, 0.0000], [0.0000, 49.3043]]), + torch.tensor([[132.5106, 0.0000], [0.0000, 132.5106]]), + torch.tensor([[141.4631, 0.0000], [0.0000, 141.4631]]), + torch.tensor([[61.8396, 0.0000], [0.0000, 61.8396]]), + torch.tensor([[94.9975, 0.0000], [0.0000, 94.9975]]), + torch.tensor([[132.5106, 0.0000], [0.0000, 132.5106]]), + torch.tensor([[141.4631, 0.0000], [0.0000, 141.4631]]), + torch.tensor([[158.0642, 0.0000], [0.0000, 158.0642]]), + torch.tensor([[49.3043, 0.0000], [0.0000, 49.3043]]), + torch.tensor([[156.5110, 0.0000], [0.0000, 156.5110]]), + torch.tensor([[72.2502, 0.0000], [0.0000, 72.2502]]), + torch.tensor([[156.5110, 0.0000], [0.0000, 156.5110]]), + torch.tensor([[72.2502, 0.0000], [0.0000, 72.2502]]), + torch.tensor([[99.7104, 0.0000], [0.0000, 99.7104]]), + torch.tensor([[50.5165, 0.0000], [0.0000, 50.5165]]), + torch.tensor([[61.8396, 0.0000], [0.0000, 61.8396]]), + torch.tensor([[94.9975, 0.0000], [0.0000, 94.9975]]), + torch.tensor([[99.7104, 0.0000], [0.0000, 99.7104]]), + torch.tensor([[50.5165, 0.0000], [0.0000, 50.5165]]), + ] + vtof_msgs_eta = torch.cat(vtof_msgs_eta_list) # vtof_msgs_lam = torch.cat([m[None, ...] for m in vtof_msgs_lam]) t1 = time.time() @@ -263,7 +460,6 @@ def pass_fac_to_var_messages_jax( # print(ftov_msgs_eta) # print(ftov_msgs_lam) - potentials_eta_jax = [jnp.array(pe) for pe in potentials_eta] potentials_lam_jax = [jnp.array(pe) for pe in potentials_lam] vtof_msgs_eta_jax = jnp.array(vtof_msgs_eta) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py new file mode 100644 index 000000000..a5dd550e4 --- /dev/null +++ b/theseus/optimizer/gbp/pgo_test.py @@ -0,0 +1,161 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import theseus as th +from theseus.optimizer.gbp import GaussianBeliefPropagation + +# This example illustrates the Gaussian Belief Propagation (GBP) optimizer +# for a 2D pose graph optimization problem. +# Linear problem where we are estimating the (x, y)position of 9 nodes, +# arranged in a 3x3 grid. +# Linear factors connect each node to its adjacent nodes. + +np.random.seed(1) +torch.manual_seed(0) + +size = 3 +dim = 2 + +noise_cov = np.array([0.01, 0.01]) + +prior_noise_std = 0.2 +prior_sigma = np.array([1.3**2, 1.3**2]) + +init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) +meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) + +# create theseus objective ------------------------------------- + +objective = th.Objective() +inputs = {} + +n_poses = size * size + +# create variables +poses = [] +for i in range(n_poses): + poses.append(th.Vector(data=torch.rand(1, 2), name=f"x{i}")) + +# add prior cost constraints with VariableDifference cost +prior_std = 1.3 +anchor_std = 0.01 +prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") +anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") + +p = 0 +for i in range(size): + for j in range(size): + init = torch.Tensor([j, i]) + if i == 0 and j == 0: + w = anchor_w + else: + # noise_init = torch.normal(torch.zeros(2), prior_noise_std) + init = init + torch.FloatTensor(init_noises[p]) + w = prior_w + + prior_target = th.Vector(data=init, name=f"prior_{p}") + inputs[f"x{p}"] = init[None, :] + inputs[f"prior_{p}"] = init[None, :] + + cf_prior = th.eb.VariableDifference( + poses[p], w, prior_target, name=f"prior_cost_{p}" + ) + + objective.add(cf_prior) + + p += 1 + +# Measurement cost functions + +meas_std = 0.1 +meas_w = th.ScaleCostWeight(1 / meas_std, name="prior_weight") + +m = 0 +for i in range(size): + for j in range(size): + if j < size - 1: + measurement = torch.Tensor([1.0, 0.0]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = i * size + j + 1 + + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] + + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) + m += 1 + + if i < size - 1: + measurement = torch.Tensor([0.0, 1.0]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = (i + 1) * size + j + + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] + + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) + m += 1 + +# # objective.update(init_dict) +# print("Initial cost:", objective.error_squared_norm()) + +# fg.print(brief=True) + +# # for vis --------------- + +# joint = fg.get_joint() +# marg_covs = np.diag(joint.cov())[::2] +# map_soln = fg.MAP().reshape([size * size, 2]) + +# Solve with Gauss Newton --------------- + +# print("inputs", inputs) + +optimizer = GaussianBeliefPropagation( + objective, + max_iterations=100, +) +theseus_optim = th.TheseusLayer(optimizer) + +optim_arg = { + "track_best_solution": True, + "track_err_history": True, + "verbose": True, + "backward_mode": th.BackwardMode.FULL, + "damping": 0.6, + "dropout": 0.0, +} +updated_inputs, info = theseus_optim.forward(inputs, optim_arg) + +print("updated_inputs", updated_inputs) +print("info", info) + + +# optimizer = th.GaussNewton( +# objective, +# max_iterations=15, +# step_size=0.5, +# ) +# theseus_optim = th.TheseusLayer(optimizer) + +# with torch.no_grad(): +# optim_args = {"track_best_solution": True, "verbose": True} +# updated_inputs, info = theseus_optim.forward(inputs, optim_args) +# print("updated_inputs", updated_inputs) +# print("info", info) + +# import ipdb; ipdb.set_trace() From 4479f3b5c5262dcd2f06ddc0087a3f1d3abb1c9a Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Thu, 7 Apr 2022 17:04:05 +0100 Subject: [PATCH 06/64] updated gaussian class --- theseus/optimizer/gbp/gbp.py | 145 +++++++++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 30 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index ef32d5b15..a5c4c1839 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -6,6 +6,7 @@ import abc import math from dataclasses import dataclass +from itertools import count from typing import Any, Dict, List, Optional, Sequence, Tuple, Type import numpy as np @@ -24,9 +25,6 @@ """ TODO - - Parallelise factor to variable message comp - - Benchmark speed - - test jax implementation of message comp functions - add class for message schedule - damping for lie algebra vars - solving inverse problem to compute message mean @@ -52,27 +50,114 @@ def update(self, params_dict): raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") -# Stores variable beliefs that converge towards the marginals class Gaussian: - def __init__(self, variable: Manifold): - self.name = variable.name + "_gaussian" - self.mean = variable - self.tot_dof = variable.dof() - - # tot_dof = 0 - # for v in variables: - # tot_dof += v.dof() - # self.tot_dof = tot_dof - - self.precision = torch.zeros( - self.mean.shape[0], self.tot_dof, self.tot_dof, dtype=variable.dtype + _ids = count(0) + + def __init__( + self, + mean: Sequence[Manifold], + precision: Optional[torch.Tensor] = None, + name: Optional[str] = None, + ): + self._id = next(Gaussian._ids) + if name: + self.name = name + else: + self.name = f"{self.__class__.__name__}__{self._id}" + + dof = 0 + for v in mean: + dof += v.dof() + self._dof = dof + + self.mean = mean + self.precision = torch.zeros(mean[0].shape[0], self.dof, self.dof).to( + dtype=mean[0].dtype, device=mean[0].device ) + @property def dof(self) -> int: - return self.tot_dof + return self._dof + + @property + def device(self) -> torch.device: + return self.precision[0].device + + @property + def dtype(self) -> torch.dtype: + return self.precision[0].dtype + + # calls to() on the internal tensors + def to(self, *args, **kwargs): + for var in self.mean: + var = var.to(*args, **kwargs) + self.precision = self.precision.to(*args, **kwargs) + + def copy(self, new_name: Optional[str] = None) -> "Gaussian": + if not new_name: + new_name = f"{self.name}_copy" + mean_copy = [var.copy() for var in self.mean] + return Gaussian(mean_copy, name=new_name) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + the_copy = self.copy() + memo[id(self)] = the_copy + return the_copy + + def update( + self, + mean: Optional[Sequence[Manifold]] = None, + precision: Optional[torch.Tensor] = None, + ): + if mean is not None: + if len(mean) != len(self.mean): + raise ValueError( + f"Tried to update mean with sequence of different" + f"lenght to original mean sequence. Given {len(mean)}. " + f"Expected: {len(self.mean)}" + ) + for i in range(len(self.mean)): + self.mean[i].update(mean[i]) + + if precision is not None: + if precision.shape != self.precision.shape: + raise ValueError( + f"Tried to update precision with data " + f"incompatible with original tensor shape. Given {precision.shape}. " + f"Expected: {self.precision.shape}" + ) + if precision.dtype != self.dtype: + raise ValueError( + f"Tried to update using tensor of dtype {precision.dtype} but precision " + f"has dtype {self.dtype}." + ) + + self.precision = precision -class Marginals(Gaussian): +# # Stores variable beliefs that converge towards the marginals +# class Gaussian: +# def __init__(self, variable: Manifold): +# self.name = variable.name + "_gaussian" +# self.mean = variable +# self.tot_dof = variable.dof() + +# # tot_dof = 0 +# # for v in variables: +# # tot_dof += v.dof() +# # self.tot_dof = tot_dof + +# self.precision = torch.zeros( +# self.mean.shape[0], self.tot_dof, self.tot_dof, dtype=variable.dtype +# ) + +# def dof(self) -> int: +# return self.tot_dof + + +class Marginal(Gaussian): pass @@ -212,7 +297,7 @@ def local_gaussian( return_mean: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor]: # mean_tp is message mean in tangent space / plane at var - mean_tp = var.local(mess.mean) + mean_tp = var.local(mess.mean[0]) jac: List[torch.Tensor] = [] th.exp_map(var, mean_tp, jacobians=jac) @@ -246,8 +331,7 @@ def retract_gaussian( inv_jac = torch.inverse(jac[0]) precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), prec_tp), inv_jac) - out_gauss.mean.update(mean.data) - out_gauss.precision = precision + out_gauss.update(mean=[mean], precision=precision) def pass_var_to_fac_messages_and_update_beliefs_lie( @@ -281,7 +365,7 @@ def pass_var_to_fac_messages_and_update_beliefs_lie( lam_a = lams_inc.sum(dim=0) if lam_a.count_nonzero() == 0: - vtof_msgs[j].mean.data[:] = 0.0 + vtof_msgs[j].mean[0].data[:] = 0.0 vtof_msgs[j].precision = lam_a else: inv_lam_a = torch.inverse(lam_a) @@ -297,7 +381,7 @@ def pass_var_to_fac_messages_and_update_beliefs_lie( sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) - belief = Gaussian(var) + belief = Gaussian([var]) retract_gaussian(tau, lam_tau, var, belief) @@ -481,14 +565,14 @@ def ftov_comp_mess_lie( # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] if new_mess_lam.count_nonzero() == 0: - new_mess = Gaussian(lin_points[v].copy()) - new_mess.mean.data[:] = 0.0 + new_mess = Gaussian([lin_points[v].copy()]) + new_mess.mean[0].data[:] = 0.0 else: new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) new_mess_mean = new_mess_mean[None, ...] new_mess_lam = new_mess_lam[None, ...] - new_mess = Gaussian(lin_points[v].copy()) + new_mess = Gaussian([lin_points[v].copy()]) retract_gaussian(new_mess_mean, new_mess_lam, lin_points[v], new_mess) new_messages.append(new_mess) @@ -496,8 +580,9 @@ def ftov_comp_mess_lie( # update messages for v in range(num_optim_vars): - ftov_msgs[v].mean.update(new_messages[v].mean.data) - ftov_msgs[v].precision = new_messages[v].precision + ftov_msgs[v].update( + mean=new_messages[v].mean, precision=new_messages[v].precision + ) return new_messages @@ -846,8 +931,8 @@ def _optimize_loop_lie( # mean of initial message doesn't matter as long as precision is zero vtof_msg_mu.data[:] = 0 ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") - vtof_msgs.append(Message(vtof_msg_mu)) - ftov_msgs.append(Message(ftov_msg_mu)) + vtof_msgs.append(Message([vtof_msg_mu])) + ftov_msgs.append(Message([ftov_msg_mu])) # compute factor potentials for the first time potentials_eta = [None] * self.objective.size_cost_functions() From fc12a98ff0990b5083e96211a6d7b29c1d545ab7 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 8 Apr 2022 10:03:10 +0100 Subject: [PATCH 07/64] message scheduler --- theseus/optimizer/gbp/__init__.py | 2 +- theseus/optimizer/gbp/gbp.py | 876 +++++++------------- theseus/optimizer/gbp/gbp_euclidean.py | 1047 ++++++++++++++++++++++++ theseus/optimizer/gbp/pgo_test.py | 7 +- 4 files changed, 1369 insertions(+), 563 deletions(-) create mode 100644 theseus/optimizer/gbp/gbp_euclidean.py diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py index b27e0742d..5a20612f2 100644 --- a/theseus/optimizer/gbp/__init__.py +++ b/theseus/optimizer/gbp/__init__.py @@ -3,4 +3,4 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .gbp import GaussianBeliefPropagation +from .gbp import GaussianBeliefPropagation, random_schedule, synchronous_schedule diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index a5c4c1839..5c79349be 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from itertools import count -from typing import Any, Dict, List, Optional, Sequence, Tuple, Type +from typing import Dict, List, Optional, Sequence, Tuple import numpy as np import torch @@ -16,7 +16,7 @@ import theseus.constants from theseus.core import CostFunction, Objective from theseus.geometry import Manifold -from theseus.optimizer import Linearization, Optimizer, VariableOrdering +from theseus.optimizer import Optimizer, VariableOrdering from theseus.optimizer.nonlinear.nonlinear_optimizer import ( BackwardMode, NonlinearOptimizerInfo, @@ -36,6 +36,7 @@ """ +# Same of NonlinearOptimizerParams but without step size @dataclass class GBPOptimizerParams: abs_err_tolerance: float @@ -50,7 +51,18 @@ def update(self, params_dict): raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") -class Gaussian: +def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: + return torch.full([max_iters, n_edges], True) + + +def random_schedule(max_iters, n_edges) -> torch.Tensor: + schedule = torch.full([max_iters, n_edges], False) + ixs = torch.randint(0, n_edges, [max_iters]) + schedule[torch.arange(max_iters), ixs] = True + return schedule + + +class ManifoldGaussian: _ids = count(0) def __init__( @@ -59,7 +71,7 @@ def __init__( precision: Optional[torch.Tensor] = None, name: Optional[str] = None, ): - self._id = next(Gaussian._ids) + self._id = next(ManifoldGaussian._ids) if name: self.name = name else: @@ -93,11 +105,11 @@ def to(self, *args, **kwargs): var = var.to(*args, **kwargs) self.precision = self.precision.to(*args, **kwargs) - def copy(self, new_name: Optional[str] = None) -> "Gaussian": + def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian": if not new_name: new_name = f"{self.name}_copy" mean_copy = [var.copy() for var in self.mean] - return Gaussian(mean_copy, name=new_name) + return ManifoldGaussian(mean_copy, name=new_name) def __deepcopy__(self, memo): if id(self) in memo: @@ -137,32 +149,64 @@ def update( self.precision = precision -# # Stores variable beliefs that converge towards the marginals -# class Gaussian: -# def __init__(self, variable: Manifold): -# self.name = variable.name + "_gaussian" -# self.mean = variable -# self.tot_dof = variable.dof() +class Marginal(ManifoldGaussian): + pass -# # tot_dof = 0 -# # for v in variables: -# # tot_dof += v.dof() -# # self.tot_dof = tot_dof -# self.precision = torch.zeros( -# self.mean.shape[0], self.tot_dof, self.tot_dof, dtype=variable.dtype -# ) +class Message(ManifoldGaussian): + pass -# def dof(self) -> int: -# return self.tot_dof +""" +Local and retract +These could be implemented as methods in Manifold class +""" -class Marginal(Gaussian): - pass +# Transforms message gaussian to tangent plane at var +# if return_mean is True, return the (mean, lam) else return (eta, lam). +# Generalises the local function by transforming the covariance as well as mean. +def local_gaussian( + mess: Message, + var: th.LieGroup, + return_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + # mean_tp is message mean in tangent space / plane at var + mean_tp = var.local(mess.mean[0]) + + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) -class Message(Gaussian): - pass + # lam_tp is the precision matrix in the tangent plane + lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) + + if return_mean: + return mean_tp, lam_tp + + else: + eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) + return eta_tp, lam_tp + + +# Transforms Gaussian in the tangent plane at var to Gaussian where the mean +# is a group element and the precision matrix is defined in the tangent plane +# at the mean. +# Generalises the retract function by transforming the covariance as well as mean. +# out_gauss is the transformed Gaussian that is updated in place. +def retract_gaussian( + mean_tp: torch.Tensor, + precision_tp: torch.Tensor, + var: th.LieGroup, + out_gauss: ManifoldGaussian, +): + mean = var.retract(mean_tp) + + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) + inv_jac = torch.inverse(jac[0]) + precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac) + + out_gauss.update(mean=[mean], precision=precision) class CostFunctionOrdering: @@ -222,381 +266,181 @@ def complete(self): """ -# Compute the factor at current adjacent beliefs. -def compute_factor(cf, lie=True): - J, error = cf.weighted_jacobians_error() - J_stk = torch.cat(J, dim=-1) - - lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) - - optim_vars_stk = torch.cat([v.data for v in cf.optim_vars], dim=-1) - eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) - if lie is False: - eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) - eta = eta.squeeze(-1) - - return eta, lam - - -def pass_var_to_fac_messages( - ftov_msgs_eta, - ftov_msgs_lam, - var_ix_for_edges, - n_vars, - max_dofs, -): - belief_eta = torch.zeros(n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) - belief_lam = torch.zeros(n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) - - belief_eta = belief_eta.index_add(0, var_ix_for_edges, ftov_msgs_eta) - belief_lam = belief_lam.index_add(0, var_ix_for_edges, ftov_msgs_lam) - - vtof_msgs_eta = belief_eta[var_ix_for_edges] - ftov_msgs_eta - vtof_msgs_lam = belief_lam[var_ix_for_edges] - ftov_msgs_lam +class Factor: + _ids = count(0) - return vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam + def __init__( + self, + cf: CostFunction, + name: Optional[str] = None, + ): + self._id = next(Factor._ids) + if name: + self.name = name + else: + self.name = f"{self.__class__.__name__}__{self._id}" + self.cf = cf -def pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - adj_var_dofs_nested: List[List], -): - ftov_msgs_eta = torch.zeros_like(vtof_msgs_eta) - ftov_msgs_lam = torch.zeros_like(vtof_msgs_lam) - - start = 0 - for i in range(len(adj_var_dofs_nested)): - adj_var_dofs = adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - ftov_eta, ftov_lam = ftov_comp_mess( - adj_var_dofs, - potentials_eta[i], - potentials_lam[i], - vtof_msgs_eta[start : start + num_optim_vars], - vtof_msgs_lam[start : start + num_optim_vars], + batch_size = cf.optim_var_at(0).shape[0] + self._dof = sum([var.dof() for var in cf.optim_vars]) + self.potential_eta = torch.zeros(batch_size, self.dof).to( + dtype=cf.optim_var_at(0).dtype, device=cf.optim_var_at(0).device ) + self.potential_lam = torch.zeros(batch_size, self.dof, self.dof).to( + dtype=cf.optim_var_at(0).dtype, device=cf.optim_var_at(0).device + ) + self.lin_point = [ + var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars + ] - ftov_msgs_eta[start : start + num_optim_vars] = torch.cat(ftov_eta) - ftov_msgs_lam[start : start + num_optim_vars] = torch.cat(ftov_lam) - - start += num_optim_vars - - return ftov_msgs_eta, ftov_msgs_lam - - -# Transforms message gaussian to tangent plane at var -# if return_mean is True, return the (mean, lam) else return (eta, lam). -# Generalises the local function by transforming the covariance as well as mean. -def local_gaussian( - mess: Message, - var: th.LieGroup, - return_mean: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - # mean_tp is message mean in tangent space / plane at var - mean_tp = var.local(mess.mean[0]) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - - # lam_tp is the precision matrix in the tangent plane - lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) - - if return_mean: - return mean_tp, lam_tp - - else: - eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) - return eta_tp, lam_tp - - -# Transforms Gaussian in the tangent plane at var to Gaussian where the mean -# is a group element and the precision matrix is defined in the tangent plane -# at the mean. -# Generalises the retract function by transforming the covariance as well as mean. -# out_gauss is the transformed Gaussian that is updated in place. -def retract_gaussian( - mean_tp: torch.Tensor, - prec_tp: torch.Tensor, - var: th.LieGroup, - out_gauss: Gaussian, -): - mean = var.retract(mean_tp) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - inv_jac = torch.inverse(jac[0]) - precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), prec_tp), inv_jac) - - out_gauss.update(mean=[mean], precision=precision) + self.linearize() + # Linearizes factors at current belief if beliefs have deviated + # from the linearization point by more than the threshold. + def linearize( + self, + relin_threshold: float = None, + lie=True, + ): + do_lin = False + if relin_threshold is None: + do_lin = True + else: + lp_dists = [ + lp.local(self.cf.optim_var_at(j)).norm() + for j, lp in enumerate(self.lin_point) + ] + do_lin = np.max(lp_dists) > relin_threshold -def pass_var_to_fac_messages_and_update_beliefs_lie( - ftov_msgs, - vtof_msgs, - var_ordering, - var_ix_for_edges, -): - for i, var in enumerate(var_ordering): - - # Collect all incoming messages in the tangent space at the current belief - taus = [] # message means - lams_tp = [] # message lams - for j, msg in enumerate(ftov_msgs): - if var_ix_for_edges[j] == i: - tau, lam_tp = local_gaussian(msg, var, return_mean=True) - taus.append(tau[None, ...]) - lams_tp.append(lam_tp[None, ...]) - - taus = torch.cat(taus) - lams_tp = torch.cat(lams_tp) - - lam_tau = lams_tp.sum(dim=0) - - # Compute outgoing messages - ix = 0 - for j, msg in enumerate(ftov_msgs): - if var_ix_for_edges[j] == i: - taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) - lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) - - lam_a = lams_inc.sum(dim=0) - if lam_a.count_nonzero() == 0: - vtof_msgs[j].mean[0].data[:] = 0.0 - vtof_msgs[j].precision = lam_a - else: - inv_lam_a = torch.inverse(lam_a) - sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum(dim=0) - tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) - retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) - ix += 1 - - # update belief mean and variance - # if no incoming messages then leave current belief unchanged - if lam_tau.count_nonzero() != 0: - inv_lam_tau = torch.inverse(lam_tau) - sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) - tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) - - belief = Gaussian([var]) - retract_gaussian(tau, lam_tau, var, belief) - - -def pass_fac_to_var_messages_lie( - potentials_eta, - potentials_lam, - lin_points, - vtof_msgs, - ftov_msgs, - adj_var_dofs_nested: List[List], - damping: torch.Tensor, -): - start = 0 - for i in range(len(adj_var_dofs_nested)): - adj_var_dofs = adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - ftov_comp_mess_lie( - potentials_eta[i], - potentials_lam[i], - lin_points[i], - vtof_msgs[start : start + num_optim_vars], - ftov_msgs[start : start + num_optim_vars], - damping[start : start + num_optim_vars], - ) + if do_lin: + J, error = self.cf.weighted_jacobians_error() + J_stk = torch.cat(J, dim=-1) - start += num_optim_vars + lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) + optim_vars_stk = torch.cat([v.data for v in self.cf.optim_vars], dim=-1) + eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) + if lie is False: + eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) + eta = eta.squeeze(-1) -# Compute all outgoing messages from the factor. -def ftov_comp_mess( - adj_var_dofs, - potential_eta, - potential_lam, - vtof_msgs_eta, - vtof_msgs_lam, -): - num_optim_vars = len(adj_var_dofs) - messages_eta, messages_lam = [], [] + self.potential_eta = eta + self.potential_lam = lam - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potential_eta.clone()[0] - lam_factor = potential_lam.clone()[0] + for j, var in enumerate(self.cf.optim_vars): + self.lin_point[j].update(var.data) - # Take product of factor with incoming messages - start = 0 - for var in range(num_optim_vars): - var_dofs = adj_var_dofs[var] - if var != v: - eta_mess = vtof_msgs_eta[var] - lam_mess = vtof_msgs_lam[var] - eta_factor[start : start + var_dofs] += eta_mess - lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess - start += var_dofs - - # Divide up parameters of distribution - dofs = adj_var_dofs[v] - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 + # Compute all outgoing messages from the factor. + def comp_mess( + self, + vtof_msgs, + ftov_msgs, + damping, + ): + num_optim_vars = self.cf.num_optim_vars() + new_messages = [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = self.potential_eta.clone()[0] + lam_factor = self.potential_lam.clone()[0] + + # Take product of factor with incoming messages. + # Convert mesages to tangent space at linearisation point. + start = 0 + for i in range(num_optim_vars): + var_dofs = self.cf.optim_var_at(i).dof() + if i != v: + eta_mess, lam_mess = local_gaussian( + vtof_msgs[i], self.lin_point[i], return_mean=False + ) + eta_factor[start : start + var_dofs] += eta_mess[0] + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess[0] + start += var_dofs + + # Divide up parameters of distribution + dofs = self.cf.optim_var_at(v).dof() + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) + + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] + lono = np.concatenate( + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], + axis=1, + ) + lnoo = np.concatenate( + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), + axis=1, + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, ), - axis=1, ), - ), - axis=0, - ) - - new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno - - messages_eta.append(new_message_eta[None, :]) - messages_lam.append(new_message_lam[None, :]) - - sdim += dofs + axis=0, + ) - return messages_eta, messages_lam + new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno + # damping in tangent space at linearisation point + # prev_mess_eta, prev_mess_lam = local_gaussian( + # vtof_msgs[v], lin_points[v], return_mean=False) + # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] + # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] -# Compute all outgoing messages from the factor. -def ftov_comp_mess_lie( - potential_eta, - potential_lam, - lin_points, - vtof_msgs, - ftov_msgs, - damping, -): - num_optim_vars = len(lin_points) - new_messages = [] - - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potential_eta.clone()[0] - lam_factor = potential_lam.clone()[0] + if new_mess_lam.count_nonzero() == 0: + new_mess = ManifoldGaussian([self.cf.optim_var_at(v).copy()]) + new_mess.mean[0].data[:] = 0.0 + else: + new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) + new_mess_mean = new_mess_mean[None, ...] + new_mess_lam = new_mess_lam[None, ...] - # Take product of factor with incoming messages. - # Convert mesages to tangent space at linearisation point. - start = 0 - for i in range(num_optim_vars): - var_dofs = lin_points[i].dof() - if i != v: - eta_mess, lam_mess = local_gaussian( - vtof_msgs[i], lin_points[i], return_mean=False + new_mess = ManifoldGaussian([self.cf.optim_var_at(v).copy()]) + retract_gaussian( + new_mess_mean, new_mess_lam, self.lin_point[v], new_mess ) - eta_factor[start : start + var_dofs] += eta_mess[0] - lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess[0] - start += var_dofs - - # Divide up parameters of distribution - dofs = lin_points[v].dof() - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 - ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno + new_messages.append(new_mess) - # damping in tangent space at linearisation point - # prev_mess_eta, prev_mess_lam = local_gaussian( - # vtof_msgs[v], lin_points[v], return_mean=False) - # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] - # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] - - if new_mess_lam.count_nonzero() == 0: - new_mess = Gaussian([lin_points[v].copy()]) - new_mess.mean[0].data[:] = 0.0 - else: - new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) - new_mess_mean = new_mess_mean[None, ...] - new_mess_lam = new_mess_lam[None, ...] + sdim += dofs - new_mess = Gaussian([lin_points[v].copy()]) - retract_gaussian(new_mess_mean, new_mess_lam, lin_points[v], new_mess) - new_messages.append(new_mess) - - sdim += dofs + # update messages + for v in range(num_optim_vars): + ftov_msgs[v].update( + mean=new_messages[v].mean, precision=new_messages[v].precision + ) - # update messages - for v in range(num_optim_vars): - ftov_msgs[v].update( - mean=new_messages[v].mean, precision=new_messages[v].precision - ) + return new_messages - return new_messages + @property + def dof(self) -> int: + return self._dof # Follows notation from https://arxiv.org/pdf/2202.03314.pdf - - class GaussianBeliefPropagation(Optimizer, abc.ABC): def __init__( self, objective: Objective, - *args, - linearization_cls: Optional[Type[Linearization]] = None, - linearization_kwargs: Optional[Dict[str, Any]] = None, abs_err_tolerance: float = 1e-10, rel_err_tolerance: float = 1e-8, max_iterations: int = 20, @@ -607,16 +451,13 @@ def __init__( self.ordering = VariableOrdering(objective, default_order=True) self.cf_ordering = CostFunctionOrdering(objective) - self.schedule = None - self.params = GBPOptimizerParams( abs_err_tolerance, rel_err_tolerance, max_iterations ) self.n_edges = sum([cf.num_optim_vars() for cf in self.cf_ordering]) - self.max_dofs = max([var.dof() for var in self.ordering]) - # create arrays for indexing the messages + # create array for indexing the messages var_ixs_nested = [ [self.ordering.index_of(var.name) for var in cf.optim_vars] for cf in self.cf_ordering @@ -624,17 +465,6 @@ def __init__( var_ixs = [item for sublist in var_ixs_nested for item in sublist] self.var_ix_for_edges = torch.tensor(var_ixs).long() - self.adj_var_dofs_nested = [ - [var.shape[1] for var in cf.optim_vars] for cf in self.cf_ordering - ] - - lie_groups = False - for v in self.ordering: - if isinstance(v, th.LieGroup) and not isinstance(v, th.Vector): - lie_groups = True - self.lie_groups = lie_groups - print("lie groups:", self.lie_groups) - """ Copied and slightly modified from nonlinear optimizer class """ @@ -764,45 +594,84 @@ def _merge_infos( ] = -1 """ - GBP specific functions + GBP functions """ - # Linearizes factors at current belief if beliefs have deviated - # from the linearization point by more than the threshold. - def _linearize( + def _pass_var_to_fac_messages( self, - potentials_eta, - potentials_lam, - lin_points, - lp_dist_thresh: float = None, - lie=False, + ftov_msgs, + vtof_msgs, + update_belief=True, ): - do_lins = [] - for i, cf in enumerate(self.cf_ordering): + for i, var in enumerate(self.ordering): + + # Collect all incoming messages in the tangent space at the current belief + taus = [] # message means + lams_tp = [] # message lams + for j, msg in enumerate(ftov_msgs): + if self.var_ix_for_edges[j] == i: + tau, lam_tp = local_gaussian(msg, var, return_mean=True) + taus.append(tau[None, ...]) + lams_tp.append(lam_tp[None, ...]) + + taus = torch.cat(taus) + lams_tp = torch.cat(lams_tp) + + lam_tau = lams_tp.sum(dim=0) + + # Compute outgoing messages + ix = 0 + for j, msg in enumerate(ftov_msgs): + if self.var_ix_for_edges[j] == i: + taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) + + lam_a = lams_inc.sum(dim=0) + if lam_a.count_nonzero() == 0: + vtof_msgs[j].mean[0].data[:] = 0.0 + vtof_msgs[j].precision = lam_a + else: + inv_lam_a = torch.inverse(lam_a) + sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( + dim=0 + ) + tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) + retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) + ix += 1 - do_lin = False - if lp_dist_thresh is None: - do_lin = True - else: - lp_dists = [ - lp.local(cf.optim_var_at(j)).norm() - for j, lp in enumerate(lin_points[i]) - ] - do_lin = np.max(lp_dists) > lp_dist_thresh + # update belief mean and variance + # if no incoming messages then leave current belief unchanged + if update_belief and lam_tau.count_nonzero() != 0: + inv_lam_tau = torch.inverse(lam_tau) + sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) + tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) - do_lins.append(do_lin) + retract_gaussian(tau, lam_tau, var, self.beliefs[i]) - if do_lin: - potential_eta, potential_lam = compute_factor(cf, lie=lie) + def _pass_fac_to_var_messages( + self, + vtof_msgs, + ftov_msgs, + schedule: torch.Tensor, + damping: torch.Tensor, + ): + start = 0 + for factor in self.factors: + num_optim_vars = factor.cf.num_optim_vars() - potentials_eta[i] = potential_eta - potentials_lam[i] = potential_lam + factor.linearize(relin_threshold=None) - for j, var in enumerate(cf.optim_vars): - lin_points[i][j].update(var.data) + factor.comp_mess( + vtof_msgs[start : start + num_optim_vars], + ftov_msgs[start : start + num_optim_vars], + damping[start : start + num_optim_vars], + ) + + start += num_optim_vars - # print(f"Linearised {np.sum(do_lins)} out of {len(do_lins)} factors.") - return potentials_eta, potentials_lam, lin_points + """ + Optimization loop functions + """ # loop for the iterative optimizer def _optimize_loop( @@ -815,116 +684,31 @@ def _optimize_loop( relin_threshold: float = 0.1, damping: float = 0.0, dropout: float = 0.0, - lp_dist_thresh: float = 0.1, + schedule: torch.Tensor = None, **kwargs, ): - # initialise messages with zeros - vtof_msgs_eta = torch.zeros( - self.n_edges, self.max_dofs, dtype=self.objective.dtype - ) - vtof_msgs_lam = torch.zeros( - self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype - ) - ftov_msgs_eta = vtof_msgs_eta.clone() - ftov_msgs_lam = vtof_msgs_lam.clone() - - # compute factor potentials for the first time - potentials_eta = [None] * self.objective.size_cost_functions() - potentials_lam = [None] * self.objective.size_cost_functions() - lin_points = [ - [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering - ] - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None - ) - - converged_indices = torch.zeros_like(info.last_err).bool() - for it_ in range(start_iter, start_iter + num_iter): - - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None + if damping > 1.0 or damping < 0.0: + raise ValueError(f"Damping must be in between 0 and 1. Got {damping}.") + if dropout > 1.0 or dropout < 0.0: + raise ValueError( + f"Dropout probability must be in between 0 and 1. Got {dropout}." ) - - msgs_eta, msgs_lam = pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - self.adj_var_dofs_nested, + if schedule is None: + schedule = random_schedule(self.params.max_iterations, self.n_edges) + elif schedule.dtype != torch.bool: + raise ValueError( + f"Schedule must be of dtype {torch.bool} but has dtype {schedule.dtype}." ) - - # damping - # damping = self.gbp_settings.get_damping(iters_since_relin) - damping_arr = torch.full([len(msgs_eta)], damping) - - # dropout can be implemented through damping - if dropout != 0.0: - dropout_ixs = torch.rand(len(msgs_eta)) < dropout - damping_arr[dropout_ixs] = 1.0 - - ftov_msgs_eta = (1 - damping_arr[:, None]) * msgs_eta + damping_arr[ - :, None - ] * ftov_msgs_eta - ftov_msgs_lam = (1 - damping_arr[:, None, None]) * msgs_lam + damping_arr[ - :, None, None - ] * ftov_msgs_lam - - ( - vtof_msgs_eta, - vtof_msgs_lam, - belief_eta, - belief_lam, - ) = pass_var_to_fac_messages( - ftov_msgs_eta, - ftov_msgs_lam, - self.var_ix_for_edges, - len(self.ordering._var_order), - self.max_dofs, + elif schedule.shape != torch.Size([self.params.max_iterations, self.n_edges]): + raise ValueError( + f"Schedule must have shape [max_iterations, num_edges]. " + f"Should be {torch.Size([self.params.max_iterations, self.n_edges])} " + f"but got {schedule.shape}." ) - # update beliefs - belief_cov = torch.inverse(belief_lam) - belief_mean = torch.matmul(belief_cov, belief_eta.unsqueeze(-1)).squeeze() - for i, var in enumerate(self.ordering): - var.update(data=belief_mean[i][None, :]) - - # check for convergence - with torch.no_grad(): - err = self.objective.error_squared_norm() / 2 - self._update_info(info, it_, err, converged_indices) - if verbose: - print(f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}") - converged_indices = self._check_convergence(err, info.last_err) - info.status[ - converged_indices.cpu().numpy() - ] = NonlinearOptimizerStatus.CONVERGED - if converged_indices.all(): - break # nothing else will happen at this point - info.last_err = err - - info.status[ - info.status == NonlinearOptimizerStatus.START - ] = NonlinearOptimizerStatus.MAX_ITERATIONS - return info - - # loop for the iterative optimizer - def _optimize_loop_lie( - self, - start_iter: int, - num_iter: int, - info: NonlinearOptimizerInfo, - verbose: bool, - truncated_grad_loop: bool, - relin_threshold: float = 0.1, - damping: float = 0.0, - dropout: float = 0.0, - lp_dist_thresh: float = 0.1, - **kwargs, - ): # initialise messages with zeros - vtof_msgs = [] - ftov_msgs = [] + vtof_msgs: List[Message] = [] + ftov_msgs: List[Message] = [] for cf in self.cf_ordering: for var in cf.optim_vars: vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") @@ -934,28 +718,19 @@ def _optimize_loop_lie( vtof_msgs.append(Message([vtof_msg_mu])) ftov_msgs.append(Message([ftov_msg_mu])) + # initialise Marginal for belief + self.beliefs: List[Marginal] = [] + for var in self.ordering: + self.beliefs.append(Marginal([var])) + # compute factor potentials for the first time - potentials_eta = [None] * self.objective.size_cost_functions() - potentials_lam = [None] * self.objective.size_cost_functions() - lin_points = [ - [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering - ] - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, lp_dist_thresh=None, lie=True - ) + self.factors: List[Factor] = [] + for cf in self.cf_ordering: + self.factors.append(Factor(cf)) converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, - potentials_lam, - lin_points, - lp_dist_thresh=None, - lie=True, - ) - # damping # damping = self.gbp_settings.get_damping(iters_since_relin) damping_arr = torch.full([self.n_edges], damping) @@ -965,21 +740,17 @@ def _optimize_loop_lie( dropout_ixs = torch.rand(self.n_edges) < dropout damping_arr[dropout_ixs] = 1.0 - pass_fac_to_var_messages_lie( - potentials_eta, - potentials_lam, - lin_points, + self._pass_fac_to_var_messages( vtof_msgs, ftov_msgs, - self.adj_var_dofs_nested, + schedule[it_], damping_arr, ) - pass_var_to_fac_messages_and_update_beliefs_lie( + self._pass_var_to_fac_messages( ftov_msgs, vtof_msgs, - self.ordering, - self.var_ix_for_edges, + update_belief=True, ) # check for convergence @@ -1015,13 +786,9 @@ def _optimize_impl( backward_mode: BackwardMode = BackwardMode.FULL, damping: float = 0.0, dropout: float = 0.0, + schedule: torch.Tensor = None, **kwargs, ) -> NonlinearOptimizerInfo: - if damping > 1.0 or damping < 0.0: - raise NotImplementedError("Damping must be in between 0 and 1.") - if dropout > 1.0 or dropout < 0.0: - raise NotImplementedError("Dropout probability must be in between 0 and 1.") - with torch.no_grad(): info = self._init_info(track_best_solution, track_err_history, verbose) @@ -1035,9 +802,7 @@ def _optimize_impl( grad = True with torch.set_grad_enabled(grad): - - # if self.lie_groups: - info = self._optimize_loop_lie( + info = self._optimize_loop( start_iter=0, num_iter=self.params.max_iterations, info=info, @@ -1045,19 +810,10 @@ def _optimize_impl( truncated_grad_loop=False, damping=damping, dropout=dropout, + schedule=schedule, **kwargs, ) - # else: - # info = self._optimize_loop( - # start_iter=0, - # num_iter=self.params.max_iterations, - # info=info, - # verbose=verbose, - # truncated_grad_loop=False, - # damping=damping, - # dropout=dropout, - # **kwargs, - # ) + # If didn't coverge, remove misleading converged_iter value info.converged_iter[ info.status == NonlinearOptimizerStatus.MAX_ITERATIONS diff --git a/theseus/optimizer/gbp/gbp_euclidean.py b/theseus/optimizer/gbp/gbp_euclidean.py new file mode 100644 index 000000000..8505fbb7d --- /dev/null +++ b/theseus/optimizer/gbp/gbp_euclidean.py @@ -0,0 +1,1047 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import abc +import math +from dataclasses import dataclass +from itertools import count +from typing import Dict, List, Optional, Sequence, Tuple + +import numpy as np +import torch + +import theseus as th +import theseus.constants +from theseus.core import CostFunction, Objective +from theseus.geometry import Manifold +from theseus.optimizer import Optimizer, VariableOrdering +from theseus.optimizer.nonlinear.nonlinear_optimizer import ( + BackwardMode, + NonlinearOptimizerInfo, + NonlinearOptimizerStatus, +) + +""" +TODO + - add class for message schedule + - damping for lie algebra vars + - solving inverse problem to compute message mean +""" + + +""" +Utitily functions +""" + + +@dataclass +class GBPOptimizerParams: + abs_err_tolerance: float + rel_err_tolerance: float + max_iterations: int + + def update(self, params_dict): + for param, value in params_dict.items(): + if hasattr(self, param): + setattr(self, param, value) + else: + raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") + + +class ManifoldGaussian: + _ids = count(0) + + def __init__( + self, + mean: Sequence[Manifold], + precision: Optional[torch.Tensor] = None, + name: Optional[str] = None, + ): + self._id = next(ManifoldGaussian._ids) + if name: + self.name = name + else: + self.name = f"{self.__class__.__name__}__{self._id}" + + dof = 0 + for v in mean: + dof += v.dof() + self._dof = dof + + self.mean = mean + self.precision = torch.zeros(mean[0].shape[0], self.dof, self.dof).to( + dtype=mean[0].dtype, device=mean[0].device + ) + + @property + def dof(self) -> int: + return self._dof + + @property + def device(self) -> torch.device: + return self.precision[0].device + + @property + def dtype(self) -> torch.dtype: + return self.precision[0].dtype + + # calls to() on the internal tensors + def to(self, *args, **kwargs): + for var in self.mean: + var = var.to(*args, **kwargs) + self.precision = self.precision.to(*args, **kwargs) + + def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian": + if not new_name: + new_name = f"{self.name}_copy" + mean_copy = [var.copy() for var in self.mean] + return ManifoldGaussian(mean_copy, name=new_name) + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + the_copy = self.copy() + memo[id(self)] = the_copy + return the_copy + + def update( + self, + mean: Optional[Sequence[Manifold]] = None, + precision: Optional[torch.Tensor] = None, + ): + if mean is not None: + if len(mean) != len(self.mean): + raise ValueError( + f"Tried to update mean with sequence of different" + f"lenght to original mean sequence. Given {len(mean)}. " + f"Expected: {len(self.mean)}" + ) + for i in range(len(self.mean)): + self.mean[i].update(mean[i]) + + if precision is not None: + if precision.shape != self.precision.shape: + raise ValueError( + f"Tried to update precision with data " + f"incompatible with original tensor shape. Given {precision.shape}. " + f"Expected: {self.precision.shape}" + ) + if precision.dtype != self.dtype: + raise ValueError( + f"Tried to update using tensor of dtype {precision.dtype} but precision " + f"has dtype {self.dtype}." + ) + + self.precision = precision + + +class Marginal(ManifoldGaussian): + pass + + +class Message(ManifoldGaussian): + pass + + +class CostFunctionOrdering: + def __init__(self, objective: Objective, default_order: bool = True): + self.objective = objective + self._cf_order: List[CostFunction] = [] + self._cf_name_to_index: Dict[str, int] = {} + if default_order: + self._compute_default_order(objective) + + def _compute_default_order(self, objective: Objective): + assert not self._cf_order and not self._cf_name_to_index + cur_idx = 0 + for cf_name, cf in objective.cost_functions.items(): + if cf_name in self._cf_name_to_index: + continue + self._cf_order.append(cf) + self._cf_name_to_index[cf_name] = cur_idx + cur_idx += 1 + + def index_of(self, key: str) -> int: + return self._cf_name_to_index[key] + + def __getitem__(self, index) -> CostFunction: + return self._cf_order[index] + + def __iter__(self): + return iter(self._cf_order) + + def append(self, cf: CostFunction): + if cf in self._cf_order: + raise ValueError( + f"Cost Function {cf.name} has already been added to the order." + ) + if cf.name not in self.objective.cost_functions: + raise ValueError( + f"Cost Function {cf.name} is not a cost function for the objective." + ) + self._cf_order.append(cf) + self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 + + def remove(self, cf: CostFunction): + self._cf_order.remove(cf) + del self._cf_name_to_index[cf.name] + + def extend(self, cfs: Sequence[CostFunction]): + for cf in cfs: + self.append(cf) + + @property + def complete(self): + return len(self._cf_order) == self.objective.size_variables() + + +""" +GBP functions +""" + + +# Compute the factor at current adjacent beliefs. +def compute_factor(cf, lie=True): + J, error = cf.weighted_jacobians_error() + J_stk = torch.cat(J, dim=-1) + + lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) + + optim_vars_stk = torch.cat([v.data for v in cf.optim_vars], dim=-1) + eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) + if lie is False: + eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) + eta = eta.squeeze(-1) + + return eta, lam + + +def pass_var_to_fac_messages( + ftov_msgs_eta, + ftov_msgs_lam, + var_ix_for_edges, + n_vars, + max_dofs, +): + belief_eta = torch.zeros(n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) + belief_lam = torch.zeros(n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) + + belief_eta = belief_eta.index_add(0, var_ix_for_edges, ftov_msgs_eta) + belief_lam = belief_lam.index_add(0, var_ix_for_edges, ftov_msgs_lam) + + vtof_msgs_eta = belief_eta[var_ix_for_edges] - ftov_msgs_eta + vtof_msgs_lam = belief_lam[var_ix_for_edges] - ftov_msgs_lam + + return vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam + + +def pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + adj_var_dofs_nested: List[List], +): + ftov_msgs_eta = torch.zeros_like(vtof_msgs_eta) + ftov_msgs_lam = torch.zeros_like(vtof_msgs_lam) + + start = 0 + for i in range(len(adj_var_dofs_nested)): + adj_var_dofs = adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + ftov_eta, ftov_lam = ftov_comp_mess( + adj_var_dofs, + potentials_eta[i], + potentials_lam[i], + vtof_msgs_eta[start : start + num_optim_vars], + vtof_msgs_lam[start : start + num_optim_vars], + ) + + ftov_msgs_eta[start : start + num_optim_vars] = torch.cat(ftov_eta) + ftov_msgs_lam[start : start + num_optim_vars] = torch.cat(ftov_lam) + + start += num_optim_vars + + return ftov_msgs_eta, ftov_msgs_lam + + +# Transforms message gaussian to tangent plane at var +# if return_mean is True, return the (mean, lam) else return (eta, lam). +# Generalises the local function by transforming the covariance as well as mean. +def local_gaussian( + mess: Message, + var: th.LieGroup, + return_mean: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + # mean_tp is message mean in tangent space / plane at var + mean_tp = var.local(mess.mean[0]) + + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) + + # lam_tp is the precision matrix in the tangent plane + lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) + + if return_mean: + return mean_tp, lam_tp + + else: + eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) + return eta_tp, lam_tp + + +# Transforms Gaussian in the tangent plane at var to Gaussian where the mean +# is a group element and the precision matrix is defined in the tangent plane +# at the mean. +# Generalises the retract function by transforming the covariance as well as mean. +# out_gauss is the transformed Gaussian that is updated in place. +def retract_gaussian( + mean_tp: torch.Tensor, + precision_tp: torch.Tensor, + var: th.LieGroup, + out_gauss: ManifoldGaussian, +): + mean = var.retract(mean_tp) + + jac: List[torch.Tensor] = [] + th.exp_map(var, mean_tp, jacobians=jac) + inv_jac = torch.inverse(jac[0]) + precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac) + + out_gauss.update(mean=[mean], precision=precision) + + +# Compute all outgoing messages from the factor. +def ftov_comp_mess( + adj_var_dofs, + potential_eta, + potential_lam, + vtof_msgs_eta, + vtof_msgs_lam, +): + num_optim_vars = len(adj_var_dofs) + messages_eta, messages_lam = [], [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potential_eta.clone()[0] + lam_factor = potential_lam.clone()[0] + + # Take product of factor with incoming messages + start = 0 + for var in range(num_optim_vars): + var_dofs = adj_var_dofs[var] + if var != v: + eta_mess = vtof_msgs_eta[var] + lam_mess = vtof_msgs_lam[var] + eta_factor[start : start + var_dofs] += eta_mess + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess + start += var_dofs + + # Divide up parameters of distribution + dofs = adj_var_dofs[v] + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) + + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] + lono = np.concatenate( + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) + lnoo = np.concatenate( + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) + + new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno + + messages_eta.append(new_message_eta[None, :]) + messages_lam.append(new_message_lam[None, :]) + + sdim += dofs + + return messages_eta, messages_lam + + +# Follows notation from https://arxiv.org/pdf/2202.03314.pdf + + +class GaussianBeliefPropagation(Optimizer, abc.ABC): + def __init__( + self, + objective: Objective, + abs_err_tolerance: float = 1e-10, + rel_err_tolerance: float = 1e-8, + max_iterations: int = 20, + ): + super().__init__(objective) + + # ordering is required to identify which messages to send where + self.ordering = VariableOrdering(objective, default_order=True) + self.cf_ordering = CostFunctionOrdering(objective) + + self.schedule = None + + self.params = GBPOptimizerParams( + abs_err_tolerance, rel_err_tolerance, max_iterations + ) + + self.n_edges = sum([cf.num_optim_vars() for cf in self.cf_ordering]) + self.max_dofs = max([var.dof() for var in self.ordering]) + + # create arrays for indexing the messages + var_ixs_nested = [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] + for cf in self.cf_ordering + ] + var_ixs = [item for sublist in var_ixs_nested for item in sublist] + self.var_ix_for_edges = torch.tensor(var_ixs).long() + + self.adj_var_dofs_nested = [ + [var.shape[1] for var in cf.optim_vars] for cf in self.cf_ordering + ] + + lie_groups = False + for v in self.ordering: + if isinstance(v, th.LieGroup) and not isinstance(v, th.Vector): + lie_groups = True + self.lie_groups = lie_groups + print("lie groups:", self.lie_groups) + + """ + Copied and slightly modified from nonlinear optimizer class + """ + + def set_params(self, **kwargs): + self.params.update(kwargs) + + def _check_convergence(self, err: torch.Tensor, last_err: torch.Tensor): + assert not torch.is_grad_enabled() + if err.abs().mean() < theseus.constants.EPS: + return torch.ones_like(err).bool() + + abs_error = (last_err - err).abs() + rel_error = abs_error / last_err + return (abs_error < self.params.abs_err_tolerance).logical_or( + rel_error < self.params.rel_err_tolerance + ) + + def _maybe_init_best_solution( + self, do_init: bool = False + ) -> Optional[Dict[str, torch.Tensor]]: + if not do_init: + return None + solution_dict = {} + for var in self.ordering: + solution_dict[var.name] = var.data.detach().clone().cpu() + return solution_dict + + def _init_info( + self, track_best_solution: bool, track_err_history: bool, verbose: bool + ) -> NonlinearOptimizerInfo: + with torch.no_grad(): + last_err = self.objective.error_squared_norm() / 2 + best_err = last_err.clone() if track_best_solution else None + if track_err_history: + err_history = ( + torch.ones(self.objective.batch_size, self.params.max_iterations + 1) + * math.inf + ) + assert last_err.grad_fn is None + err_history[:, 0] = last_err.clone().cpu() + else: + err_history = None + return NonlinearOptimizerInfo( + best_solution=self._maybe_init_best_solution(do_init=track_best_solution), + last_err=last_err, + best_err=best_err, + status=np.array( + [NonlinearOptimizerStatus.START] * self.objective.batch_size + ), + converged_iter=torch.zeros_like(last_err, dtype=torch.long), + best_iter=torch.zeros_like(last_err, dtype=torch.long), + err_history=err_history, + ) + + def _update_info( + self, + info: NonlinearOptimizerInfo, + current_iter: int, + err: torch.Tensor, + converged_indices: torch.Tensor, + ): + info.converged_iter += 1 - converged_indices.long() + if info.err_history is not None: + assert err.grad_fn is None + info.err_history[:, current_iter + 1] = err.clone().cpu() + + if info.best_solution is not None: + # Only copy best solution if needed (None means track_best_solution=False) + assert info.best_err is not None + good_indices = err < info.best_err + info.best_iter[good_indices] = current_iter + for var in self.ordering: + info.best_solution[var.name][good_indices] = ( + var.data.detach().clone()[good_indices].cpu() + ) + + info.best_err = torch.minimum(info.best_err, err) + + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + np.array(converged_indices.detach().cpu()) + ] = NonlinearOptimizerStatus.CONVERGED + + # Modifies the (no grad) info in place to add data of grad loop info + def _merge_infos( + self, + grad_loop_info: NonlinearOptimizerInfo, + num_no_grad_iter: int, + backward_num_iterations: int, + info: NonlinearOptimizerInfo, + ): + # Concatenate error histories + if info.err_history is not None: + info.err_history[:, num_no_grad_iter:] = grad_loop_info.err_history[ + :, : backward_num_iterations + 1 + ] + # Merge best solution and best error + if info.best_solution is not None: + best_solution = {} + best_err_no_grad = info.best_err + best_err_grad = grad_loop_info.best_err + idx_no_grad = best_err_no_grad < best_err_grad + best_err = torch.minimum(best_err_no_grad, best_err_grad) + for var_name in info.best_solution: + sol_no_grad = info.best_solution[var_name] + sol_grad = grad_loop_info.best_solution[var_name] + best_solution[var_name] = torch.where( + idx_no_grad, sol_no_grad, sol_grad + ) + info.best_solution = best_solution + info.best_err = best_err + + # Merge the converged status into the info from the detached loop, + M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS + assert np.all( + (grad_loop_info.status[M] == NonlinearOptimizerStatus.MAX_ITERATIONS) + | (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED) + ) + info.status[M] = grad_loop_info.status[M] + info.converged_iter[M] = ( + info.converged_iter[M] + grad_loop_info.converged_iter[M] + ) + # If didn't coverge in either loop, remove misleading converged_iter value + info.converged_iter[ + M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS) + ] = -1 + + """ + GBP specific functions + """ + + # Linearizes factors at current belief if beliefs have deviated + # from the linearization point by more than the threshold. + def _linearize( + self, + potentials_eta, + potentials_lam, + lin_points, + relin_threshold: float = None, + lie=False, + ): + do_lins = [] + for i, cf in enumerate(self.cf_ordering): + + do_lin = False + if relin_threshold is None: + do_lin = True + else: + lp_dists = [ + lp.local(cf.optim_var_at(j)).norm() + for j, lp in enumerate(lin_points[i]) + ] + do_lin = np.max(lp_dists) > relin_threshold + + do_lins.append(do_lin) + + if do_lin: + potential_eta, potential_lam = compute_factor(cf, lie=lie) + + potentials_eta[i] = potential_eta + potentials_lam[i] = potential_lam + + for j, var in enumerate(cf.optim_vars): + lin_points[i][j].update(var.data) + + # print(f"Linearised {np.sum(do_lins)} out of {len(do_lins)} factors.") + return potentials_eta, potentials_lam, lin_points + + def _pass_var_to_fac_messages( + self, + ftov_msgs, + vtof_msgs, + update_belief=True, + ): + for i, var in enumerate(self.ordering): + + # Collect all incoming messages in the tangent space at the current belief + taus = [] # message means + lams_tp = [] # message lams + for j, msg in enumerate(ftov_msgs): + if self.var_ix_for_edges[j] == i: + tau, lam_tp = local_gaussian(msg, var, return_mean=True) + taus.append(tau[None, ...]) + lams_tp.append(lam_tp[None, ...]) + + taus = torch.cat(taus) + lams_tp = torch.cat(lams_tp) + + lam_tau = lams_tp.sum(dim=0) + + # Compute outgoing messages + ix = 0 + for j, msg in enumerate(ftov_msgs): + if self.var_ix_for_edges[j] == i: + taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) + + lam_a = lams_inc.sum(dim=0) + if lam_a.count_nonzero() == 0: + vtof_msgs[j].mean[0].data[:] = 0.0 + vtof_msgs[j].precision = lam_a + else: + inv_lam_a = torch.inverse(lam_a) + sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( + dim=0 + ) + tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) + retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) + ix += 1 + + # update belief mean and variance + # if no incoming messages then leave current belief unchanged + if update_belief and lam_tau.count_nonzero() != 0: + inv_lam_tau = torch.inverse(lam_tau) + sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) + tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) + + retract_gaussian(tau, lam_tau, var, self.beliefs[i]) + + def _pass_fac_to_var_messages( + self, + potentials_eta, + potentials_lam, + lin_points, + vtof_msgs, + ftov_msgs, + damping: torch.Tensor, + ): + start = 0 + for i in range(len(self.adj_var_dofs_nested)): + adj_var_dofs = self.adj_var_dofs_nested[i] + num_optim_vars = len(adj_var_dofs) + + self._ftov_comp_mess( + potentials_eta[i], + potentials_lam[i], + lin_points[i], + vtof_msgs[start : start + num_optim_vars], + ftov_msgs[start : start + num_optim_vars], + damping[start : start + num_optim_vars], + ) + + start += num_optim_vars + + # Compute all outgoing messages from the factor. + def _ftov_comp_mess( + self, + potential_eta, + potential_lam, + lin_points, + vtof_msgs, + ftov_msgs, + damping, + ): + num_optim_vars = len(lin_points) + new_messages = [] + + sdim = 0 + for v in range(num_optim_vars): + eta_factor = potential_eta.clone()[0] + lam_factor = potential_lam.clone()[0] + + # Take product of factor with incoming messages. + # Convert mesages to tangent space at linearisation point. + start = 0 + for i in range(num_optim_vars): + var_dofs = lin_points[i].dof() + if i != v: + eta_mess, lam_mess = local_gaussian( + vtof_msgs[i], lin_points[i], return_mean=False + ) + eta_factor[start : start + var_dofs] += eta_mess[0] + lam_factor[ + start : start + var_dofs, start : start + var_dofs + ] += lam_mess[0] + start += var_dofs + + # Divide up parameters of distribution + dofs = lin_points[v].dof() + eo = eta_factor[sdim : sdim + dofs] + eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) + + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] + lono = np.concatenate( + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + axis=1, + ) + lnoo = np.concatenate( + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], + ), + axis=0, + ) + lnono = np.concatenate( + ( + np.concatenate( + (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), + axis=1, + ), + np.concatenate( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + axis=1, + ), + ), + axis=0, + ) + + new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno + + # damping in tangent space at linearisation point + # prev_mess_eta, prev_mess_lam = local_gaussian( + # vtof_msgs[v], lin_points[v], return_mean=False) + # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] + # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] + + if new_mess_lam.count_nonzero() == 0: + new_mess = ManifoldGaussian([lin_points[v].copy()]) + new_mess.mean[0].data[:] = 0.0 + else: + new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) + new_mess_mean = new_mess_mean[None, ...] + new_mess_lam = new_mess_lam[None, ...] + + new_mess = ManifoldGaussian([lin_points[v].copy()]) + retract_gaussian(new_mess_mean, new_mess_lam, lin_points[v], new_mess) + new_messages.append(new_mess) + + sdim += dofs + + # update messages + for v in range(num_optim_vars): + ftov_msgs[v].update( + mean=new_messages[v].mean, precision=new_messages[v].precision + ) + + return new_messages + + """ + Optimization loop functions + """ + + # loop for the iterative optimizer + def _optimize_loop( + self, + start_iter: int, + num_iter: int, + info: NonlinearOptimizerInfo, + verbose: bool, + truncated_grad_loop: bool, + relin_threshold: float = 0.1, + damping: float = 0.0, + dropout: float = 0.0, + **kwargs, + ): + # initialise messages with zeros + vtof_msgs_eta = torch.zeros( + self.n_edges, self.max_dofs, dtype=self.objective.dtype + ) + vtof_msgs_lam = torch.zeros( + self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype + ) + ftov_msgs_eta = vtof_msgs_eta.clone() + ftov_msgs_lam = vtof_msgs_lam.clone() + + # compute factor potentials for the first time + potentials_eta = [None] * self.objective.size_cost_functions() + potentials_lam = [None] * self.objective.size_cost_functions() + lin_points = [ + [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] + for cf in self.cf_ordering + ] + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, relin_threshold=None + ) + + converged_indices = torch.zeros_like(info.last_err).bool() + for it_ in range(start_iter, start_iter + num_iter): + + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, relin_threshold=None + ) + + msgs_eta, msgs_lam = pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + vtof_msgs_eta, + vtof_msgs_lam, + self.adj_var_dofs_nested, + ) + + # damping + # damping = self.gbp_settings.get_damping(iters_since_relin) + damping_arr = torch.full([len(msgs_eta)], damping) + + # dropout can be implemented through damping + if dropout != 0.0: + dropout_ixs = torch.rand(len(msgs_eta)) < dropout + damping_arr[dropout_ixs] = 1.0 + + ftov_msgs_eta = (1 - damping_arr[:, None]) * msgs_eta + damping_arr[ + :, None + ] * ftov_msgs_eta + ftov_msgs_lam = (1 - damping_arr[:, None, None]) * msgs_lam + damping_arr[ + :, None, None + ] * ftov_msgs_lam + + ( + vtof_msgs_eta, + vtof_msgs_lam, + belief_eta, + belief_lam, + ) = pass_var_to_fac_messages( + ftov_msgs_eta, + ftov_msgs_lam, + self.var_ix_for_edges, + len(self.ordering._var_order), + self.max_dofs, + ) + + # update beliefs + belief_cov = torch.inverse(belief_lam) + belief_mean = torch.matmul(belief_cov, belief_eta.unsqueeze(-1)).squeeze() + for i, var in enumerate(self.ordering): + var.update(data=belief_mean[i][None, :]) + + # check for convergence + with torch.no_grad(): + err = self.objective.error_squared_norm() / 2 + self._update_info(info, it_, err, converged_indices) + if verbose: + print(f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}") + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + converged_indices.cpu().numpy() + ] = NonlinearOptimizerStatus.CONVERGED + if converged_indices.all(): + break # nothing else will happen at this point + info.last_err = err + + info.status[ + info.status == NonlinearOptimizerStatus.START + ] = NonlinearOptimizerStatus.MAX_ITERATIONS + return info + + # loop for the iterative optimizer + def _optimize_loop_lie( + self, + start_iter: int, + num_iter: int, + info: NonlinearOptimizerInfo, + verbose: bool, + truncated_grad_loop: bool, + relin_threshold: float = 0.1, + damping: float = 0.0, + dropout: float = 0.0, + **kwargs, + ): + # initialise messages with zeros + vtof_msgs: List[Message] = [] + ftov_msgs: List[Message] = [] + for cf in self.cf_ordering: + for var in cf.optim_vars: + vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") + # mean of initial message doesn't matter as long as precision is zero + vtof_msg_mu.data[:] = 0 + ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") + vtof_msgs.append(Message([vtof_msg_mu])) + ftov_msgs.append(Message([ftov_msg_mu])) + + # initialise gaussian for belief + self.beliefs: List[Marginal] = [] + for var in self.ordering: + self.beliefs.append(Marginal([var])) + + # compute factor potentials for the first time + potentials_eta = [None] * self.objective.size_cost_functions() + potentials_lam = [None] * self.objective.size_cost_functions() + lin_points = [ + [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] + for cf in self.cf_ordering + ] + potentials_eta, potentials_lam, lin_points = self._linearize( + potentials_eta, potentials_lam, lin_points, relin_threshold=None, lie=True + ) + + converged_indices = torch.zeros_like(info.last_err).bool() + for it_ in range(start_iter, start_iter + num_iter): + + potentials_eta, potentials_lam, self.lin_points = self._linearize( + potentials_eta, + potentials_lam, + lin_points, + relin_threshold=None, + lie=True, + ) + + # damping + # damping = self.gbp_settings.get_damping(iters_since_relin) + damping_arr = torch.full([self.n_edges], damping) + + # dropout can be implemented through damping + if dropout != 0.0: + dropout_ixs = torch.rand(self.n_edges) < dropout + damping_arr[dropout_ixs] = 1.0 + + self._pass_fac_to_var_messages( + potentials_eta, + potentials_lam, + lin_points, + vtof_msgs, + ftov_msgs, + damping_arr, + ) + + self._pass_var_to_fac_messages( + ftov_msgs, + vtof_msgs, + update_belief=True, + ) + + # check for convergence + if it_ > 0: + with torch.no_grad(): + err = self.objective.error_squared_norm() / 2 + self._update_info(info, it_, err, converged_indices) + if verbose: + print( + f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}" + ) + converged_indices = self._check_convergence(err, info.last_err) + info.status[ + converged_indices.cpu().numpy() + ] = NonlinearOptimizerStatus.CONVERGED + if converged_indices.all() and it_ > 1: + break # nothing else will happen at this point + info.last_err = err + + info.status[ + info.status == NonlinearOptimizerStatus.START + ] = NonlinearOptimizerStatus.MAX_ITERATIONS + return info + + # `track_best_solution` keeps a **detached** copy (as in no gradient info) + # of the best variables found, but it is optional to avoid unnecessary copying + # if this is not needed + def _optimize_impl( + self, + track_best_solution: bool = False, + track_err_history: bool = False, + verbose: bool = False, + backward_mode: BackwardMode = BackwardMode.FULL, + damping: float = 0.0, + dropout: float = 0.0, + **kwargs, + ) -> NonlinearOptimizerInfo: + if damping > 1.0 or damping < 0.0: + raise NotImplementedError("Damping must be in between 0 and 1.") + if dropout > 1.0 or dropout < 0.0: + raise NotImplementedError("Dropout probability must be in between 0 and 1.") + + with torch.no_grad(): + info = self._init_info(track_best_solution, track_err_history, verbose) + + if verbose: + print( + f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" + ) + + grad = False + if backward_mode == BackwardMode.FULL: + grad = True + + with torch.set_grad_enabled(grad): + + # if self.lie_groups: + info = self._optimize_loop_lie( + start_iter=0, + num_iter=self.params.max_iterations, + info=info, + verbose=verbose, + truncated_grad_loop=False, + damping=damping, + dropout=dropout, + **kwargs, + ) + # else: + # info = self._optimize_loop( + # start_iter=0, + # num_iter=self.params.max_iterations, + # info=info, + # verbose=verbose, + # truncated_grad_loop=False, + # damping=damping, + # dropout=dropout, + # **kwargs, + # ) + # If didn't coverge, remove misleading converged_iter value + info.converged_iter[ + info.status == NonlinearOptimizerStatus.MAX_ITERATIONS + ] = -1 + return info diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index a5dd550e4..dc1825ea4 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -7,7 +7,7 @@ import torch import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation +from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule # This example illustrates the Gaussian Belief Propagation (GBP) optimizer # for a 2D pose graph optimization problem. @@ -125,12 +125,14 @@ # print("inputs", inputs) +max_iterations = 100 optimizer = GaussianBeliefPropagation( objective, - max_iterations=100, + max_iterations=max_iterations, ) theseus_optim = th.TheseusLayer(optimizer) + optim_arg = { "track_best_solution": True, "track_err_history": True, @@ -138,6 +140,7 @@ "backward_mode": th.BackwardMode.FULL, "damping": 0.6, "dropout": 0.0, + "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), } updated_inputs, info = theseus_optim.forward(inputs, optim_arg) From 264ed22640b9a2a09201c88ce8507e73d6506433 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 13 Apr 2022 16:55:52 +0100 Subject: [PATCH 08/64] added mean damping in lin point space --- theseus/optimizer/gbp/gbp.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 5c79349be..903167cd1 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -25,9 +25,9 @@ """ TODO - - add class for message schedule - damping for lie algebra vars - solving inverse problem to compute message mean + - handle batch dim """ @@ -401,11 +401,19 @@ def comp_mess( new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno - # damping in tangent space at linearisation point - # prev_mess_eta, prev_mess_lam = local_gaussian( - # vtof_msgs[v], lin_points[v], return_mean=False) - # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] - # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] + # damping in tangent space at linearisation point as message + # is already in this tangent space. Could equally do damping + # in the tangent space of the new or old message mean. + prev_mess_mean, prev_mess_lam = local_gaussian( + ftov_msgs[v], self.lin_point[v], return_mean=True + ) + # mean damping + if new_mess_lam.count_nonzero() != 0: + new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) + new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ + v + ] * prev_mess_mean[0] + new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) if new_mess_lam.count_nonzero() == 0: new_mess = ManifoldGaussian([self.cf.optim_var_at(v).copy()]) From 8092a031c212922069fc2c034ae27c1da21cbadc Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 20 Apr 2022 16:55:03 +0100 Subject: [PATCH 09/64] fix in linearise and add ba exmple --- theseus/optimizer/gbp/ba_test.py | 192 ++++++++++++++++++++++++++++++ theseus/optimizer/gbp/gbp.py | 4 +- theseus/optimizer/gbp/pgo_test.py | 26 +--- 3 files changed, 194 insertions(+), 28 deletions(-) create mode 100644 theseus/optimizer/gbp/ba_test.py diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py new file mode 100644 index 000000000..da1731457 --- /dev/null +++ b/theseus/optimizer/gbp/ba_test.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import random +from typing import Dict, List + +import numpy as np +import omegaconf +import torch + +import theseus as th +import theseus.utils.examples as theg + +# from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule + +# Smaller values} result in error +th.SO3.SO3_EPS = 1e-6 + + +def print_histogram( + ba: theg.BundleAdjustmentDataset, var_dict: Dict[str, torch.Tensor], msg: str +): + print(msg) + histogram = theg.ba_histogram( + cameras=[ + theg.Camera( + th.SE3(data=var_dict[c.pose.name]), + c.focal_length, + c.calib_k1, + c.calib_k2, + ) + for c in ba.cameras + ], + points=[th.Point3(data=var_dict[pt.name]) for pt in ba.points], + observations=ba.observations, + ) + for line in histogram.split("\n"): + print(line) + + +def camera_loss( + ba: theg.BundleAdjustmentDataset, camera_pose_vars: List[th.LieGroup] +) -> torch.Tensor: + loss: torch.Tensor = 0 # type:ignore + for i in range(len(ba.cameras)): + camera_loss = th.local(camera_pose_vars[i], ba.gt_cameras[i].pose).norm(dim=1) + loss += camera_loss + return loss + + +def run(cfg: omegaconf.OmegaConf): + # create (or load) dataset + ba = theg.BundleAdjustmentDataset.generate_synthetic( + num_cameras=cfg["num_cameras"], + num_points=cfg["num_points"], + average_track_length=cfg["average_track_length"], + track_locality=cfg["track_locality"], + feat_random=1.5, + outlier_feat_random=70, + ) + # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") + + # param that control transition from squared loss to huber + radius_tensor = torch.tensor([1.0], dtype=torch.float64) + log_loss_radius = th.Vector(data=radius_tensor, name="log_loss_radius") + + # Set up objective + objective = th.Objective(dtype=torch.float64) + + for obs in ba.observations: + cam = ba.cameras[obs.camera_index] + cost_function = theg.Reprojection( + camera_pose=cam.pose, + world_point=ba.points[obs.point_index], + focal_length=cam.focal_length, + calib_k1=cam.calib_k1, + calib_k2=cam.calib_k2, + log_loss_radius=log_loss_radius, + image_feature_point=obs.image_feature_point, + ) + objective.add(cost_function) + dtype = objective.dtype + + # Add regularization + if cfg["inner_optim"]["regularize"]: + zero_point3 = th.Point3(dtype=dtype, name="zero_point") + identity_se3 = th.SE3(dtype=dtype, name="zero_se3") + w = np.sqrt(cfg["inner_optim"]["reg_w"]) + damping_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) + for name, var in objective.optim_vars.items(): + target: th.Manifold + if isinstance(var, th.SE3): + target = identity_se3 + elif isinstance(var, th.Point3): + target = zero_point3 + else: + assert False + objective.add( + th.eb.VariableDifference( + var, damping_weight, target, name=f"reg_{name}" + ) + ) + + camera_pose_vars: List[th.LieGroup] = [ + objective.optim_vars[c.pose.name] for c in ba.cameras # type: ignore + ] + if cfg["inner_optim"]["ratio_known_cameras"] > 0.0: + w = 100.0 + camera_weight = th.ScaleCostWeight(100 * torch.ones(1, dtype=dtype)) + for i in range(len(ba.cameras)): + if np.random.rand() > cfg["inner_optim"]["ratio_known_cameras"]: + continue + objective.add( + th.eb.VariableDifference( + camera_pose_vars[i], + camera_weight, + ba.gt_cameras[i].pose, + name=f"camera_diff_{i}", + ) + ) + + # Create optimizer and theseus layer + optimizer = th.GaussNewton( + objective, + max_iterations=cfg["inner_optim"]["max_iters"], + step_size=0.1, + ) + # optimizer = GaussianBeliefPropagation( + # objective, + # max_iterations=cfg["inner_optim"]["max_iters"], + # ) + theseus_optim = th.TheseusLayer(optimizer) + + optim_arg = { + "track_best_solution": True, + "track_err_history": True, + "verbose": True, + "backward_mode": th.BackwardMode.FULL, + # "damping": 0.6, + # "dropout": 0.0, + # "schedule": synchronous_schedule(cfg["inner_optim"]["max_iters"], optimizer.n_edges), + } + + theseus_inputs = {} + for cam in ba.cameras: + theseus_inputs[cam.pose.name] = cam.pose.data.clone() + for pt in ba.points: + theseus_inputs[pt.name] = pt.data.clone() + theseus_inputs["log_loss_radius"] = log_loss_radius.data.clone() + + with torch.no_grad(): + camera_loss_ref = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS: {camera_loss_ref: .3f}") + # print_histogram(ba, theseus_inputs, "Input histogram:") + + objective.update(theseus_inputs) + print("squred err:", objective.error_squared_norm().item()) + + theseus_outputs, info = theseus_optim.forward( + input_data=theseus_inputs, + optimizer_kwargs=optim_arg, + ) + + loss = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") + + +if __name__ == "__main__": + + cfg = { + "seed": 1, + "num_cameras": 10, + "num_points": 200, + "average_track_length": 8, + "track_locality": 0.2, + "inner_optim": { + "max_iters": 10, + "verbose": True, + "track_err_history": True, + "keep_step_size": True, + "regularize": True, + "ratio_known_cameras": 0.1, + "reg_w": 1e-4, + }, + } + + torch.manual_seed(cfg["seed"]) + np.random.seed(cfg["seed"]) + random.seed(cfg["seed"]) + + run(cfg) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 903167cd1..9e3d3a0a0 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -25,7 +25,6 @@ """ TODO - - damping for lie algebra vars - solving inverse problem to compute message mean - handle batch dim """ @@ -318,10 +317,9 @@ def linearize( J_stk = torch.cat(J, dim=-1) lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) - - optim_vars_stk = torch.cat([v.data for v in self.cf.optim_vars], dim=-1) eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) if lie is False: + optim_vars_stk = torch.cat([v.data for v in self.cf.optim_vars], dim=-1) eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index dc1825ea4..2498e5802 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -110,21 +110,13 @@ objective.add(cf_meas) m += 1 -# # objective.update(init_dict) +# objective.update(init_dict) # print("Initial cost:", objective.error_squared_norm()) -# fg.print(brief=True) - -# # for vis --------------- - # joint = fg.get_joint() # marg_covs = np.diag(joint.cov())[::2] # map_soln = fg.MAP().reshape([size * size, 2]) -# Solve with Gauss Newton --------------- - -# print("inputs", inputs) - max_iterations = 100 optimizer = GaussianBeliefPropagation( objective, @@ -144,21 +136,5 @@ } updated_inputs, info = theseus_optim.forward(inputs, optim_arg) -print("updated_inputs", updated_inputs) -print("info", info) - - -# optimizer = th.GaussNewton( -# objective, -# max_iterations=15, -# step_size=0.5, -# ) -# theseus_optim = th.TheseusLayer(optimizer) - -# with torch.no_grad(): -# optim_args = {"track_best_solution": True, "verbose": True} -# updated_inputs, info = theseus_optim.forward(inputs, optim_args) # print("updated_inputs", updated_inputs) # print("info", info) - -# import ipdb; ipdb.set_trace() From 410034225a7e74b0f01680f1034eb22ab0049910 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 22 Apr 2022 12:00:03 +0100 Subject: [PATCH 10/64] use th.Manifold gaussian --- theseus/optimizer/gbp/gbp.py | 180 +++++------------------------- theseus/optimizer/gbp/pgo_test.py | 2 +- 2 files changed, 26 insertions(+), 156 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 9e3d3a0a0..0a6954b51 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -7,7 +7,7 @@ import math from dataclasses import dataclass from itertools import count -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence import numpy as np import torch @@ -61,151 +61,21 @@ def random_schedule(max_iters, n_edges) -> torch.Tensor: return schedule -class ManifoldGaussian: - _ids = count(0) - +# Initialises message precision to zero +class Message(th.ManifoldGaussian): def __init__( self, mean: Sequence[Manifold], precision: Optional[torch.Tensor] = None, name: Optional[str] = None, ): - self._id = next(ManifoldGaussian._ids) - if name: - self.name = name - else: - self.name = f"{self.__class__.__name__}__{self._id}" - - dof = 0 - for v in mean: - dof += v.dof() - self._dof = dof - - self.mean = mean - self.precision = torch.zeros(mean[0].shape[0], self.dof, self.dof).to( - dtype=mean[0].dtype, device=mean[0].device - ) - - @property - def dof(self) -> int: - return self._dof - - @property - def device(self) -> torch.device: - return self.precision[0].device - - @property - def dtype(self) -> torch.dtype: - return self.precision[0].dtype - - # calls to() on the internal tensors - def to(self, *args, **kwargs): - for var in self.mean: - var = var.to(*args, **kwargs) - self.precision = self.precision.to(*args, **kwargs) - - def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian": - if not new_name: - new_name = f"{self.name}_copy" - mean_copy = [var.copy() for var in self.mean] - return ManifoldGaussian(mean_copy, name=new_name) - - def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - the_copy = self.copy() - memo[id(self)] = the_copy - return the_copy - - def update( - self, - mean: Optional[Sequence[Manifold]] = None, - precision: Optional[torch.Tensor] = None, - ): - if mean is not None: - if len(mean) != len(self.mean): - raise ValueError( - f"Tried to update mean with sequence of different" - f"lenght to original mean sequence. Given {len(mean)}. " - f"Expected: {len(self.mean)}" - ) - for i in range(len(self.mean)): - self.mean[i].update(mean[i]) - - if precision is not None: - if precision.shape != self.precision.shape: - raise ValueError( - f"Tried to update precision with data " - f"incompatible with original tensor shape. Given {precision.shape}. " - f"Expected: {self.precision.shape}" - ) - if precision.dtype != self.dtype: - raise ValueError( - f"Tried to update using tensor of dtype {precision.dtype} but precision " - f"has dtype {self.dtype}." - ) - - self.precision = precision - - -class Marginal(ManifoldGaussian): - pass - - -class Message(ManifoldGaussian): - pass - - -""" -Local and retract -These could be implemented as methods in Manifold class -""" - - -# Transforms message gaussian to tangent plane at var -# if return_mean is True, return the (mean, lam) else return (eta, lam). -# Generalises the local function by transforming the covariance as well as mean. -def local_gaussian( - mess: Message, - var: th.LieGroup, - return_mean: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - # mean_tp is message mean in tangent space / plane at var - mean_tp = var.local(mess.mean[0]) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - - # lam_tp is the precision matrix in the tangent plane - lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) - - if return_mean: - return mean_tp, lam_tp - - else: - eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) - return eta_tp, lam_tp - - -# Transforms Gaussian in the tangent plane at var to Gaussian where the mean -# is a group element and the precision matrix is defined in the tangent plane -# at the mean. -# Generalises the retract function by transforming the covariance as well as mean. -# out_gauss is the transformed Gaussian that is updated in place. -def retract_gaussian( - mean_tp: torch.Tensor, - precision_tp: torch.Tensor, - var: th.LieGroup, - out_gauss: ManifoldGaussian, -): - mean = var.retract(mean_tp) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - inv_jac = torch.inverse(jac[0]) - precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac) - - out_gauss.update(mean=[mean], precision=precision) + if precision is None: + dof = sum([v.dof() for v in mean]) + precision = torch.zeros(mean[0].shape[0], dof, dof).to( + dtype=mean[0].dtype, device=mean[0].device + ) + super(Message, self).__init__(mean, precision=precision, name=name) + assert dof == self.dof class CostFunctionOrdering: @@ -350,8 +220,8 @@ def comp_mess( for i in range(num_optim_vars): var_dofs = self.cf.optim_var_at(i).dof() if i != v: - eta_mess, lam_mess = local_gaussian( - vtof_msgs[i], self.lin_point[i], return_mean=False + eta_mess, lam_mess = th.local_gaussian( + self.lin_point[i], vtof_msgs[i], return_mean=False ) eta_factor[start : start + var_dofs] += eta_mess[0] lam_factor[ @@ -402,8 +272,8 @@ def comp_mess( # damping in tangent space at linearisation point as message # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. - prev_mess_mean, prev_mess_lam = local_gaussian( - ftov_msgs[v], self.lin_point[v], return_mean=True + prev_mess_mean, prev_mess_lam = th.local_gaussian( + self.lin_point[v], ftov_msgs[v], return_mean=True ) # mean damping if new_mess_lam.count_nonzero() != 0: @@ -414,16 +284,14 @@ def comp_mess( new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) if new_mess_lam.count_nonzero() == 0: - new_mess = ManifoldGaussian([self.cf.optim_var_at(v).copy()]) + new_mess = Message([self.cf.optim_var_at(v).copy()]) new_mess.mean[0].data[:] = 0.0 else: new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) new_mess_mean = new_mess_mean[None, ...] new_mess_lam = new_mess_lam[None, ...] - - new_mess = ManifoldGaussian([self.cf.optim_var_at(v).copy()]) - retract_gaussian( - new_mess_mean, new_mess_lam, self.lin_point[v], new_mess + new_mess = th.retract_gaussian( + self.lin_point[v], new_mess_mean, new_mess_lam ) new_messages.append(new_mess) @@ -616,7 +484,7 @@ def _pass_var_to_fac_messages( lams_tp = [] # message lams for j, msg in enumerate(ftov_msgs): if self.var_ix_for_edges[j] == i: - tau, lam_tp = local_gaussian(msg, var, return_mean=True) + tau, lam_tp = th.local_gaussian(var, msg, return_mean=True) taus.append(tau[None, ...]) lams_tp.append(lam_tp[None, ...]) @@ -642,7 +510,8 @@ def _pass_var_to_fac_messages( dim=0 ) tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) - retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) + new_mess = th.retract_gaussian(var, tau_a, lam_a) + vtof_msgs[j].update(new_mess.mean, new_mess.precision) ix += 1 # update belief mean and variance @@ -652,7 +521,8 @@ def _pass_var_to_fac_messages( sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) - retract_gaussian(tau, lam_tau, var, self.beliefs[i]) + new_belief = th.retract_gaussian(var, tau, lam_tau) + self.beliefs[i].update(new_belief.mean, new_belief.precision) def _pass_fac_to_var_messages( self, @@ -724,10 +594,10 @@ def _optimize_loop( vtof_msgs.append(Message([vtof_msg_mu])) ftov_msgs.append(Message([ftov_msg_mu])) - # initialise Marginal for belief - self.beliefs: List[Marginal] = [] + # initialise ManifoldGaussian for belief + self.beliefs: List[th.ManifoldGaussian] = [] for var in self.ordering: - self.beliefs.append(Marginal([var])) + self.beliefs.append(th.ManifoldGaussian([var])) # compute factor potentials for the first time self.factors: List[Factor] = [] diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index 2498e5802..78591ab03 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -130,7 +130,7 @@ "track_err_history": True, "verbose": True, "backward_mode": th.BackwardMode.FULL, - "damping": 0.6, + "damping": 0.0, "dropout": 0.0, "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), } From b26e1a3a605a366cd0d0f1a74e2ab8535cf83fef Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 25 Apr 2022 14:08:07 +0100 Subject: [PATCH 11/64] ba tests, fixing numerical issues --- theseus/optimizer/gbp/ba_test.py | 30 ++++---- theseus/optimizer/gbp/gbp.py | 123 +++++++++++++++++++++---------- 2 files changed, 101 insertions(+), 52 deletions(-) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index da1731457..e429cfb56 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -11,8 +11,7 @@ import theseus as th import theseus.utils.examples as theg - -# from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule +from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule # Smaller values} result in error th.SO3.SO3_EPS = 1e-6 @@ -121,15 +120,15 @@ def run(cfg: omegaconf.OmegaConf): ) # Create optimizer and theseus layer - optimizer = th.GaussNewton( - objective, - max_iterations=cfg["inner_optim"]["max_iters"], - step_size=0.1, - ) - # optimizer = GaussianBeliefPropagation( + # optimizer = th.GaussNewton( # objective, # max_iterations=cfg["inner_optim"]["max_iters"], + # step_size=0.1, # ) + optimizer = GaussianBeliefPropagation( + objective, + max_iterations=cfg["inner_optim"]["max_iters"], + ) theseus_optim = th.TheseusLayer(optimizer) optim_arg = { @@ -137,9 +136,12 @@ def run(cfg: omegaconf.OmegaConf): "track_err_history": True, "verbose": True, "backward_mode": th.BackwardMode.FULL, - # "damping": 0.6, - # "dropout": 0.0, - # "schedule": synchronous_schedule(cfg["inner_optim"]["max_iters"], optimizer.n_edges), + "relin_threshold": 0.1, + "damping": 0.9, + "dropout": 0.0, + "schedule": synchronous_schedule( + cfg["inner_optim"]["max_iters"], optimizer.n_edges + ), } theseus_inputs = {} @@ -170,8 +172,8 @@ def run(cfg: omegaconf.OmegaConf): cfg = { "seed": 1, - "num_cameras": 10, - "num_points": 200, + "num_cameras": 2, # 10 + "num_points": 20, # 200 "average_track_length": 8, "track_locality": 0.2, "inner_optim": { @@ -181,7 +183,7 @@ def run(cfg: omegaconf.OmegaConf): "keep_step_size": True, "regularize": True, "ratio_known_cameras": 0.1, - "reg_w": 1e-4, + "reg_w": 1e-3, }, } diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 0a6954b51..e06f93740 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -77,6 +77,18 @@ def __init__( super(Message, self).__init__(mean, precision=precision, name=name) assert dof == self.dof + # sets mean to the group identity and zero precision matrix + def zero_message(self): + new_mean = [] + for var in self.mean: + new_mean_i = var.__class__() + new_mean_i.to(dtype=self.dtype, device=self.device) + new_mean.append(new_mean_i) + new_precision = torch.zeros(self.mean[0].shape[0], self.dof, self.dof).to( + dtype=self.dtype, device=self.device + ) + self.update(mean=new_mean, precision=new_precision) + class CostFunctionOrdering: def __init__(self, objective: Objective, default_order: bool = True): @@ -163,6 +175,7 @@ def __init__( var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars ] + self.steps_since_lin = 0 self.linearize() # Linearizes factors at current belief if beliefs have deviated @@ -172,6 +185,7 @@ def linearize( relin_threshold: float = None, lie=True, ): + self.steps_since_lin += 1 do_lin = False if relin_threshold is None: do_lin = True @@ -199,6 +213,8 @@ def linearize( for j, var in enumerate(self.cf.optim_vars): self.lin_point[j].update(var.data) + self.steps_since_lin = 0 + # Compute all outgoing messages from the factor. def comp_mess( self, @@ -223,71 +239,94 @@ def comp_mess( eta_mess, lam_mess = th.local_gaussian( self.lin_point[i], vtof_msgs[i], return_mean=False ) + eta_factor[start : start + var_dofs] += eta_mess[0] lam_factor[ start : start + var_dofs, start : start + var_dofs ] += lam_mess[0] + + # if self.name == "Factor__0": + # print('from adj variable') + # print(eta_mess) + # print(lam_mess) + start += var_dofs # Divide up parameters of distribution dofs = self.cf.optim_var_at(v).dof() eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) + eno = torch.cat((eta_factor[:sdim], eta_factor[sdim + dofs :])) loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( + lono = torch.cat( ( lam_factor[sdim : sdim + dofs, :sdim], lam_factor[sdim : sdim + dofs, sdim + dofs :], ), - axis=1, + dim=1, ) - lnoo = np.concatenate( + lnoo = torch.cat( ( lam_factor[:sdim, sdim : sdim + dofs], lam_factor[sdim + dofs :, sdim : sdim + dofs], ), - axis=0, + dim=0, ) - lnono = np.concatenate( + lnono = torch.cat( ( - np.concatenate( + torch.cat( (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - axis=1, + dim=1, ), - np.concatenate( + torch.cat( ( lam_factor[sdim + dofs :, :sdim], lam_factor[sdim + dofs :, sdim + dofs :], ), - axis=1, + dim=1, ), ), - axis=0, + dim=0, ) - new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno + # print('det', lnono.det()) + new_mess_lam = loo - lono @ torch.linalg.inv(lnono) @ lnoo + new_mess_eta = eo - lono @ torch.linalg.inv(lnono) @ eno # damping in tangent space at linearisation point as message # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. - prev_mess_mean, prev_mess_lam = th.local_gaussian( - self.lin_point[v], ftov_msgs[v], return_mean=True - ) # mean damping - if new_mess_lam.count_nonzero() != 0: - new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) - new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ - v - ] * prev_mess_mean[0] - new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) + if damping[v] != 0: + if ( + new_mess_lam.count_nonzero() != 0 + and ftov_msgs[v].precision.count_nonzero() != 0 + ): + prev_mess_mean, prev_mess_lam = th.local_gaussian( + self.lin_point[v], ftov_msgs[v], return_mean=True + ) + + new_mess_mean = torch.matmul( + torch.inverse(new_mess_lam), new_mess_eta + ) + new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ + v + ] * prev_mess_mean[0] + new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) if new_mess_lam.count_nonzero() == 0: + # print(self.cf.__class__, 'not updating new message as lam is all zeros') new_mess = Message([self.cf.optim_var_at(v).copy()]) - new_mess.mean[0].data[:] = 0.0 + new_mess.zero_message() + elif not torch.allclose(new_mess_lam, new_mess_lam.transpose(0, 1)): + # print(self.cf.__class__, 'not updating new message as lam is not symmetric') + new_mess = Message([self.cf.optim_var_at(v).copy()]) + new_mess.zero_message() else: - new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) + # print(self.cf.__class__, 'sending message') + new_mess_mean = torch.matmul( + torch.linalg.pinv(new_mess_lam), new_mess_eta + ) new_mess_mean = new_mess_mean[None, ...] new_mess_lam = new_mess_lam[None, ...] new_mess = th.retract_gaussian( @@ -295,6 +334,9 @@ def comp_mess( ) new_messages.append(new_mess) + # if self.name == "Factor__0": + # import ipdb; ipdb.set_trace() + sdim += dofs # update messages @@ -502,10 +544,9 @@ def _pass_var_to_fac_messages( lam_a = lams_inc.sum(dim=0) if lam_a.count_nonzero() == 0: - vtof_msgs[j].mean[0].data[:] = 0.0 - vtof_msgs[j].precision = lam_a + vtof_msgs[j].zero_message() else: - inv_lam_a = torch.inverse(lam_a) + inv_lam_a = torch.linalg.pinv(lam_a) sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( dim=0 ) @@ -530,12 +571,13 @@ def _pass_fac_to_var_messages( ftov_msgs, schedule: torch.Tensor, damping: torch.Tensor, + relin_threshold: float, ): start = 0 for factor in self.factors: num_optim_vars = factor.cf.num_optim_vars() - factor.linearize(relin_threshold=None) + # factor.linearize(relin_threshold=relin_threshold) factor.comp_mess( vtof_msgs[start : start + num_optim_vars], @@ -557,10 +599,10 @@ def _optimize_loop( info: NonlinearOptimizerInfo, verbose: bool, truncated_grad_loop: bool, - relin_threshold: float = 0.1, - damping: float = 0.0, - dropout: float = 0.0, - schedule: torch.Tensor = None, + relin_threshold: float, + damping: float, + dropout: float, + schedule: torch.Tensor, **kwargs, ): if damping > 1.0 or damping < 0.0: @@ -587,12 +629,14 @@ def _optimize_loop( ftov_msgs: List[Message] = [] for cf in self.cf_ordering: for var in cf.optim_vars: - vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") - # mean of initial message doesn't matter as long as precision is zero - vtof_msg_mu.data[:] = 0 - ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") - vtof_msgs.append(Message([vtof_msg_mu])) - ftov_msgs.append(Message([ftov_msg_mu])) + # Set mean of initial message to identity of the group + # doesn't matter what it is as long as precision is zero + vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") + ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg.zero_message() + ftov_msg.zero_message() + vtof_msgs.append(vtof_msg) + ftov_msgs.append(ftov_msg) # initialise ManifoldGaussian for belief self.beliefs: List[th.ManifoldGaussian] = [] @@ -621,6 +665,7 @@ def _optimize_loop( ftov_msgs, schedule[it_], damping_arr, + relin_threshold, ) self._pass_var_to_fac_messages( @@ -660,6 +705,7 @@ def _optimize_impl( track_err_history: bool = False, verbose: bool = False, backward_mode: BackwardMode = BackwardMode.FULL, + relin_threshold: float = 0.1, damping: float = 0.0, dropout: float = 0.0, schedule: torch.Tensor = None, @@ -684,6 +730,7 @@ def _optimize_impl( info=info, verbose=verbose, truncated_grad_loop=False, + relin_threshold=relin_threshold, damping=damping, dropout=dropout, schedule=schedule, From b4c387aff7cd8b39aab9e2181a411b495259fc43 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Thu, 5 May 2022 14:43:49 +0100 Subject: [PATCH 12/64] ba viewer --- theseus/optimizer/gbp/ba_viewer.py | 137 +++++++++++++++++++++++++++++ 1 file changed, 137 insertions(+) create mode 100644 theseus/optimizer/gbp/ba_viewer.py diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py new file mode 100644 index 000000000..7943a5710 --- /dev/null +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -0,0 +1,137 @@ +import threading + +import pyglet +import torch +import trimesh +import trimesh.viewer + +import theseus as th + + +def draw_camera( + transform, fov, resolution, color=(0.0, 1.0, 0.0, 0.8), marker_height=12.0 +): + camera = trimesh.scene.Camera(fov=fov, resolution=resolution) + marker = trimesh.creation.camera_marker(camera, marker_height=marker_height) + marker[0].apply_transform(transform) + marker[1].apply_transform(transform) + marker[1].colors = (color,) * len(marker[1].entities) + + return marker + + +class BAViewer(trimesh.viewer.SceneViewer): + def __init__(self, belief_history): + self._it = 0 + self.belief_history = belief_history + self.lock = threading.Lock() + + scene = trimesh.Scene() + self.scene = scene + self.next_iteration() + scene.set_camera() + super(BAViewer, self).__init__(scene=scene, resolution=(1080, 720)) + + def on_key_press(self, symbol, modifiers): + """ + Call appropriate functions given key presses. + """ + magnitude = 10 + if symbol == pyglet.window.key.W: + self.toggle_wireframe() + elif symbol == pyglet.window.key.Z: + self.reset_view() + elif symbol == pyglet.window.key.C: + self.toggle_culling() + elif symbol == pyglet.window.key.A: + self.toggle_axis() + elif symbol == pyglet.window.key.G: + self.toggle_grid() + elif symbol == pyglet.window.key.Q: + self.on_close() + elif symbol == pyglet.window.key.M: + self.maximize() + elif symbol == pyglet.window.key.F: + self.toggle_fullscreen() + elif symbol == pyglet.window.key.P: + print(self.scene.camera_transform) + elif symbol == pyglet.window.key.N: + if self._it + 1 in self.belief_history: + self._it += 1 + print("Iteration", self._it) + self.next_iteration() + else: + print("No more iterations to view") + + if symbol in [ + pyglet.window.key.LEFT, + pyglet.window.key.RIGHT, + pyglet.window.key.DOWN, + pyglet.window.key.UP, + ]: + self.view["ball"].down([0, 0]) + if symbol == pyglet.window.key.LEFT: + self.view["ball"].drag([-magnitude, 0]) + elif symbol == pyglet.window.key.RIGHT: + self.view["ball"].drag([magnitude, 0]) + elif symbol == pyglet.window.key.DOWN: + self.view["ball"].drag([0, -magnitude]) + elif symbol == pyglet.window.key.UP: + self.view["ball"].drag([0, magnitude]) + self.scene.camera_transform[...] = self.view["ball"].pose + + def next_iteration(self): + with self.lock: + points = [] + n_cams, n_pts = 0, 0 + for belief in self.belief_history[self._it]: + if isinstance(belief.mean[0], th.SE3): + T = torch.vstack( + ( + belief.mean[0].data[0], + torch.tensor( + [[0.0, 0.0, 0.0, 1.0]], dtype=belief.mean[0].dtype + ), + ) + ) + camera = draw_camera( + T, self.scene.camera.fov, self.scene.camera.resolution + ) + self.scene.delete_geometry(f"cam_{n_cams}") + self.scene.add_geometry(camera, geom_name=f"cam_{n_cams}") + n_cams += 1 + elif isinstance(belief.mean[0], th.Point3): + point = belief.mean[0].data + points.append(point) + + cov = torch.linalg.inv(belief.precision[0]) + ellipse = make_ellipse(point[0], cov) + self.scene.delete_geometry(f"ellipse_{n_pts}") + self.scene.add_geometry(ellipse, geom_name=f"ellipse_{n_pts}") + + points = torch.cat(points) + points_tm = trimesh.PointCloud(points) + self.scene.delete_geometry("points") + self.scene.add_geometry(points_tm, geom_name="points") + if self._it != 0: + self._update_vertex_list() + + +def make_ellipse(mean, cov): + eigvals, eigvecs = torch.linalg.eigh(cov) + + # rescale eigvals into range that fits in scene + print(eigvals) + eigvals = eigvals / 10 + eigvals = torch.maximum(torch.tensor(0.7), eigvals) + eigvals = torch.minimum(torch.tensor(60.0), eigvals) + + rotation = torch.eye(4) + rotation[:3, :3] = eigvecs + + ellipse = trimesh.creation.icosphere() + ellipse.apply_scale(eigvals.numpy()) + ellipse.apply_transform(rotation) + ellipse.apply_translation(mean) + + return ellipse From b75cfca145a094fcf32b3eab655ff9e6c7b99696 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 1 Jun 2022 15:42:56 +0100 Subject: [PATCH 13/64] bundle adjustment trimesh vis --- theseus/optimizer/gbp/__init__.py | 1 + theseus/optimizer/gbp/ba_viewer.py | 69 +++++++++++++++++++++++------- 2 files changed, 55 insertions(+), 15 deletions(-) diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py index 5a20612f2..5d308bde2 100644 --- a/theseus/optimizer/gbp/__init__.py +++ b/theseus/optimizer/gbp/__init__.py @@ -3,4 +3,5 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .ba_viewer import BAViewer from .gbp import GaussianBeliefPropagation, random_schedule, synchronous_schedule diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py index 7943a5710..d4b8e8622 100644 --- a/theseus/optimizer/gbp/ba_viewer.py +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -1,5 +1,6 @@ import threading +import numpy as np import pyglet import torch import trimesh @@ -21,9 +22,14 @@ def draw_camera( class BAViewer(trimesh.viewer.SceneViewer): - def __init__(self, belief_history): + def __init__( + self, belief_history, msg_history=None, cam_to_world=False, flip_z=True + ): self._it = 0 self.belief_history = belief_history + self.msg_history = msg_history + self.cam_to_world = cam_to_world + self.flip_z = flip_z self.lock = threading.Lock() scene = trimesh.Scene() @@ -94,6 +100,10 @@ def next_iteration(self): ), ) ) + if not self.cam_to_world: + T = np.linalg.inv(T) + if self.flip_z: + T[:3, 2] *= -1.0 camera = draw_camera( T, self.scene.camera.fov, self.scene.camera.resolution ) @@ -106,6 +116,8 @@ def next_iteration(self): cov = torch.linalg.inv(belief.precision[0]) ellipse = make_ellipse(point[0], cov) + ellipse.visual.vertex_colors[:] = [255, 0, 0, 100] + self.scene.delete_geometry(f"ellipse_{n_pts}") self.scene.add_geometry(ellipse, geom_name=f"ellipse_{n_pts}") @@ -113,25 +125,52 @@ def next_iteration(self): points_tm = trimesh.PointCloud(points) self.scene.delete_geometry("points") self.scene.add_geometry(points_tm, geom_name="points") + + if self.msg_history: + for msg in self.msg_history[self._it]: + if msg.precision.count_nonzero() != 0: + if msg.mean[0].dof() == 3 and "Reprojection" in msg.name: + ellipse = make_ellipse( + msg.mean[0][0], torch.linalg.inv(msg.precision[0]) + ) + if f"ellipse_{msg.name}" in self.scene.geometry: + self.scene.delete_geometry(f"ellipse_{msg.name}") + self.scene.add_geometry( + ellipse, geom_name=f"ellipse_{msg.name}" + ) + if self._it != 0: self._update_vertex_list() -def make_ellipse(mean, cov): - eigvals, eigvecs = torch.linalg.eigh(cov) - - # rescale eigvals into range that fits in scene - print(eigvals) +def make_ellipse(mean, cov, do_lines=False): + # eigvals_torch, eigvecs_torch = torch.linalg.eigh(cov) + eigvals, eigvecs = np.linalg.eigh(cov) # eigenvecs are columns + # print("eigvals", eigvals) # , eigvals_torch.numpy()) eigvals = eigvals / 10 - eigvals = torch.maximum(torch.tensor(0.7), eigvals) - eigvals = torch.minimum(torch.tensor(60.0), eigvals) + signs = np.sign(eigvals) + eigvals = np.clip(np.abs(eigvals), 1.0, 100, eigvals) * signs + + if do_lines: + points = [] + for i, eigvalue in enumerate(eigvals): + disp = eigvalue * eigvecs[:, i] + points.extend([mean + disp, mean - disp]) + + paths = torch.cat(points).reshape(3, 2, 3) + lines = trimesh.load_path(paths) + + return lines - rotation = torch.eye(4) - rotation[:3, :3] = eigvecs + else: + rotation = np.eye(4) + rotation[:3, :3] = eigvecs - ellipse = trimesh.creation.icosphere() - ellipse.apply_scale(eigvals.numpy()) - ellipse.apply_transform(rotation) - ellipse.apply_translation(mean) + ellipse = trimesh.creation.icosphere() + ellipse.apply_scale(eigvals) + ellipse.apply_transform(rotation) + ellipse.apply_translation(mean) + ellipse.visual.vertex_colors = trimesh.visual.random_color() + ellipse.visual.vertex_colors[:, 3] = 100 - return ellipse + return ellipse From 40575f4e1ddda18e5b1e40c12efc0ad62cbce4c2 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Wed, 1 Jun 2022 15:44:22 +0100 Subject: [PATCH 14/64] soft huber-like loss on norm of x,y error --- .../bundle_adjustment/reprojection_error.py | 12 ++++------ .../utils/examples/bundle_adjustment/util.py | 23 +++++++++++++++++++ 2 files changed, 27 insertions(+), 8 deletions(-) diff --git a/theseus/utils/examples/bundle_adjustment/reprojection_error.py b/theseus/utils/examples/bundle_adjustment/reprojection_error.py index 5f32092ae..cf0983a92 100644 --- a/theseus/utils/examples/bundle_adjustment/reprojection_error.py +++ b/theseus/utils/examples/bundle_adjustment/reprojection_error.py @@ -9,7 +9,7 @@ import theseus as th -from .util import soft_loss_huber_like +from .util import soft_loss_huber_like_reprojection class Reprojection(th.CostFunction): @@ -66,10 +66,9 @@ def error(self) -> torch.Tensor: err = point_projection - self.image_feature_point.data - err_norm = torch.norm(err, dim=1).unsqueeze(1) loss_radius = torch.exp(self.log_loss_radius.data) - val, _ = soft_loss_huber_like(err_norm, loss_radius) + val, _ = soft_loss_huber_like_reprojection(err, loss_radius) return val def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: @@ -100,13 +99,10 @@ def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: ) + proj_sqn_jac * d_proj_factor.unsqueeze(2) err = point_projection - self.image_feature_point.data - err_norm = torch.norm(err, dim=1).unsqueeze(1) - err_dir = err / err_norm - norm_jac = torch.bmm(err_dir.unsqueeze(1), point_projection_jac) loss_radius = torch.exp(self.log_loss_radius.data) + val, der = soft_loss_huber_like_reprojection(err, loss_radius) - val, der = soft_loss_huber_like(err_norm, loss_radius) - soft_jac = norm_jac * der.unsqueeze(1) + soft_jac = torch.bmm(der, point_projection_jac) return [soft_jac[:, :, :6], soft_jac[:, :, 6:]], val diff --git a/theseus/utils/examples/bundle_adjustment/util.py b/theseus/utils/examples/bundle_adjustment/util.py index bd30d7447..e670e91a6 100644 --- a/theseus/utils/examples/bundle_adjustment/util.py +++ b/theseus/utils/examples/bundle_adjustment/util.py @@ -31,6 +31,29 @@ def soft_loss_huber_like( return val, der +# For reprojection cost functions where the loss is 2 dimensional, +# x and y pixel loss, but the robust loss region is determined +# by the norm of the (x, y) pixel loss. +def soft_loss_huber_like_reprojection( + x: torch.Tensor, radius: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + x_norm = torch.norm(x, dim=1).unsqueeze(1) + val, der = soft_loss_huber_like(x_norm, radius) + scaling = val / x_norm + x_loss = x * scaling + + term1 = scaling[..., None] * torch.eye(2, dtype=x.dtype, device=x.device).reshape( + 1, 2, 2 + ).repeat(x.shape[0], 1, 1) + term2 = ( + torch.bmm(x.unsqueeze(2), x.unsqueeze(1)) + * ((der - scaling) / (x_norm**2))[..., None] + ) + der = term1 + term2 + + return x_loss, der + + # ------------------------------------------------------------ # # ----------------------------- RNG -------------------------- # # ------------------------------------------------------------ # From a91754ec8b35190c6133d4a634673eeb39b6084e Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 6 Jun 2022 12:51:21 +0100 Subject: [PATCH 15/64] remove prints and fix viewer --- theseus/optimizer/gbp/ba_test.py | 52 +++++++--- theseus/optimizer/gbp/ba_viewer.py | 19 ++-- theseus/optimizer/gbp/gbp.py | 161 +++++++++++++++-------------- 3 files changed, 135 insertions(+), 97 deletions(-) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index e429cfb56..1e2e795bb 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -11,9 +11,13 @@ import theseus as th import theseus.utils.examples as theg -from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule +from theseus.optimizer.gbp import ( + BAViewer, + GaussianBeliefPropagation, + synchronous_schedule, +) -# Smaller values} result in error +# Smaller values result in error th.SO3.SO3_EPS = 1e-6 @@ -48,6 +52,13 @@ def camera_loss( return loss +def average_repojection_error(objective) -> float: + + are = 0.0 + + return are + + def run(cfg: omegaconf.OmegaConf): # create (or load) dataset ba = theg.BundleAdjustmentDataset.generate_synthetic( @@ -55,9 +66,16 @@ def run(cfg: omegaconf.OmegaConf): num_points=cfg["num_points"], average_track_length=cfg["average_track_length"], track_locality=cfg["track_locality"], - feat_random=1.5, + feat_random=0.0, + prob_feat_is_outlier=0.0, outlier_feat_random=70, ) + + # cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( + # "/home/joe/Downloads/riku/fr3stf.txt") + # ba = theg.BundleAdjustmentDataset(cams, points, obs) + # import ipdb; ipdb.set_trace() + # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") # param that control transition from squared loss to huber @@ -84,14 +102,16 @@ def run(cfg: omegaconf.OmegaConf): # Add regularization if cfg["inner_optim"]["regularize"]: zero_point3 = th.Point3(dtype=dtype, name="zero_point") - identity_se3 = th.SE3(dtype=dtype, name="zero_se3") + # identity_se3 = th.SE3(dtype=dtype, name="zero_se3") w = np.sqrt(cfg["inner_optim"]["reg_w"]) damping_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) for name, var in objective.optim_vars.items(): target: th.Manifold if isinstance(var, th.SE3): - target = identity_se3 + target = var.copy(new_name="target_" + var.name) + # target = identity_se3 elif isinstance(var, th.Point3): + # target = var.copy(new_name="target_" + var.name) target = zero_point3 else: assert False @@ -105,8 +125,8 @@ def run(cfg: omegaconf.OmegaConf): objective.optim_vars[c.pose.name] for c in ba.cameras # type: ignore ] if cfg["inner_optim"]["ratio_known_cameras"] > 0.0: - w = 100.0 - camera_weight = th.ScaleCostWeight(100 * torch.ones(1, dtype=dtype)) + w = 1000.0 + camera_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) for i in range(len(ba.cameras)): if np.random.rand() > cfg["inner_optim"]["ratio_known_cameras"]: continue @@ -119,6 +139,8 @@ def run(cfg: omegaconf.OmegaConf): ) ) + print("Factors:\n", objective.cost_functions.keys(), "\n") + # Create optimizer and theseus layer # optimizer = th.GaussNewton( # objective, @@ -136,8 +158,8 @@ def run(cfg: omegaconf.OmegaConf): "track_err_history": True, "verbose": True, "backward_mode": th.BackwardMode.FULL, - "relin_threshold": 0.1, - "damping": 0.9, + "relin_threshold": 0.0001, + "damping": 0.0, "dropout": 0.0, "schedule": synchronous_schedule( cfg["inner_optim"]["max_iters"], optimizer.n_edges @@ -167,23 +189,25 @@ def run(cfg: omegaconf.OmegaConf): loss = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") + BAViewer(optimizer.belief_history, msg_history=optimizer.ftov_msgs_history) + if __name__ == "__main__": cfg = { "seed": 1, - "num_cameras": 2, # 10 - "num_points": 20, # 200 + "num_cameras": 4, + "num_points": 10, "average_track_length": 8, "track_locality": 0.2, "inner_optim": { - "max_iters": 10, + "max_iters": 50, "verbose": True, "track_err_history": True, "keep_step_size": True, "regularize": True, - "ratio_known_cameras": 0.1, - "reg_w": 1e-3, + "ratio_known_cameras": 1.0, + "reg_w": 1e-4, }, } diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py index d4b8e8622..0429ed72c 100644 --- a/theseus/optimizer/gbp/ba_viewer.py +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -108,18 +108,19 @@ def next_iteration(self): T, self.scene.camera.fov, self.scene.camera.resolution ) self.scene.delete_geometry(f"cam_{n_cams}") - self.scene.add_geometry(camera, geom_name=f"cam_{n_cams}") + self.scene.add_geometry(camera[1], geom_name=f"cam_{n_cams}") n_cams += 1 elif isinstance(belief.mean[0], th.Point3): point = belief.mean[0].data points.append(point) - cov = torch.linalg.inv(belief.precision[0]) - ellipse = make_ellipse(point[0], cov) - ellipse.visual.vertex_colors[:] = [255, 0, 0, 100] + # cov = torch.linalg.inv(belief.precision[0]) + # ellipse = make_ellipse(point[0], cov) + # ellipse.visual.vertex_colors[:] = [255, 0, 0, 100] - self.scene.delete_geometry(f"ellipse_{n_pts}") - self.scene.add_geometry(ellipse, geom_name=f"ellipse_{n_pts}") + # self.scene.delete_geometry(f"ellipse_{n_pts}") + # self.scene.add_geometry(ellipse, geom_name=f"ellipse_{n_pts}") + n_pts += 1 points = torch.cat(points) points_tm = trimesh.PointCloud(points) @@ -143,7 +144,7 @@ def next_iteration(self): self._update_vertex_list() -def make_ellipse(mean, cov, do_lines=False): +def make_ellipse(mean, cov, do_lines=False, color=None): # eigvals_torch, eigvecs_torch = torch.linalg.eigh(cov) eigvals, eigvecs = np.linalg.eigh(cov) # eigenvecs are columns # print("eigvals", eigvals) # , eigvals_torch.numpy()) @@ -170,7 +171,9 @@ def make_ellipse(mean, cov, do_lines=False): ellipse.apply_scale(eigvals) ellipse.apply_transform(rotation) ellipse.apply_translation(mean) - ellipse.visual.vertex_colors = trimesh.visual.random_color() + if color is None: + color = trimesh.visual.random_color() + ellipse.visual.vertex_colors = color ellipse.visual.vertex_colors[:, 3] = 100 return ellipse diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index e06f93740..ccf8ff885 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -198,6 +198,7 @@ def linearize( if do_lin: J, error = self.cf.weighted_jacobians_error() + J_stk = torch.cat(J, dim=-1) lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) @@ -229,6 +230,7 @@ def comp_mess( for v in range(num_optim_vars): eta_factor = self.potential_eta.clone()[0] lam_factor = self.potential_lam.clone()[0] + lam_factor_copy = lam_factor.clone() # Take product of factor with incoming messages. # Convert mesages to tangent space at linearisation point. @@ -239,91 +241,85 @@ def comp_mess( eta_mess, lam_mess = th.local_gaussian( self.lin_point[i], vtof_msgs[i], return_mean=False ) - eta_factor[start : start + var_dofs] += eta_mess[0] lam_factor[ start : start + var_dofs, start : start + var_dofs ] += lam_mess[0] - # if self.name == "Factor__0": - # print('from adj variable') - # print(eta_mess) - # print(lam_mess) - start += var_dofs - # Divide up parameters of distribution dofs = self.cf.optim_var_at(v).dof() - eo = eta_factor[sdim : sdim + dofs] - eno = torch.cat((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = torch.cat( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - dim=1, - ) - lnoo = torch.cat( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - dim=0, - ) - lnono = torch.cat( - ( - torch.cat( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - dim=1, + + if torch.allclose(lam_factor, lam_factor_copy) and num_optim_vars > 1: + # print(self.cf.name, '---> not updating as incoming message lams are zeros') + new_mess = Message([self.cf.optim_var_at(v).copy()]) + new_mess.zero_message() + + else: + # print(self.cf.name, '---> sending message') + # Divide up parameters of distribution + eo = eta_factor[sdim : sdim + dofs] + eno = torch.cat((eta_factor[:sdim], eta_factor[sdim + dofs :])) + + loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] + lono = torch.cat( + ( + lam_factor[sdim : sdim + dofs, :sdim], + lam_factor[sdim : sdim + dofs, sdim + dofs :], + ), + dim=1, + ) + lnoo = torch.cat( + ( + lam_factor[:sdim, sdim : sdim + dofs], + lam_factor[sdim + dofs :, sdim : sdim + dofs], ), - torch.cat( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], + dim=0, + ) + lnono = torch.cat( + ( + torch.cat( + ( + lam_factor[:sdim, :sdim], + lam_factor[:sdim, sdim + dofs :], + ), + dim=1, + ), + torch.cat( + ( + lam_factor[sdim + dofs :, :sdim], + lam_factor[sdim + dofs :, sdim + dofs :], + ), + dim=1, ), - dim=1, ), - ), - dim=0, - ) + dim=0, + ) - # print('det', lnono.det()) - new_mess_lam = loo - lono @ torch.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - lono @ torch.linalg.inv(lnono) @ eno - - # damping in tangent space at linearisation point as message - # is already in this tangent space. Could equally do damping - # in the tangent space of the new or old message mean. - # mean damping - if damping[v] != 0: - if ( - new_mess_lam.count_nonzero() != 0 - and ftov_msgs[v].precision.count_nonzero() != 0 - ): - prev_mess_mean, prev_mess_lam = th.local_gaussian( - self.lin_point[v], ftov_msgs[v], return_mean=True - ) + new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_mess_eta = eo - lono @ torch.linalg.inv(lnono) @ eno + + # damping in tangent space at linearisation point as message + # is already in this tangent space. Could equally do damping + # in the tangent space of the new or old message mean. + # mean damping + if damping[v] != 0: # and steps_since_lin > 0: + if ( + new_mess_lam.count_nonzero() != 0 + and ftov_msgs[v].precision.count_nonzero() != 0 + ): + prev_mess_mean, prev_mess_lam = th.local_gaussian( + self.lin_point[v], ftov_msgs[v], return_mean=True + ) - new_mess_mean = torch.matmul( - torch.inverse(new_mess_lam), new_mess_eta - ) - new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ - v - ] * prev_mess_mean[0] - new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) + new_mess_mean = torch.matmul( + torch.inverse(new_mess_lam), new_mess_eta + ) + new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ + v + ] * prev_mess_mean[0] + new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) - if new_mess_lam.count_nonzero() == 0: - # print(self.cf.__class__, 'not updating new message as lam is all zeros') - new_mess = Message([self.cf.optim_var_at(v).copy()]) - new_mess.zero_message() - elif not torch.allclose(new_mess_lam, new_mess_lam.transpose(0, 1)): - # print(self.cf.__class__, 'not updating new message as lam is not symmetric') - new_mess = Message([self.cf.optim_var_at(v).copy()]) - new_mess.zero_message() - else: - # print(self.cf.__class__, 'sending message') new_mess_mean = torch.matmul( torch.linalg.pinv(new_mess_lam), new_mess_eta ) @@ -332,10 +328,8 @@ def comp_mess( new_mess = th.retract_gaussian( self.lin_point[v], new_mess_mean, new_mess_lam ) - new_messages.append(new_mess) - # if self.name == "Factor__0": - # import ipdb; ipdb.set_trace() + new_messages.append(new_mess) sdim += dofs @@ -526,6 +520,7 @@ def _pass_var_to_fac_messages( lams_tp = [] # message lams for j, msg in enumerate(ftov_msgs): if self.var_ix_for_edges[j] == i: + # print(msg.mean, msg.precision) tau, lam_tp = th.local_gaussian(var, msg, return_mean=True) taus.append(tau[None, ...]) lams_tp.append(lam_tp[None, ...]) @@ -573,11 +568,18 @@ def _pass_fac_to_var_messages( damping: torch.Tensor, relin_threshold: float, ): + relins = 0 + did_relin = [] start = 0 for factor in self.factors: num_optim_vars = factor.cf.num_optim_vars() - # factor.linearize(relin_threshold=relin_threshold) + factor.linearize(relin_threshold=relin_threshold) + if factor.steps_since_lin == 0: + relins += 1 + did_relin += [1] + else: + did_relin += [0] factor.comp_mess( vtof_msgs[start : start + num_optim_vars], @@ -587,6 +589,9 @@ def _pass_fac_to_var_messages( start += num_optim_vars + # print(f"Factor relinearisations: {relins} / {len(self.factors)}") + # print(did_relin) + """ Optimization loop functions """ @@ -648,9 +653,15 @@ def _optimize_loop( for cf in self.cf_ordering: self.factors.append(Factor(cf)) + self.belief_history = {} + self.ftov_msgs_history = {} + converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): + self.ftov_msgs_history[it_] = [msg.copy() for msg in ftov_msgs] + self.belief_history[it_] = [belief.copy() for belief in self.beliefs] + # damping # damping = self.gbp_settings.get_damping(iters_since_relin) damping_arr = torch.full([self.n_edges], damping) From 32380fbfeb8a0ab4679ee8c173c3ad135114d8ad Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 6 Jun 2022 12:51:40 +0100 Subject: [PATCH 16/64] remove symmetric check --- theseus/optimizer/manifold_gaussian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/theseus/optimizer/manifold_gaussian.py b/theseus/optimizer/manifold_gaussian.py index 8cf1d980c..2dafc0501 100644 --- a/theseus/optimizer/manifold_gaussian.py +++ b/theseus/optimizer/manifold_gaussian.py @@ -101,8 +101,8 @@ def update( f"Tried to update using tensor on device: {precision.dtype} but precision " f"is on device: {self.device}." ) - if not torch.allclose(precision, precision.transpose(1, 2)): - raise ValueError("Tried to update precision with non-symmetric matrix.") + # if not torch.allclose(precision, precision.transpose(1, 2)): + # raise ValueError("Tried to update precision with non-symmetric matrix.") self.precision = precision From 2aa58a119019e16c6398b274ec0908c0ee90853d Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Mon, 6 Jun 2022 21:52:03 +0100 Subject: [PATCH 17/64] ba visualisation and derivates setup for pgo --- theseus/optimizer/gbp/ba_test.py | 37 ++++++++++++++------ theseus/optimizer/gbp/ba_viewer.py | 54 +++++++++++++++++++++--------- theseus/optimizer/gbp/gbp.py | 15 +++++---- theseus/optimizer/gbp/pgo_test.py | 38 +++++++++++++++------ 4 files changed, 102 insertions(+), 42 deletions(-) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 1e2e795bb..1a2fc4fe8 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -54,8 +54,13 @@ def camera_loss( def average_repojection_error(objective) -> float: - are = 0.0 + reproj_norms = [] + for k in objective.cost_functions.keys(): + if "Reprojection" in k: + err = objective.cost_functions[k].error().norm(dim=1) + reproj_norms.append(err) + are = torch.tensor(reproj_norms).mean().item() return are @@ -72,10 +77,8 @@ def run(cfg: omegaconf.OmegaConf): ) # cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( - # "/home/joe/Downloads/riku/fr3stf.txt") + # "/media/joe/3.0TB Hard Disk/bal_data/problem-21-11315-pre.txt") # ba = theg.BundleAdjustmentDataset(cams, points, obs) - # import ipdb; ipdb.set_trace() - # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") # param that control transition from squared loss to huber @@ -83,9 +86,12 @@ def run(cfg: omegaconf.OmegaConf): log_loss_radius = th.Vector(data=radius_tensor, name="log_loss_radius") # Set up objective + print("Setting up objective") objective = th.Objective(dtype=torch.float64) - for obs in ba.observations: + print("obs") + for i, obs in enumerate(ba.observations): + # print(i, len(ba.observations)) cam = ba.cameras[obs.camera_index] cost_function = theg.Reprojection( camera_pose=cam.pose, @@ -100,6 +106,7 @@ def run(cfg: omegaconf.OmegaConf): dtype = objective.dtype # Add regularization + print("reg") if cfg["inner_optim"]["regularize"]: zero_point3 = th.Point3(dtype=dtype, name="zero_point") # identity_se3 = th.SE3(dtype=dtype, name="zero_se3") @@ -130,6 +137,7 @@ def run(cfg: omegaconf.OmegaConf): for i in range(len(ba.cameras)): if np.random.rand() > cfg["inner_optim"]["ratio_known_cameras"]: continue + print("fixing cam", i) objective.add( th.eb.VariableDifference( camera_pose_vars[i], @@ -139,7 +147,7 @@ def run(cfg: omegaconf.OmegaConf): ) ) - print("Factors:\n", objective.cost_functions.keys(), "\n") + # print("Factors:\n", objective.cost_functions.keys(), "\n") # Create optimizer and theseus layer # optimizer = th.GaussNewton( @@ -158,7 +166,7 @@ def run(cfg: omegaconf.OmegaConf): "track_err_history": True, "verbose": True, "backward_mode": th.BackwardMode.FULL, - "relin_threshold": 0.0001, + "relin_threshold": 0.001, "damping": 0.0, "dropout": 0.0, "schedule": synchronous_schedule( @@ -176,7 +184,7 @@ def run(cfg: omegaconf.OmegaConf): with torch.no_grad(): camera_loss_ref = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: {camera_loss_ref: .3f}") - # print_histogram(ba, theseus_inputs, "Input histogram:") + print_histogram(ba, theseus_inputs, "Input histogram:") objective.update(theseus_inputs) print("squred err:", objective.error_squared_norm().item()) @@ -189,15 +197,22 @@ def run(cfg: omegaconf.OmegaConf): loss = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") - BAViewer(optimizer.belief_history, msg_history=optimizer.ftov_msgs_history) + are = average_repojection_error(objective) + print("Average reprojection error (pixels): ", are) + + print_histogram(ba, theseus_inputs, "Final histogram:") + + BAViewer( + optimizer.belief_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points + ) # , msg_history=optimizer.ftov_msgs_history) if __name__ == "__main__": cfg = { "seed": 1, - "num_cameras": 4, - "num_points": 10, + "num_cameras": 10, + "num_points": 50, "average_track_length": 8, "track_locality": 0.2, "inner_optim": { diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py index 0429ed72c..b383e65c5 100644 --- a/theseus/optimizer/gbp/ba_viewer.py +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -23,7 +23,13 @@ def draw_camera( class BAViewer(trimesh.viewer.SceneViewer): def __init__( - self, belief_history, msg_history=None, cam_to_world=False, flip_z=True + self, + belief_history, + msg_history=None, + cam_to_world=False, + flip_z=True, + gt_cameras=None, + gt_points=None, ): self._it = 0 self.belief_history = belief_history @@ -34,6 +40,17 @@ def __init__( scene = trimesh.Scene() self.scene = scene + + if gt_cameras is not None: + for i, cam in enumerate(gt_cameras): + camera = self.make_cam(cam.pose.data[0]) + self.scene.add_geometry(camera[1], geom_name=f"gt_cam_{i}") + + if gt_points is not None: + pts = torch.cat([pt.data for pt in gt_points]) + pc = trimesh.PointCloud(pts, [0, 255, 0, 200]) + self.scene.add_geometry(pc, geom_name="gt_points") + self.next_iteration() scene.set_camera() super(BAViewer, self).__init__(scene=scene, resolution=(1080, 720)) @@ -86,26 +103,33 @@ def on_key_press(self, symbol, modifiers): self.view["ball"].drag([0, magnitude]) self.scene.camera_transform[...] = self.view["ball"].pose + def make_cam(self, pose, color=(0.0, 1.0, 0.0, 0.8)): + T = torch.vstack( + ( + pose, + torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=pose.dtype), + ) + ) + if not self.cam_to_world: + T = np.linalg.inv(T) + if self.flip_z: + T[:3, 2] *= -1.0 + camera = draw_camera( + T, + self.scene.camera.fov, + self.scene.camera.resolution, + color=color, + ) + return camera + def next_iteration(self): with self.lock: points = [] n_cams, n_pts = 0, 0 for belief in self.belief_history[self._it]: if isinstance(belief.mean[0], th.SE3): - T = torch.vstack( - ( - belief.mean[0].data[0], - torch.tensor( - [[0.0, 0.0, 0.0, 1.0]], dtype=belief.mean[0].dtype - ), - ) - ) - if not self.cam_to_world: - T = np.linalg.inv(T) - if self.flip_z: - T[:3, 2] *= -1.0 - camera = draw_camera( - T, self.scene.camera.fov, self.scene.camera.resolution + camera = self.make_cam( + belief.mean[0].data[0], color=(0.0, 0.0, 1.0, 0.8) ) self.scene.delete_geometry(f"cam_{n_cams}") self.scene.add_geometry(camera[1], geom_name=f"cam_{n_cams}") diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index ccf8ff885..f62f3dc4b 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -81,7 +81,10 @@ def __init__( def zero_message(self): new_mean = [] for var in self.mean: - new_mean_i = var.__class__() + if var.__class__ == th.Vector: + new_mean_i = var.__class__(var.dof()) + else: + new_mean_i = var.__class__() new_mean_i.to(dtype=self.dtype, device=self.device) new_mean.append(new_mean_i) new_precision = torch.zeros(self.mean[0].shape[0], self.dof, self.dof).to( @@ -303,7 +306,7 @@ def comp_mess( # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. # mean damping - if damping[v] != 0: # and steps_since_lin > 0: + if damping[v] != 0 and self.steps_since_lin > 0: if ( new_mess_lam.count_nonzero() != 0 and ftov_msgs[v].precision.count_nonzero() != 0 @@ -590,7 +593,7 @@ def _pass_fac_to_var_messages( start += num_optim_vars # print(f"Factor relinearisations: {relins} / {len(self.factors)}") - # print(did_relin) + return relins """ Optimization loop functions @@ -663,7 +666,6 @@ def _optimize_loop( self.belief_history[it_] = [belief.copy() for belief in self.beliefs] # damping - # damping = self.gbp_settings.get_damping(iters_since_relin) damping_arr = torch.full([self.n_edges], damping) # dropout can be implemented through damping @@ -671,7 +673,7 @@ def _optimize_loop( dropout_ixs = torch.rand(self.n_edges) < dropout damping_arr[dropout_ixs] = 1.0 - self._pass_fac_to_var_messages( + relins = self._pass_fac_to_var_messages( vtof_msgs, ftov_msgs, schedule[it_], @@ -692,7 +694,8 @@ def _optimize_loop( self._update_info(info, it_, err, converged_indices) if verbose: print( - f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}" + f"GBP. Iteration: {it_+1}. Error: {err.mean().item():.4f}. " + f"Relins: {relins} / {len(self.factors)}" ) converged_indices = self._check_convergence(err, info.last_err) info.status[ diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index 78591ab03..7bbd43d72 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -21,7 +21,7 @@ size = 3 dim = 2 -noise_cov = np.array([0.01, 0.01]) +noise_cov = np.array([0.05, 0.05]) prior_noise_std = 0.2 prior_sigma = np.array([1.3**2, 1.3**2]) @@ -47,10 +47,13 @@ prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") +gt_poses = [] + p = 0 for i in range(size): for j in range(size): init = torch.Tensor([j, i]) + gt_poses.append(init[None, :]) if i == 0 and j == 0: w = anchor_w else: @@ -72,8 +75,8 @@ # Measurement cost functions -meas_std = 0.1 -meas_w = th.ScaleCostWeight(1 / meas_std, name="prior_weight") +meas_std_tensor = torch.nn.Parameter(torch.tensor([0.1])) +meas_w = th.ScaleCostWeight(1 / meas_std_tensor, name="prior_weight") m = 0 for i in range(size): @@ -110,12 +113,26 @@ objective.add(cf_meas) m += 1 -# objective.update(init_dict) -# print("Initial cost:", objective.error_squared_norm()) -# joint = fg.get_joint() -# marg_covs = np.diag(joint.cov())[::2] -# map_soln = fg.MAP().reshape([size * size, 2]) +# outer optimizer +lr = 1e-3 +model_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) + + +linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) +th_layer = th.TheseusLayer(linear_optimizer) +outputs, _ = th_layer.forward(inputs) + + +gt_poses_tensor = torch.cat(gt_poses) +output_poses = torch.cat([x.data for x in poses]) + +loss = torch.norm(gt_poses_tensor - output_poses) +loss.backward() + +# da_dx = torch.autograd.grad(loss, data_x, retain_graph=True)[0].squeeze() +# print("\n--- backward_mode=IMPLICIT") +# print(da_dx.numpy()) max_iterations = 100 optimizer = GaussianBeliefPropagation( @@ -134,7 +151,8 @@ "dropout": 0.0, "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), } + updated_inputs, info = theseus_optim.forward(inputs, optim_arg) -# print("updated_inputs", updated_inputs) -# print("info", info) +print("gbp outputs\n", updated_inputs) +print("linear solver\n", updated_inputs) From 4c9194448f1fae593a869e5c340d13f606c98ab4 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 14:41:08 +0100 Subject: [PATCH 18/64] lin system damping for ftov msgs --- theseus/optimizer/gbp/gbp.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index f62f3dc4b..a5694abde 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -157,6 +157,7 @@ def __init__( self, cf: CostFunction, name: Optional[str] = None, + lin_system_damping: float = 1e-6, ): self._id = next(Factor._ids) if name: @@ -165,6 +166,7 @@ def __init__( self.name = f"{self.__class__.__name__}__{self._id}" self.cf = cf + self.lin_system_damping = lin_system_damping batch_size = cf.optim_var_at(0).shape[0] self._dof = sum([var.dof() for var in cf.optim_vars]) @@ -193,11 +195,13 @@ def linearize( if relin_threshold is None: do_lin = True else: - lp_dists = [ - lp.local(self.cf.optim_var_at(j)).norm() - for j, lp in enumerate(self.lin_point) - ] - do_lin = np.max(lp_dists) > relin_threshold + lp_dists = torch.tensor( + [ + lp.local(self.cf.optim_var_at(j)).norm() + for j, lp in enumerate(self.lin_point) + ] + ) + do_lin = bool((torch.max(lp_dists) > relin_threshold).item()) if do_lin: J, error = self.cf.weighted_jacobians_error() @@ -299,7 +303,7 @@ def comp_mess( dim=0, ) - new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo + new_mess_lam = loo - lono @ torch.linalg.inv(lnono) @ lnoo new_mess_eta = eo - lono @ torch.linalg.inv(lnono) @ eno # damping in tangent space at linearisation point as message @@ -323,11 +327,16 @@ def comp_mess( ] * prev_mess_mean[0] new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) - new_mess_mean = torch.matmul( - torch.linalg.pinv(new_mess_lam), new_mess_eta + new_mess_lam = th.DenseSolver._apply_damping( + new_mess_lam[None, ...], + self.lin_system_damping, + ellipsoidal=True, + eps=1e-8, ) - new_mess_mean = new_mess_mean[None, ...] - new_mess_lam = new_mess_lam[None, ...] + new_mess_mean = th.LUDenseSolver._solve_sytem( + new_mess_eta[..., None], new_mess_lam + ) + new_mess = th.retract_gaussian( self.lin_point[v], new_mess_mean, new_mess_lam ) @@ -544,7 +553,7 @@ def _pass_var_to_fac_messages( if lam_a.count_nonzero() == 0: vtof_msgs[j].zero_message() else: - inv_lam_a = torch.linalg.pinv(lam_a) + inv_lam_a = torch.linalg.inv(lam_a) sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( dim=0 ) @@ -611,6 +620,7 @@ def _optimize_loop( damping: float, dropout: float, schedule: torch.Tensor, + lin_system_damping: float, **kwargs, ): if damping > 1.0 or damping < 0.0: @@ -654,7 +664,7 @@ def _optimize_loop( # compute factor potentials for the first time self.factors: List[Factor] = [] for cf in self.cf_ordering: - self.factors.append(Factor(cf)) + self.factors.append(Factor(cf, lin_system_damping=lin_system_damping)) self.belief_history = {} self.ftov_msgs_history = {} @@ -723,6 +733,7 @@ def _optimize_impl( damping: float = 0.0, dropout: float = 0.0, schedule: torch.Tensor = None, + lin_system_damping: float = 1e-6, **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): @@ -748,6 +759,7 @@ def _optimize_impl( damping=damping, dropout=dropout, schedule=schedule, + lin_system_damping=lin_system_damping, **kwargs, ) From d1244e815eea1881c0da03828f02907ebed327e4 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 14:42:03 +0100 Subject: [PATCH 19/64] static dense solver methods --- theseus/optimizer/linear/dense_solver.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/theseus/optimizer/linear/dense_solver.py b/theseus/optimizer/linear/dense_solver.py index bdf9294b9..d2008d9ea 100644 --- a/theseus/optimizer/linear/dense_solver.py +++ b/theseus/optimizer/linear/dense_solver.py @@ -88,6 +88,9 @@ def solve( if self._check_singular: AtA = self.linearization.AtA Atb = self.linearization.Atb + import ipdb + + ipdb.set_trace() with torch.no_grad(): output = torch.zeros(AtA.shape[0], AtA.shape[1]).to(AtA.device) _, _, infos = torch.lu(AtA, get_infos=True) @@ -134,7 +137,8 @@ def __init__( check_singular=check_singular, ) - def _solve_sytem(self, Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: + @staticmethod + def _solve_sytem(Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: return torch.linalg.solve(AtA, Atb).squeeze(2) @@ -153,6 +157,7 @@ def __init__( check_singular=check_singular, ) - def _solve_sytem(self, Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: + @staticmethod + def _solve_sytem(Atb: torch.Tensor, AtA: torch.Tensor) -> torch.Tensor: lower = torch.linalg.cholesky(AtA) return torch.cholesky_solve(Atb, lower).squeeze(2) From 87cac1e02510d9e0c01789dd088e972e78e5c237 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 14:43:38 +0100 Subject: [PATCH 20/64] ba with damping in linear system --- theseus/optimizer/gbp/ba_test.py | 42 +++++++++++-------- .../utils/examples/bundle_adjustment/data.py | 15 ++++++- 2 files changed, 38 insertions(+), 19 deletions(-) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 1a2fc4fe8..9985f35bc 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -74,6 +74,9 @@ def run(cfg: omegaconf.OmegaConf): feat_random=0.0, prob_feat_is_outlier=0.0, outlier_feat_random=70, + cam_pos_rand=0.5, + cam_rot_rand=0.1, + point_rand=5.0, ) # cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( @@ -150,12 +153,7 @@ def run(cfg: omegaconf.OmegaConf): # print("Factors:\n", objective.cost_functions.keys(), "\n") # Create optimizer and theseus layer - # optimizer = th.GaussNewton( - # objective, - # max_iterations=cfg["inner_optim"]["max_iters"], - # step_size=0.1, - # ) - optimizer = GaussianBeliefPropagation( + optimizer = cfg["optimizer_cls"]( objective, max_iterations=cfg["inner_optim"]["max_iters"], ) @@ -166,13 +164,18 @@ def run(cfg: omegaconf.OmegaConf): "track_err_history": True, "verbose": True, "backward_mode": th.BackwardMode.FULL, - "relin_threshold": 0.001, - "damping": 0.0, - "dropout": 0.0, - "schedule": synchronous_schedule( - cfg["inner_optim"]["max_iters"], optimizer.n_edges - ), } + if cfg["optimizer_cls"] == GaussianBeliefPropagation: + gbp_optim_arg = { + "relin_threshold": 0.0000000001, + "damping": 0.0, + "dropout": 0.0, + "schedule": synchronous_schedule( + cfg["inner_optim"]["max_iters"], optimizer.n_edges + ), + "lin_system_damping": 1e-5, + } + optim_arg = {**optim_arg, **gbp_optim_arg} theseus_inputs = {} for cam in ba.cameras: @@ -200,6 +203,9 @@ def run(cfg: omegaconf.OmegaConf): are = average_repojection_error(objective) print("Average reprojection error (pixels): ", are) + with torch.no_grad(): + camera_loss_ref = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS: {camera_loss_ref: .3f}") print_histogram(ba, theseus_inputs, "Final histogram:") BAViewer( @@ -211,18 +217,20 @@ def run(cfg: omegaconf.OmegaConf): cfg = { "seed": 1, - "num_cameras": 10, - "num_points": 50, + "num_cameras": 5, + "num_points": 10, "average_track_length": 8, "track_locality": 0.2, + "optimizer_cls": GaussianBeliefPropagation, + # "optimizer_cls": th.GaussNewton, "inner_optim": { - "max_iters": 50, + "max_iters": 10, "verbose": True, "track_err_history": True, "keep_step_size": True, "regularize": True, - "ratio_known_cameras": 1.0, - "reg_w": 1e-4, + "ratio_known_cameras": 0.3, + "reg_w": 1e-7, }, } diff --git a/theseus/utils/examples/bundle_adjustment/data.py b/theseus/utils/examples/bundle_adjustment/data.py index a05294ed1..8f0cd0fc0 100644 --- a/theseus/utils/examples/bundle_adjustment/data.py +++ b/theseus/utils/examples/bundle_adjustment/data.py @@ -39,6 +39,11 @@ def to_params(self) -> List[float]: float(self.calib_k2[0, 0]), ] + def position(self) -> torch.Tensor: + R = self.pose.data[:, :, :3].squeeze(0) + t = self.pose.data[:, :, 3].squeeze(0) + return -R.T @ t + @staticmethod def from_params(params: List[float], name: str = "Cam") -> "Camera": r = th.SO3.exp_map(torch.tensor(params[:3], dtype=torch.float64).unsqueeze(0)) @@ -278,6 +283,9 @@ def generate_synthetic( feat_random: float = 1.5, prob_feat_is_outlier: float = 0.02, outlier_feat_random: float = 70, + cam_pos_rand: float = 0.2, + cam_rot_rand: float = 0.1, + point_rand: float = 0.2, ): # add cameras @@ -290,7 +298,10 @@ def generate_synthetic( ) for i in range(num_cameras) ] - cameras = [cam.perturbed() for cam in gt_cameras] + cameras = [ + cam.perturbed(rot_random=cam_rot_rand, pos_random=cam_pos_rand) + for cam in gt_cameras + ] # add points gt_points = [ @@ -303,7 +314,7 @@ def generate_synthetic( ] points = [ th.Point3( - data=gt_points[i].data + (torch.rand((1, 3)) * 2 - 1) * 0.2, + data=gt_points[i].data + (torch.rand((1, 3)) * 2 - 1) * point_rand, name=gt_points[i].name + "_copy", ) for i in range(num_points) From 34133be0d6e9665455ebb6e765373335abed4749 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 16:12:43 +0100 Subject: [PATCH 21/64] backward modes --- theseus/optimizer/gbp/gbp.py | 66 +++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index a5694abde..e8465d5a8 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -5,6 +5,7 @@ import abc import math +import warnings from dataclasses import dataclass from itertools import count from typing import Dict, List, Optional, Sequence @@ -744,11 +745,7 @@ def _optimize_impl( f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" ) - grad = False if backward_mode == BackwardMode.FULL: - grad = True - - with torch.set_grad_enabled(grad): info = self._optimize_loop( start_iter=0, num_iter=self.params.max_iterations, @@ -768,3 +765,64 @@ def _optimize_impl( info.status == NonlinearOptimizerStatus.MAX_ITERATIONS ] = -1 return info + + elif backward_mode in [BackwardMode.IMPLICIT, BackwardMode.TRUNCATED]: + if backward_mode == BackwardMode.IMPLICIT: + backward_num_iterations = 1 + else: + if "backward_num_iterations" not in kwargs: + raise ValueError( + "backward_num_iterations expected but not received" + ) + if kwargs["backward_num_iterations"] > self.params.max_iterations: + warnings.warn( + f"Input backward_num_iterations " + f"(={kwargs['backward_num_iterations']}) > " + f"max_iterations (={self.params.max_iterations}). " + f"Using backward_num_iterations=max_iterations." + ) + backward_num_iterations = min( + kwargs["backward_num_iterations"], self.params.max_iterations + ) + + num_no_grad_iter = self.params.max_iterations - backward_num_iterations + with torch.no_grad(): + self._optimize_loop( + start_iter=0, + num_iter=num_no_grad_iter, + info=info, + verbose=verbose, + truncated_grad_loop=False, + relin_threshold=relin_threshold, + damping=damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + **kwargs, + ) + + grad_loop_info = self._init_info( + track_best_solution, track_err_history, verbose + ) + self._optimize_loop( + start_iter=0, + num_iter=backward_num_iterations, + info=grad_loop_info, + verbose=verbose, + truncated_grad_loop=True, + relin_threshold=relin_threshold, + damping=damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + **kwargs, + ) + + # Adds grad_loop_info results to original info + self._merge_infos( + grad_loop_info, num_no_grad_iter, backward_num_iterations, info + ) + + return info + else: + raise ValueError("Unrecognized backward mode") From 4263ef5cc3a69a3e7025291bc5961342a1eae922 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 16:49:51 +0100 Subject: [PATCH 22/64] msgs are class variables to fix implicit backward mode --- theseus/optimizer/gbp/gbp.py | 72 +++++++++++++++++------------------- 1 file changed, 33 insertions(+), 39 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index e8465d5a8..6f916408d 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -388,6 +388,30 @@ def __init__( var_ixs = [item for sublist in var_ixs_nested for item in sublist] self.var_ix_for_edges = torch.tensor(var_ixs).long() + # initialise messages with zeros + self.vtof_msgs: List[Message] = [] + self.ftov_msgs: List[Message] = [] + for cf in self.cf_ordering: + for var in cf.optim_vars: + # Set mean of initial message to identity of the group + # doesn't matter what it is as long as precision is zero + vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") + ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg.zero_message() + ftov_msg.zero_message() + self.vtof_msgs.append(vtof_msg) + self.ftov_msgs.append(ftov_msg) + + # initialise ManifoldGaussian for belief + self.beliefs: List[th.ManifoldGaussian] = [] + for var in self.ordering: + self.beliefs.append(th.ManifoldGaussian([var])) + + # compute factor potentials for the first time + self.factors: List[Factor] = [] + for cf in self.cf_ordering: + self.factors.append(Factor(cf)) + """ Copied and slightly modified from nonlinear optimizer class """ @@ -522,8 +546,6 @@ def _merge_infos( def _pass_var_to_fac_messages( self, - ftov_msgs, - vtof_msgs, update_belief=True, ): for i, var in enumerate(self.ordering): @@ -531,7 +553,7 @@ def _pass_var_to_fac_messages( # Collect all incoming messages in the tangent space at the current belief taus = [] # message means lams_tp = [] # message lams - for j, msg in enumerate(ftov_msgs): + for j, msg in enumerate(self.ftov_msgs): if self.var_ix_for_edges[j] == i: # print(msg.mean, msg.precision) tau, lam_tp = th.local_gaussian(var, msg, return_mean=True) @@ -545,14 +567,14 @@ def _pass_var_to_fac_messages( # Compute outgoing messages ix = 0 - for j, msg in enumerate(ftov_msgs): + for j, msg in enumerate(self.ftov_msgs): if self.var_ix_for_edges[j] == i: taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) lam_a = lams_inc.sum(dim=0) if lam_a.count_nonzero() == 0: - vtof_msgs[j].zero_message() + self.vtof_msgs[j].zero_message() else: inv_lam_a = torch.linalg.inv(lam_a) sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( @@ -560,7 +582,7 @@ def _pass_var_to_fac_messages( ) tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) new_mess = th.retract_gaussian(var, tau_a, lam_a) - vtof_msgs[j].update(new_mess.mean, new_mess.precision) + self.vtof_msgs[j].update(new_mess.mean, new_mess.precision) ix += 1 # update belief mean and variance @@ -575,8 +597,6 @@ def _pass_var_to_fac_messages( def _pass_fac_to_var_messages( self, - vtof_msgs, - ftov_msgs, schedule: torch.Tensor, damping: torch.Tensor, relin_threshold: float, @@ -595,8 +615,8 @@ def _pass_fac_to_var_messages( did_relin += [0] factor.comp_mess( - vtof_msgs[start : start + num_optim_vars], - ftov_msgs[start : start + num_optim_vars], + self.vtof_msgs[start : start + num_optim_vars], + self.ftov_msgs[start : start + num_optim_vars], damping[start : start + num_optim_vars], ) @@ -643,37 +663,15 @@ def _optimize_loop( f"but got {schedule.shape}." ) - # initialise messages with zeros - vtof_msgs: List[Message] = [] - ftov_msgs: List[Message] = [] - for cf in self.cf_ordering: - for var in cf.optim_vars: - # Set mean of initial message to identity of the group - # doesn't matter what it is as long as precision is zero - vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") - ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") - vtof_msg.zero_message() - ftov_msg.zero_message() - vtof_msgs.append(vtof_msg) - ftov_msgs.append(ftov_msg) - - # initialise ManifoldGaussian for belief - self.beliefs: List[th.ManifoldGaussian] = [] - for var in self.ordering: - self.beliefs.append(th.ManifoldGaussian([var])) - - # compute factor potentials for the first time - self.factors: List[Factor] = [] - for cf in self.cf_ordering: - self.factors.append(Factor(cf, lin_system_damping=lin_system_damping)) + for factor in self.factors: + factor.lin_system_damping = lin_system_damping self.belief_history = {} self.ftov_msgs_history = {} converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): - - self.ftov_msgs_history[it_] = [msg.copy() for msg in ftov_msgs] + self.ftov_msgs_history[it_] = [msg.copy() for msg in self.ftov_msgs] self.belief_history[it_] = [belief.copy() for belief in self.beliefs] # damping @@ -685,16 +683,12 @@ def _optimize_loop( damping_arr[dropout_ixs] = 1.0 relins = self._pass_fac_to_var_messages( - vtof_msgs, - ftov_msgs, schedule[it_], damping_arr, relin_threshold, ) self._pass_var_to_fac_messages( - ftov_msgs, - vtof_msgs, update_belief=True, ) From 7b1bf78df8e8ae0e9c37d5c794c34ae729dc6c18 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 10 Jun 2022 16:50:34 +0100 Subject: [PATCH 23/64] test different backward modes for pgo --- theseus/optimizer/gbp/pgo_test.py | 226 +++++++++++++++++------------- 1 file changed, 130 insertions(+), 96 deletions(-) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index 7bbd43d72..a9ecd7b4b 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -21,7 +21,7 @@ size = 3 dim = 2 -noise_cov = np.array([0.05, 0.05]) +noise_cov = np.array([0.01, 0.01]) prior_noise_std = 0.2 prior_sigma = np.array([1.3**2, 1.3**2]) @@ -31,128 +31,162 @@ # create theseus objective ------------------------------------- -objective = th.Objective() -inputs = {} -n_poses = size * size +def create_pgo(): -# create variables -poses = [] -for i in range(n_poses): - poses.append(th.Vector(data=torch.rand(1, 2), name=f"x{i}")) + objective = th.Objective() + inputs = {} -# add prior cost constraints with VariableDifference cost -prior_std = 1.3 -anchor_std = 0.01 -prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") -anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") + n_poses = size * size -gt_poses = [] + # create variables + poses = [] + for i in range(n_poses): + poses.append(th.Vector(data=torch.rand(1, 2), name=f"x{i}")) -p = 0 -for i in range(size): - for j in range(size): - init = torch.Tensor([j, i]) - gt_poses.append(init[None, :]) - if i == 0 and j == 0: - w = anchor_w - else: - # noise_init = torch.normal(torch.zeros(2), prior_noise_std) - init = init + torch.FloatTensor(init_noises[p]) - w = prior_w + # add prior cost constraints with VariableDifference cost + prior_std = 1.3 + anchor_std = 0.01 + prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") + anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") - prior_target = th.Vector(data=init, name=f"prior_{p}") - inputs[f"x{p}"] = init[None, :] - inputs[f"prior_{p}"] = init[None, :] + gt_poses = [] - cf_prior = th.eb.VariableDifference( - poses[p], w, prior_target, name=f"prior_cost_{p}" - ) + p = 0 + for i in range(size): + for j in range(size): + init = torch.Tensor([j, i]) + gt_poses.append(init[None, :]) + if i == 0 and j == 0: + w = anchor_w + else: + # noise_init = torch.normal(torch.zeros(2), prior_noise_std) + init = init + torch.FloatTensor(init_noises[p]) + w = prior_w - objective.add(cf_prior) + prior_target = th.Vector(data=init, name=f"prior_{p}") + inputs[f"x{p}"] = init[None, :] + inputs[f"prior_{p}"] = init[None, :] - p += 1 + cf_prior = th.eb.VariableDifference( + poses[p], w, prior_target, name=f"prior_cost_{p}" + ) -# Measurement cost functions + objective.add(cf_prior) -meas_std_tensor = torch.nn.Parameter(torch.tensor([0.1])) -meas_w = th.ScaleCostWeight(1 / meas_std_tensor, name="prior_weight") + p += 1 -m = 0 -for i in range(size): - for j in range(size): - if j < size - 1: - measurement = torch.Tensor([1.0, 0.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = i * size + j + 1 + # Measurement cost functions - meas = th.Vector(data=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] + meas_std_tensor = torch.nn.Parameter(torch.tensor([0.1])) + meas_w = th.ScaleCostWeight(1 / meas_std_tensor, name="prior_weight") - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 + m = 0 + for i in range(size): + for j in range(size): + if j < size - 1: + measurement = torch.Tensor([1.0, 0.0]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = i * size + j + 1 - if i < size - 1: - measurement = torch.Tensor([0.0, 1.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = (i + 1) * size + j + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] - meas = th.Vector(data=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) + m += 1 + + if i < size - 1: + measurement = torch.Tensor([0.0, 1.0]) + # measurement += torch.normal(torch.zeros(2), meas_std) + measurement += torch.FloatTensor(meas_noises[m]) + ix0 = i * size + j + ix1 = (i + 1) * size + j + + meas = th.Vector(data=measurement, name=f"meas_{m}") + inputs[f"meas_{m}"] = measurement[None, :] + + cf_meas = th.eb.Between( + poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" + ) + objective.add(cf_meas) + m += 1 + + return objective, gt_poses, meas_std_tensor, inputs + + +def linear_solve_pgo(): + print("\n\nLinear solver...\n") + + objective, gt_poses, meas_std_tensor, inputs = create_pgo() + + # outer optimizer + gt_poses_tensor = torch.cat(gt_poses) + lr = 1e-3 + outer_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) + outer_optimizer.zero_grad() + + linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) + th_layer = th.TheseusLayer(linear_optimizer) + outputs_linsolve, _ = th_layer.forward(inputs, {"verbose": True}) + + out_ls_tensor = torch.cat(list(outputs_linsolve.values())) + loss = torch.norm(gt_poses_tensor - out_ls_tensor) + loss.backward() + + print("loss", loss.item()) + print("grad", meas_std_tensor.grad.item()) + + print("outputs\n", outputs_linsolve) - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas_w, meas, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 +def gbp_solve_pgo(backward_mode, max_iterations=20): + print("\n\nWith GBP...") + print("backward mode:", backward_mode, "\n") -# outer optimizer -lr = 1e-3 -model_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) + objective, gt_poses, meas_std_tensor, inputs = create_pgo() + gt_poses_tensor = torch.cat(gt_poses) + lr = 1e-3 + outer_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) + outer_optimizer.zero_grad() -linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) -th_layer = th.TheseusLayer(linear_optimizer) -outputs, _ = th_layer.forward(inputs) + optimizer = GaussianBeliefPropagation( + objective, + max_iterations=max_iterations, + ) + theseus_optim = th.TheseusLayer(optimizer) + optim_arg = { + "verbose": True, + # "track_best_solution": True, + # "track_err_history": True, + "backward_mode": backward_mode, + "backward_num_iterations": 5, + "relin_threshold": 1e-8, + "damping": 0.0, + "dropout": 0.0, + "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), + } -gt_poses_tensor = torch.cat(gt_poses) -output_poses = torch.cat([x.data for x in poses]) + outputs_gbp, info = theseus_optim.forward(inputs, optim_arg) -loss = torch.norm(gt_poses_tensor - output_poses) -loss.backward() + out_gbp_tensor = torch.cat(list(outputs_gbp.values())) + loss = torch.norm(gt_poses_tensor - out_gbp_tensor) + loss.backward() -# da_dx = torch.autograd.grad(loss, data_x, retain_graph=True)[0].squeeze() -# print("\n--- backward_mode=IMPLICIT") -# print(da_dx.numpy()) + print("loss", loss.item()) + print("grad", meas_std_tensor.grad.item()) -max_iterations = 100 -optimizer = GaussianBeliefPropagation( - objective, - max_iterations=max_iterations, -) -theseus_optim = th.TheseusLayer(optimizer) + print("outputs\n", outputs_gbp) -optim_arg = { - "track_best_solution": True, - "track_err_history": True, - "verbose": True, - "backward_mode": th.BackwardMode.FULL, - "damping": 0.0, - "dropout": 0.0, - "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), -} +linear_solve_pgo() -updated_inputs, info = theseus_optim.forward(inputs, optim_arg) +gbp_solve_pgo(backward_mode=th.BackwardMode.FULL, max_iterations=20) -print("gbp outputs\n", updated_inputs) -print("linear solver\n", updated_inputs) +gbp_solve_pgo(backward_mode=th.BackwardMode.TRUNCATED, max_iterations=20) From 6c07c347c3c0a648626a4f1a9e9d65a2cee1ebce Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Sat, 11 Jun 2022 11:38:38 +0100 Subject: [PATCH 24/64] fixed copy_impl for reprojection error fn --- .../bundle_adjustment/reprojection_error.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/theseus/utils/examples/bundle_adjustment/reprojection_error.py b/theseus/utils/examples/bundle_adjustment/reprojection_error.py index cf0983a92..03e2b60b9 100644 --- a/theseus/utils/examples/bundle_adjustment/reprojection_error.py +++ b/theseus/utils/examples/bundle_adjustment/reprojection_error.py @@ -112,13 +112,13 @@ def dim(self) -> int: def to(self, *args, **kwargs): super().to(*args, **kwargs) - def _copy_impl(self): + def _copy_impl(self, new_name: Optional[str] = None): return Reprojection( - self.camera_pose, - self.world_point, - self.log_loss_radius, - self.focal_length, - self.image_feature_point, - weight=self.weight, - name=self.name, + self.camera_pose.copy(), + self.world_point.copy(), + self.log_loss_radius.copy(), + self.focal_length.copy(), + self.image_feature_point.copy(), + weight=self.weight.copy(), + name=new_name, ) From d2bd2aef296d835ce61ab58e3b7ed359976344b3 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Sun, 12 Jun 2022 22:36:12 +0100 Subject: [PATCH 25/64] fix order of args in copy fn --- theseus/optimizer/linear/dense_solver.py | 3 --- theseus/utils/examples/bundle_adjustment/reprojection_error.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/theseus/optimizer/linear/dense_solver.py b/theseus/optimizer/linear/dense_solver.py index d2008d9ea..ab9d74cb7 100644 --- a/theseus/optimizer/linear/dense_solver.py +++ b/theseus/optimizer/linear/dense_solver.py @@ -88,9 +88,6 @@ def solve( if self._check_singular: AtA = self.linearization.AtA Atb = self.linearization.Atb - import ipdb - - ipdb.set_trace() with torch.no_grad(): output = torch.zeros(AtA.shape[0], AtA.shape[1]).to(AtA.device) _, _, infos = torch.lu(AtA, get_infos=True) diff --git a/theseus/utils/examples/bundle_adjustment/reprojection_error.py b/theseus/utils/examples/bundle_adjustment/reprojection_error.py index 03e2b60b9..ed811ba74 100644 --- a/theseus/utils/examples/bundle_adjustment/reprojection_error.py +++ b/theseus/utils/examples/bundle_adjustment/reprojection_error.py @@ -117,8 +117,8 @@ def _copy_impl(self, new_name: Optional[str] = None): self.camera_pose.copy(), self.world_point.copy(), self.log_loss_radius.copy(), - self.focal_length.copy(), self.image_feature_point.copy(), + self.focal_length.copy(), weight=self.weight.copy(), name=new_name, ) From 9f6eb09e50b090be4506b26640755c930c45f084 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Sun, 12 Jun 2022 22:37:43 +0100 Subject: [PATCH 26/64] used vectorization for part of relin, rename VariableDifference --- theseus/optimizer/gbp/ba_test.py | 10 +- theseus/optimizer/gbp/gbp.py | 212 ++++++++++++++++-------------- theseus/optimizer/gbp/pgo_test.py | 2 +- 3 files changed, 119 insertions(+), 105 deletions(-) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 9985f35bc..02d59f43f 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -92,7 +92,6 @@ def run(cfg: omegaconf.OmegaConf): print("Setting up objective") objective = th.Objective(dtype=torch.float64) - print("obs") for i, obs in enumerate(ba.observations): # print(i, len(ba.observations)) cam = ba.cameras[obs.camera_index] @@ -109,7 +108,6 @@ def run(cfg: omegaconf.OmegaConf): dtype = objective.dtype # Add regularization - print("reg") if cfg["inner_optim"]["regularize"]: zero_point3 = th.Point3(dtype=dtype, name="zero_point") # identity_se3 = th.SE3(dtype=dtype, name="zero_se3") @@ -126,7 +124,7 @@ def run(cfg: omegaconf.OmegaConf): else: assert False objective.add( - th.eb.VariableDifference( + th.Difference( var, damping_weight, target, name=f"reg_{name}" ) ) @@ -142,7 +140,7 @@ def run(cfg: omegaconf.OmegaConf): continue print("fixing cam", i) objective.add( - th.eb.VariableDifference( + th.Difference( camera_pose_vars[i], camera_weight, ba.gt_cameras[i].pose, @@ -217,8 +215,8 @@ def run(cfg: omegaconf.OmegaConf): cfg = { "seed": 1, - "num_cameras": 5, - "num_points": 10, + "num_cameras": 10, + "num_points": 100, "average_track_length": 8, "track_locality": 0.2, "optimizer_cls": GaussianBeliefPropagation, diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 6f916408d..a022a28bd 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -9,6 +9,7 @@ from dataclasses import dataclass from itertools import count from typing import Dict, List, Optional, Sequence +import time import numpy as np import torch @@ -94,56 +95,56 @@ def zero_message(self): self.update(mean=new_mean, precision=new_precision) -class CostFunctionOrdering: - def __init__(self, objective: Objective, default_order: bool = True): - self.objective = objective - self._cf_order: List[CostFunction] = [] - self._cf_name_to_index: Dict[str, int] = {} - if default_order: - self._compute_default_order(objective) - - def _compute_default_order(self, objective: Objective): - assert not self._cf_order and not self._cf_name_to_index - cur_idx = 0 - for cf_name, cf in objective.cost_functions.items(): - if cf_name in self._cf_name_to_index: - continue - self._cf_order.append(cf) - self._cf_name_to_index[cf_name] = cur_idx - cur_idx += 1 - - def index_of(self, key: str) -> int: - return self._cf_name_to_index[key] - - def __getitem__(self, index) -> CostFunction: - return self._cf_order[index] - - def __iter__(self): - return iter(self._cf_order) - - def append(self, cf: CostFunction): - if cf in self._cf_order: - raise ValueError( - f"Cost Function {cf.name} has already been added to the order." - ) - if cf.name not in self.objective.cost_functions: - raise ValueError( - f"Cost Function {cf.name} is not a cost function for the objective." - ) - self._cf_order.append(cf) - self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 - - def remove(self, cf: CostFunction): - self._cf_order.remove(cf) - del self._cf_name_to_index[cf.name] - - def extend(self, cfs: Sequence[CostFunction]): - for cf in cfs: - self.append(cf) - - @property - def complete(self): - return len(self._cf_order) == self.objective.size_variables() +# class CostFunctionOrdering: +# def __init__(self, objective: Objective, default_order: bool = True): +# self.objective = objective +# self._cf_order: List[CostFunction] = [] +# self._cf_name_to_index: Dict[str, int] = {} +# if default_order: +# self._compute_default_order(objective) + +# def _compute_default_order(self, objective: Objective): +# assert not self._cf_order and not self._cf_name_to_index +# cur_idx = 0 +# for cf_name, cf in objective.cost_functions.items(): +# if cf_name in self._cf_name_to_index: +# continue +# self._cf_order.append(cf) +# self._cf_name_to_index[cf_name] = cur_idx +# cur_idx += 1 + +# def index_of(self, key: str) -> int: +# return self._cf_name_to_index[key] + +# def __getitem__(self, index) -> CostFunction: +# return self._cf_order[index] + +# def __iter__(self): +# return iter(self._cf_order) + +# def append(self, cf: CostFunction): +# if cf in self._cf_order: +# raise ValueError( +# f"Cost Function {cf.name} has already been added to the order." +# ) +# if cf.name not in self.objective.cost_functions: +# raise ValueError( +# f"Cost Function {cf.name} is not a cost function for the objective." +# ) +# self._cf_order.append(cf) +# self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 + +# def remove(self, cf: CostFunction): +# self._cf_order.remove(cf) +# del self._cf_name_to_index[cf.name] + +# def extend(self, cfs: Sequence[CostFunction]): +# for cf in cfs: +# self.append(cf) + +# @property +# def complete(self): +# return len(self._cf_order) == self.objective.size_variables() """ @@ -372,46 +373,21 @@ def __init__( # ordering is required to identify which messages to send where self.ordering = VariableOrdering(objective, default_order=True) - self.cf_ordering = CostFunctionOrdering(objective) self.params = GBPOptimizerParams( abs_err_tolerance, rel_err_tolerance, max_iterations ) - self.n_edges = sum([cf.num_optim_vars() for cf in self.cf_ordering]) + self.n_edges = sum([cf.num_optim_vars() for cf in self.objective.cost_functions.values()]) # create array for indexing the messages var_ixs_nested = [ [self.ordering.index_of(var.name) for var in cf.optim_vars] - for cf in self.cf_ordering + for cf in self.objective.cost_functions.values() ] var_ixs = [item for sublist in var_ixs_nested for item in sublist] self.var_ix_for_edges = torch.tensor(var_ixs).long() - # initialise messages with zeros - self.vtof_msgs: List[Message] = [] - self.ftov_msgs: List[Message] = [] - for cf in self.cf_ordering: - for var in cf.optim_vars: - # Set mean of initial message to identity of the group - # doesn't matter what it is as long as precision is zero - vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") - ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") - vtof_msg.zero_message() - ftov_msg.zero_message() - self.vtof_msgs.append(vtof_msg) - self.ftov_msgs.append(ftov_msg) - - # initialise ManifoldGaussian for belief - self.beliefs: List[th.ManifoldGaussian] = [] - for var in self.ordering: - self.beliefs.append(th.ManifoldGaussian([var])) - - # compute factor potentials for the first time - self.factors: List[Factor] = [] - for cf in self.cf_ordering: - self.factors.append(Factor(cf)) - """ Copied and slightly modified from nonlinear optimizer class """ @@ -595,24 +571,39 @@ def _pass_var_to_fac_messages( new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _pass_fac_to_var_messages( - self, - schedule: torch.Tensor, - damping: torch.Tensor, - relin_threshold: float, - ): + + def _linearize_factors(self, relin_threshold: float): relins = 0 did_relin = [] - start = 0 - for factor in self.factors: - num_optim_vars = factor.cf.num_optim_vars() + start = time.time() + # compute weighted error and jacobian for all factors + self.objective.update_vectorization() + print('vectorized update time', time.time() - start) + + start = time.time() + for factor in self.factors: factor.linearize(relin_threshold=relin_threshold) if factor.steps_since_lin == 0: relins += 1 did_relin += [1] else: did_relin += [0] + print('compute factor time', time.time() - start) + + # print(f"Factor relinearisations: {relins} / {len(self.factors)}") + return relins + + def _pass_fac_to_var_messages( + self, + schedule: torch.Tensor, + damping: torch.Tensor, + ): + start = 0 + for factor in self.factors: + num_optim_vars = factor.cf.num_optim_vars() + + factor.comp_mess( self.vtof_msgs[start : start + num_optim_vars], @@ -622,8 +613,6 @@ def _pass_fac_to_var_messages( start += num_optim_vars - # print(f"Factor relinearisations: {relins} / {len(self.factors)}") - return relins """ Optimization loop functions @@ -642,6 +631,7 @@ def _optimize_loop( dropout: float, schedule: torch.Tensor, lin_system_damping: float, + clear_messages: bool = True, **kwargs, ): if damping > 1.0 or damping < 0.0: @@ -663,8 +653,31 @@ def _optimize_loop( f"but got {schedule.shape}." ) - for factor in self.factors: - factor.lin_system_damping = lin_system_damping + if clear_messages: + # initialise messages with zeros + self.vtof_msgs: List[Message] = [] + self.ftov_msgs: List[Message] = [] + for cf in self.objective.cost_functions.values(): + for var in cf.optim_vars: + # Set mean of initial message to identity of the group + # doesn't matter what it is as long as precision is zero + vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") + ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg.zero_message() + ftov_msg.zero_message() + self.vtof_msgs.append(vtof_msg) + self.ftov_msgs.append(ftov_msg) + + # initialise ManifoldGaussian for belief + self.beliefs: List[th.ManifoldGaussian] = [] + for var in self.ordering: + self.beliefs.append(th.ManifoldGaussian([var])) + + # compute factor potentials for the first time + self.factors: List[Factor] = [] + for cost_function in self.objective._get_iterator(): + self.factors.append(Factor(cost_function, lin_system_damping)) + self.belief_history = {} self.ftov_msgs_history = {} @@ -682,15 +695,17 @@ def _optimize_loop( dropout_ixs = torch.rand(self.n_edges) < dropout damping_arr[dropout_ixs] = 1.0 - relins = self._pass_fac_to_var_messages( - schedule[it_], - damping_arr, - relin_threshold, - ) + t1 = time.time() + relins = self._linearize_factors(relin_threshold) + print("relin time", time.time() - t1) - self._pass_var_to_fac_messages( - update_belief=True, - ) + t1 = time.time() + self._pass_fac_to_var_messages(schedule[it_], damping_arr) + # print("ftov time", time.time() - t1) + + t1 = time.time() + self._pass_var_to_fac_messages(update_belief=True) + # print("vtof time", time.time() - t1) # check for convergence if it_ > 0: @@ -809,6 +824,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + clear_messages=False, **kwargs, ) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index a9ecd7b4b..a7fe1899c 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -68,7 +68,7 @@ def create_pgo(): inputs[f"x{p}"] = init[None, :] inputs[f"prior_{p}"] = init[None, :] - cf_prior = th.eb.VariableDifference( + cf_prior = th.Difference( poses[p], w, prior_target, name=f"prior_cost_{p}" ) From 78bb12001f996e83afe450a3b8045c41b86a72e8 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 4 Jul 2022 11:28:32 +0100 Subject: [PATCH 27/64] handles batched problems --- theseus/optimizer/gbp/ba_test.py | 7 +- theseus/optimizer/gbp/gbp.py | 109 +++++++++++++--------- theseus/optimizer/gbp/gbp_baseline.py | 14 +-- theseus/optimizer/gbp/vectorize_test.py | 116 ++++++++++++++++++++++++ 4 files changed, 187 insertions(+), 59 deletions(-) create mode 100644 theseus/optimizer/gbp/vectorize_test.py diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 02d59f43f..62a03d008 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -124,9 +124,7 @@ def run(cfg: omegaconf.OmegaConf): else: assert False objective.add( - th.Difference( - var, damping_weight, target, name=f"reg_{name}" - ) + th.Difference(var, damping_weight, target, name=f"reg_{name}") ) camera_pose_vars: List[th.LieGroup] = [ @@ -157,6 +155,9 @@ def run(cfg: omegaconf.OmegaConf): ) theseus_optim = th.TheseusLayer(optimizer) + # device = "cuda" if torch.cuda.is_available() else "cpu" + # theseus_optim.to(device) + optim_arg = { "track_best_solution": True, "track_err_history": True, diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index a022a28bd..e8a617e3f 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -5,11 +5,11 @@ import abc import math +import time import warnings from dataclasses import dataclass from itertools import count from typing import Dict, List, Optional, Sequence -import time import numpy as np import torch @@ -28,7 +28,6 @@ """ TODO - solving inverse problem to compute message mean - - handle batch dim """ @@ -82,14 +81,19 @@ def __init__( # sets mean to the group identity and zero precision matrix def zero_message(self): new_mean = [] + batch_size = self.mean[0].shape[0] for var in self.mean: if var.__class__ == th.Vector: new_mean_i = var.__class__(var.dof()) else: new_mean_i = var.__class__() + repeats = torch.ones(var.ndim, dtype=int) + repeats[0] = batch_size + repeats[0] = batch_size + new_mean_i = new_mean_i.data.repeat(repeats.tolist()) new_mean_i.to(dtype=self.dtype, device=self.device) new_mean.append(new_mean_i) - new_precision = torch.zeros(self.mean[0].shape[0], self.dof, self.dof).to( + new_precision = torch.zeros(batch_size, self.dof, self.dof).to( dtype=self.dtype, device=self.device ) self.update(mean=new_mean, precision=new_precision) @@ -237,8 +241,8 @@ def comp_mess( sdim = 0 for v in range(num_optim_vars): - eta_factor = self.potential_eta.clone()[0] - lam_factor = self.potential_lam.clone()[0] + eta_factor = self.potential_eta.clone() + lam_factor = self.potential_lam.clone() lam_factor_copy = lam_factor.clone() # Take product of factor with incoming messages. @@ -250,63 +254,70 @@ def comp_mess( eta_mess, lam_mess = th.local_gaussian( self.lin_point[i], vtof_msgs[i], return_mean=False ) - eta_factor[start : start + var_dofs] += eta_mess[0] + eta_factor[:, start : start + var_dofs] += eta_mess lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess[0] + :, start : start + var_dofs, start : start + var_dofs + ] += lam_mess start += var_dofs dofs = self.cf.optim_var_at(v).dof() if torch.allclose(lam_factor, lam_factor_copy) and num_optim_vars > 1: - # print(self.cf.name, '---> not updating as incoming message lams are zeros') + print( + self.cf.name, "---> not updating as incoming message lams are zeros" + ) new_mess = Message([self.cf.optim_var_at(v).copy()]) new_mess.zero_message() else: - # print(self.cf.name, '---> sending message') + print(self.cf.name, "---> sending message") # Divide up parameters of distribution - eo = eta_factor[sdim : sdim + dofs] - eno = torch.cat((eta_factor[:sdim], eta_factor[sdim + dofs :])) - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] + eo = eta_factor[:, sdim : sdim + dofs] + eno = torch.cat( + (eta_factor[:, :sdim], eta_factor[:, sdim + dofs :]), dim=1 + ) + + loo = lam_factor[:, sdim : sdim + dofs, sdim : sdim + dofs] lono = torch.cat( ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], + lam_factor[:, sdim : sdim + dofs, :sdim], + lam_factor[:, sdim : sdim + dofs, sdim + dofs :], ), - dim=1, + dim=2, ) lnoo = torch.cat( ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], + lam_factor[:, :sdim, sdim : sdim + dofs], + lam_factor[:, sdim + dofs :, sdim : sdim + dofs], ), - dim=0, + dim=1, ) lnono = torch.cat( ( torch.cat( ( - lam_factor[:sdim, :sdim], - lam_factor[:sdim, sdim + dofs :], + lam_factor[:, :sdim, :sdim], + lam_factor[:, :sdim, sdim + dofs :], ), - dim=1, + dim=2, ), torch.cat( ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], + lam_factor[:, sdim + dofs :, :sdim], + lam_factor[:, sdim + dofs :, sdim + dofs :], ), - dim=1, + dim=2, ), ), - dim=0, + dim=1, ) new_mess_lam = loo - lono @ torch.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - lono @ torch.linalg.inv(lnono) @ eno + new_mess_eta = eo - torch.bmm( + torch.bmm(lono, torch.linalg.inv(lnono)), eno.unsqueeze(-1) + ).squeeze(-1) # damping in tangent space at linearisation point as message # is already in this tangent space. Could equally do damping @@ -321,20 +332,23 @@ def comp_mess( self.lin_point[v], ftov_msgs[v], return_mean=True ) - new_mess_mean = torch.matmul( - torch.inverse(new_mess_lam), new_mess_eta - ) + new_mess_mean = torch.bmm( + torch.inverse(new_mess_lam), new_mess_eta.unsqueeze(-1) + ).squeeze(-1) new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ v - ] * prev_mess_mean[0] - new_mess_eta = torch.matmul(new_mess_lam, new_mess_mean) + ] * prev_mess_mean + new_mess_eta = torch.bmm( + new_mess_lam, new_mess_mean.unsqueeze(-1) + ).squeeze(-1) new_mess_lam = th.DenseSolver._apply_damping( - new_mess_lam[None, ...], + new_mess_lam, self.lin_system_damping, ellipsoidal=True, eps=1e-8, ) + new_mess_mean = th.LUDenseSolver._solve_sytem( new_mess_eta[..., None], new_mess_lam ) @@ -378,7 +392,9 @@ def __init__( abs_err_tolerance, rel_err_tolerance, max_iterations ) - self.n_edges = sum([cf.num_optim_vars() for cf in self.objective.cost_functions.values()]) + self.n_edges = sum( + [cf.num_optim_vars() for cf in self.objective.cost_functions.values()] + ) # create array for indexing the messages var_ixs_nested = [ @@ -571,7 +587,6 @@ def _pass_var_to_fac_messages( new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _linearize_factors(self, relin_threshold: float): relins = 0 did_relin = [] @@ -579,7 +594,7 @@ def _linearize_factors(self, relin_threshold: float): start = time.time() # compute weighted error and jacobian for all factors self.objective.update_vectorization() - print('vectorized update time', time.time() - start) + print("vectorized update time", time.time() - start) start = time.time() for factor in self.factors: @@ -589,7 +604,7 @@ def _linearize_factors(self, relin_threshold: float): did_relin += [1] else: did_relin += [0] - print('compute factor time', time.time() - start) + print("compute factor time", time.time() - start) # print(f"Factor relinearisations: {relins} / {len(self.factors)}") return relins @@ -603,8 +618,6 @@ def _pass_fac_to_var_messages( for factor in self.factors: num_optim_vars = factor.cf.num_optim_vars() - - factor.comp_mess( self.vtof_msgs[start : start + num_optim_vars], self.ftov_msgs[start : start + num_optim_vars], @@ -613,7 +626,6 @@ def _pass_fac_to_var_messages( start += num_optim_vars - """ Optimization loop functions """ @@ -661,8 +673,12 @@ def _optimize_loop( for var in cf.optim_vars: # Set mean of initial message to identity of the group # doesn't matter what it is as long as precision is zero - vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") - ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg = Message( + [var.copy()], name=f"msg_{var.name}_to_{cf.name}" + ) + ftov_msg = Message( + [var.copy()], name=f"msg_{cf.name}_to_{var.name}" + ) vtof_msg.zero_message() ftov_msg.zero_message() self.vtof_msgs.append(vtof_msg) @@ -676,8 +692,13 @@ def _optimize_loop( # compute factor potentials for the first time self.factors: List[Factor] = [] for cost_function in self.objective._get_iterator(): - self.factors.append(Factor(cost_function, lin_system_damping)) - + self.factors.append( + Factor( + cost_function, + name=cost_function.name, + lin_system_damping=lin_system_damping, + ) + ) self.belief_history = {} self.ftov_msgs_history = {} diff --git a/theseus/optimizer/gbp/gbp_baseline.py b/theseus/optimizer/gbp/gbp_baseline.py index fef31b412..ea81e1a62 100644 --- a/theseus/optimizer/gbp/gbp_baseline.py +++ b/theseus/optimizer/gbp/gbp_baseline.py @@ -755,12 +755,8 @@ def draw(i): # fg.compute_all_messages() - import ipdb - - ipdb.set_trace() - # i = 0 - n_iters = 5 + n_iters = 20 while i <= n_iters: # img = draw(i) # cv2.imshow('img', img) @@ -776,10 +772,4 @@ def draw(i): # for m in f.messages: # print(np.linalg.inv(m.lam) @ m.eta) - print(fg.belief_means()) - - import ipdb - - ipdb.set_trace() - - # time.sleep(0.05) + print(fg.belief_means()) diff --git a/theseus/optimizer/gbp/vectorize_test.py b/theseus/optimizer/gbp/vectorize_test.py new file mode 100644 index 000000000..6f434ef9e --- /dev/null +++ b/theseus/optimizer/gbp/vectorize_test.py @@ -0,0 +1,116 @@ +import torch + +import theseus as th +from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule + +torch.manual_seed(0) + + +def generate_data(num_points=100, a=1, b=0.5, noise_factor=0.01): + # Generate data: 100 points sampled from the quadratic curve listed above + data_x = torch.rand((1, num_points)) + noise = torch.randn((1, num_points)) * noise_factor + data_y = a * data_x.square() + b + noise + return data_x, data_y + + +def generate_learning_data(num_points, num_models): + a, b = 3, 1 + data_batches = [] + for i in range(num_models): + b = b + 2 + data = generate_data(num_points, a, b) + data_batches.append(data) + return data_batches + + +num_models = 10 +data_batches = generate_learning_data(100, num_models) + + +# updated error function reflects change in 'a' +def quad_error_fn2(optim_vars, aux_vars): + [a, b] = optim_vars + x, y = aux_vars + est = a.data * x.data.square() + b.data + err = y.data - est + return err + + +# The theseus_inputs dictionary is also constructed similarly to before, +# but with data matching the new shapes of the variables +def construct_theseus_layer_inputs(): + theseus_inputs = {} + theseus_inputs.update( + { + "x": data_x, + "y": data_y, + "b": torch.ones((num_models, 1)), + "a": a_tensor, + } + ) + return theseus_inputs + + +# convert data_x, data_y into torch.tensors of shape [num_models, 100] +data_x = torch.stack([data_x.squeeze() for data_x, _ in data_batches]) +data_y = torch.stack([data_y.squeeze() for _, data_y in data_batches]) + +# construct one variable each of x, y of shape [num_models, 100] +x = th.Variable(data_x, name="x") +y = th.Variable(data_y, name="y") + +# construct a as before +a = th.Vector(data=torch.rand(num_models, 1), name="a") + +# construct one variable b, now of shape [num_models, 1] +b = th.Vector(data=torch.rand(num_models, 1), name="b") + +# Again, 'b' is the only optim_var, and 'a' is part of aux_vars along with x, y +aux_vars = [x, y] + +# cost function constructed as before +cost_function = th.AutoDiffCostFunction( + [a, b], quad_error_fn2, 100, aux_vars=aux_vars, name="quadratic_cost_fn" +) + +prior_weight = th.ScaleCostWeight(torch.ones(1)) +prior_a = th.Difference(a, prior_weight, th.Vector(1)) +prior_b = th.Difference(b, prior_weight, th.Vector(1)) + +# objective, optimizer and theseus layer constructed as before +objective = th.Objective() +objective.add(cost_function) +objective.add(prior_a) +objective.add(prior_b) + +print([cf.name for cf in objective.cost_functions.values()]) + +optimizer = GaussianBeliefPropagation( + objective, + max_iterations=50, # step_size=0.5, +) + +theseus_optim = th.TheseusLayer(optimizer, vectorize=False) + +a_tensor = torch.nn.Parameter(torch.rand(num_models, 1)) + + +optim_arg = { + "track_best_solution": True, + "track_err_history": True, + "verbose": True, + "backward_mode": th.BackwardMode.FULL, + "relin_threshold": 0.0000000001, + "damping": 0.5, + "dropout": 0.0, + "schedule": synchronous_schedule(50, optimizer.n_edges), + "lin_system_damping": 1e-5, +} + + +theseus_inputs = construct_theseus_layer_inputs() +print("inputs\n", theseus_inputs["a"], theseus_inputs["x"].shape) +updated_inputs, _ = theseus_optim.forward(theseus_inputs, optim_arg) + +print(updated_inputs) From c1b244977e2d97ef2c92afc636fd1cbd59fd9c04 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Tue, 5 Jul 2022 13:32:35 +0100 Subject: [PATCH 28/64] vectorized relinearization and ftov msg passing, schedule class --- theseus/core/objective.py | 7 +- theseus/core/vectorizer.py | 20 +- theseus/optimizer/gbp/__init__.py | 2 +- theseus/optimizer/gbp/ba_test.py | 15 +- theseus/optimizer/gbp/gbp.py | 308 +++++++++++++----------- theseus/optimizer/gbp/pgo_test.py | 9 +- theseus/optimizer/gbp/vectorize_test.py | 7 +- 7 files changed, 204 insertions(+), 164 deletions(-) diff --git a/theseus/core/objective.py b/theseus/core/objective.py index b92e93313..532d4e51a 100644 --- a/theseus/core/objective.py +++ b/theseus/core/objective.py @@ -70,6 +70,9 @@ def __init__(self, dtype: Optional[torch.dtype] = None): self._vectorization_to: Optional[Callable] = None + self.vectorized_cost_fns: Optional[List[CostFunction]] = None + self.vectorized_msg_ixs: Optional[List[List[int]]] = None + def _add_function_variables( self, function: TheseusFunction, @@ -479,11 +482,11 @@ def _get_batch_size(batch_sizes: Sequence[int]) -> int: batch_sizes.extend([v.data.shape[0] for v in self.aux_vars.values()]) self._batch_size = _get_batch_size(batch_sizes) - def update_vectorization(self): + def update_vectorization(self, compute_caches=True): if self._vectorization_run is not None: if self._batch_size is None: self.update() - self._vectorization_run() + self._vectorization_run(compute_caches=compute_caches) # iterates over cost functions def __iter__(self): diff --git a/theseus/core/vectorizer.py b/theseus/core/vectorizer.py index 446efd58c..0cdf1ae9a 100644 --- a/theseus/core/vectorizer.py +++ b/theseus/core/vectorizer.py @@ -92,13 +92,19 @@ def __init__(self, objective: Objective): _CostFunctionSchema, List[_CostFunctionWrapper] ] = defaultdict(list) + schema_ixs_dict: Dict[_CostFunctionSchema, List[int]] = defaultdict(list) + # Create wrappers for all cost functions and also get their schemas + msg_ix = 0 for cost_fn in objective.cost_functions.values(): wrapper = _CostFunctionWrapper(cost_fn) self._cost_fn_wrappers.append(wrapper) schema = _get_cost_function_schema(cost_fn) self._schema_dict[schema].append(wrapper) + schema_ixs_dict[schema].append(msg_ix) + msg_ix += cost_fn.num_optim_vars() + # Now create a vectorized cost function for each unique schema self._vectorized_cost_fns: Dict[_CostFunctionSchema, CostFunction] = {} for schema in self._schema_dict: @@ -119,6 +125,8 @@ def __init__(self, objective: Objective): objective._cost_functions_iterable = self._cost_fn_wrappers objective._vectorization_run = self._vectorize objective._vectorization_to = self._to + objective.vectorized_cost_fns = list(self._vectorized_cost_fns.values()) + objective.vectorized_msg_ixs = list(schema_ixs_dict.values()) self._objective = objective @@ -282,8 +290,9 @@ def _clear_wrapper_caches(self): cf._cached_error = None cf._cached_jacobians = None - def _vectorize(self): - self._clear_wrapper_caches() + def _vectorize(self, compute_caches=True): + if compute_caches: + self._clear_wrapper_caches() for schema, cost_fn_wrappers in self._schema_dict.items(): var_names = self._var_names[schema] vectorized_cost_fn = self._vectorized_cost_fns[schema] @@ -302,9 +311,10 @@ def _vectorize(self): batch_size, len(cost_fn_wrappers), ) - Vectorize._compute_error_and_replace_wrapper_caches( - vectorized_cost_fn, cost_fn_wrappers, batch_size - ) + if compute_caches: + Vectorize._compute_error_and_replace_wrapper_caches( + vectorized_cost_fn, cost_fn_wrappers, batch_size + ) # Applies to() with given args to all vectorized cost functions in the objective def _to(self, *args, **kwargs): diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py index 5d308bde2..53c57ea55 100644 --- a/theseus/optimizer/gbp/__init__.py +++ b/theseus/optimizer/gbp/__init__.py @@ -4,4 +4,4 @@ # LICENSE file in the root directory of this source tree. from .ba_viewer import BAViewer -from .gbp import GaussianBeliefPropagation, random_schedule, synchronous_schedule +from .gbp import GaussianBeliefPropagation, GBPSchedule diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 62a03d008..46bce313e 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -11,11 +11,7 @@ import theseus as th import theseus.utils.examples as theg -from theseus.optimizer.gbp import ( - BAViewer, - GaussianBeliefPropagation, - synchronous_schedule, -) +from theseus.optimizer.gbp import BAViewer, GaussianBeliefPropagation, GBPSchedule # Smaller values result in error th.SO3.SO3_EPS = 1e-6 @@ -47,7 +43,9 @@ def camera_loss( ) -> torch.Tensor: loss: torch.Tensor = 0 # type:ignore for i in range(len(ba.cameras)): - camera_loss = th.local(camera_pose_vars[i], ba.gt_cameras[i].pose).norm(dim=1) + cam_pose = camera_pose_vars[i].copy() + cam_pose.to(ba.gt_cameras[i].pose.device) + camera_loss = th.local(cam_pose, ba.gt_cameras[i].pose).norm(dim=1).cpu() loss += camera_loss return loss @@ -169,10 +167,9 @@ def run(cfg: omegaconf.OmegaConf): "relin_threshold": 0.0000000001, "damping": 0.0, "dropout": 0.0, - "schedule": synchronous_schedule( - cfg["inner_optim"]["max_iters"], optimizer.n_edges - ), + "schedule": GBPSchedule.SYNCHRONOUS, "lin_system_damping": 1e-5, + "vectorize": True, } optim_arg = {**optim_arg, **gbp_optim_arg} diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index e8a617e3f..7b50d06d0 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -8,6 +8,7 @@ import time import warnings from dataclasses import dataclass +from enum import Enum from itertools import count from typing import Dict, List, Optional, Sequence @@ -51,6 +52,11 @@ def update(self, params_dict): raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") +class GBPSchedule(Enum): + SYNCHRONOUS = 0 + RANDOM = 1 + + def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: return torch.full([max_iters, n_edges], True) @@ -76,7 +82,6 @@ def __init__( dtype=mean[0].dtype, device=mean[0].device ) super(Message, self).__init__(mean, precision=precision, name=name) - assert dof == self.dof # sets mean to the group identity and zero precision matrix def zero_message(self): @@ -89,9 +94,8 @@ def zero_message(self): new_mean_i = var.__class__() repeats = torch.ones(var.ndim, dtype=int) repeats[0] = batch_size - repeats[0] = batch_size new_mean_i = new_mean_i.data.repeat(repeats.tolist()) - new_mean_i.to(dtype=self.dtype, device=self.device) + new_mean_i = new_mean_i.to(dtype=self.dtype, device=self.device) new_mean.append(new_mean_i) new_precision = torch.zeros(batch_size, self.dof, self.dof).to( dtype=self.dtype, device=self.device @@ -99,58 +103,6 @@ def zero_message(self): self.update(mean=new_mean, precision=new_precision) -# class CostFunctionOrdering: -# def __init__(self, objective: Objective, default_order: bool = True): -# self.objective = objective -# self._cf_order: List[CostFunction] = [] -# self._cf_name_to_index: Dict[str, int] = {} -# if default_order: -# self._compute_default_order(objective) - -# def _compute_default_order(self, objective: Objective): -# assert not self._cf_order and not self._cf_name_to_index -# cur_idx = 0 -# for cf_name, cf in objective.cost_functions.items(): -# if cf_name in self._cf_name_to_index: -# continue -# self._cf_order.append(cf) -# self._cf_name_to_index[cf_name] = cur_idx -# cur_idx += 1 - -# def index_of(self, key: str) -> int: -# return self._cf_name_to_index[key] - -# def __getitem__(self, index) -> CostFunction: -# return self._cf_order[index] - -# def __iter__(self): -# return iter(self._cf_order) - -# def append(self, cf: CostFunction): -# if cf in self._cf_order: -# raise ValueError( -# f"Cost Function {cf.name} has already been added to the order." -# ) -# if cf.name not in self.objective.cost_functions: -# raise ValueError( -# f"Cost Function {cf.name} is not a cost function for the objective." -# ) -# self._cf_order.append(cf) -# self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 - -# def remove(self, cf: CostFunction): -# self._cf_order.remove(cf) -# del self._cf_name_to_index[cf.name] - -# def extend(self, cfs: Sequence[CostFunction]): -# for cf in cfs: -# self.append(cf) - -# @property -# def complete(self): -# return len(self._cf_order) == self.objective.size_variables() - - """ GBP functions """ @@ -174,20 +126,26 @@ def __init__( self.cf = cf self.lin_system_damping = lin_system_damping - batch_size = cf.optim_var_at(0).shape[0] + # batch_size of the vectorized factor. In general != objective.batch_size. + # They are equal without vectorization or for unique cost function schema. + self.batch_size = cf.optim_var_at(0).shape[0] + + device = cf.optim_var_at(0).device + dtype = cf.optim_var_at(0).dtype self._dof = sum([var.dof() for var in cf.optim_vars]) - self.potential_eta = torch.zeros(batch_size, self.dof).to( - dtype=cf.optim_var_at(0).dtype, device=cf.optim_var_at(0).device + self.potential_eta = torch.zeros(self.batch_size, self.dof).to( + dtype=dtype, device=device ) - self.potential_lam = torch.zeros(batch_size, self.dof, self.dof).to( - dtype=cf.optim_var_at(0).dtype, device=cf.optim_var_at(0).device + self.potential_lam = torch.zeros(self.batch_size, self.dof, self.dof).to( + dtype=dtype, device=device ) self.lin_point = [ var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars ] - self.steps_since_lin = 0 - self.linearize() + self.steps_since_lin = torch.zeros( + self.batch_size, device=device, dtype=torch.int + ) # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. @@ -197,19 +155,25 @@ def linearize( lie=True, ): self.steps_since_lin += 1 - do_lin = False + if relin_threshold is None: - do_lin = True + do_lin = torch.full( + [self.batch_size], + True, + device=self.cf.optim_var_at(0).device, + ) else: - lp_dists = torch.tensor( + lp_dists = torch.cat( [ - lp.local(self.cf.optim_var_at(j)).norm() + lp.local(self.cf.optim_var_at(j)).norm(dim=1)[..., None] for j, lp in enumerate(self.lin_point) - ] + ], + dim=1, ) - do_lin = bool((torch.max(lp_dists) > relin_threshold).item()) + max_dists = lp_dists.max(dim=1)[0] + do_lin = max_dists > relin_threshold - if do_lin: + if torch.sum(do_lin) > 0: # if any factor in the batch needs relinearization J, error = self.cf.weighted_jacobians_error() J_stk = torch.cat(J, dim=-1) @@ -221,13 +185,13 @@ def linearize( eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) - self.potential_eta = eta - self.potential_lam = lam + self.potential_eta[do_lin] = eta[do_lin] + self.potential_lam[do_lin] = lam[do_lin] for j, var in enumerate(self.cf.optim_vars): - self.lin_point[j].update(var.data) + self.lin_point[j].update(var.data, batch_ignore_mask=~do_lin) - self.steps_since_lin = 0 + self.steps_since_lin[do_lin] = 0 # Compute all outgoing messages from the factor. def comp_mess( @@ -264,16 +228,13 @@ def comp_mess( dofs = self.cf.optim_var_at(v).dof() if torch.allclose(lam_factor, lam_factor_copy) and num_optim_vars > 1: - print( - self.cf.name, "---> not updating as incoming message lams are zeros" - ) + # print(self.cf.name, "---> not updating, incoming precision is zero") new_mess = Message([self.cf.optim_var_at(v).copy()]) new_mess.zero_message() else: - print(self.cf.name, "---> sending message") + # print(self.cf.name, "---> sending message") # Divide up parameters of distribution - eo = eta_factor[:, sdim : sdim + dofs] eno = torch.cat( (eta_factor[:, :sdim], eta_factor[:, sdim + dofs :]), dim=1 @@ -323,21 +284,24 @@ def comp_mess( # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. # mean damping - if damping[v] != 0 and self.steps_since_lin > 0: - if ( - new_mess_lam.count_nonzero() != 0 - and ftov_msgs[v].precision.count_nonzero() != 0 - ): + do_damping = torch.logical_and(damping[v] > 0, self.steps_since_lin > 0) + if do_damping.sum() > 0: + damping_check = torch.logical_and( + new_mess_lam.count_nonzero(1, 2) != 0, + ftov_msgs[v].precision.count_nonzero(1, 2) != 0, + ) + do_damping = torch.logical_and(do_damping, damping_check) + if do_damping.sum() > 0: prev_mess_mean, prev_mess_lam = th.local_gaussian( self.lin_point[v], ftov_msgs[v], return_mean=True ) - new_mess_mean = torch.bmm( torch.inverse(new_mess_lam), new_mess_eta.unsqueeze(-1) ).squeeze(-1) - new_mess_mean = (1 - damping[v]) * new_mess_mean + damping[ - v - ] * prev_mess_mean + damping[v][~do_damping] = 0.0 + new_mess_mean = ( + 1 - damping[v][:, None] + ) * new_mess_mean + damping[v][:, None] * prev_mess_mean new_mess_eta = torch.bmm( new_mess_lam, new_mess_mean.unsqueeze(-1) ).squeeze(-1) @@ -392,10 +356,6 @@ def __init__( abs_err_tolerance, rel_err_tolerance, max_iterations ) - self.n_edges = sum( - [cf.num_optim_vars() for cf in self.objective.cost_functions.values()] - ) - # create array for indexing the messages var_ixs_nested = [ [self.ordering.index_of(var.name) for var in cf.optim_vars] @@ -587,26 +547,12 @@ def _pass_var_to_fac_messages( new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _linearize_factors(self, relin_threshold: float): + def _linearize_factors(self, relin_threshold: float = None): relins = 0 - did_relin = [] - - start = time.time() - # compute weighted error and jacobian for all factors - self.objective.update_vectorization() - print("vectorized update time", time.time() - start) - - start = time.time() for factor in self.factors: factor.linearize(relin_threshold=relin_threshold) - if factor.steps_since_lin == 0: - relins += 1 - did_relin += [1] - else: - did_relin += [0] - print("compute factor time", time.time() - start) + relins += int((factor.steps_since_lin == 0).sum().item()) - # print(f"Factor relinearisations: {relins} / {len(self.factors)}") return relins def _pass_fac_to_var_messages( @@ -614,15 +560,73 @@ def _pass_fac_to_var_messages( schedule: torch.Tensor, damping: torch.Tensor, ): + + # USE THE SCHEDULE!!!!! + start = 0 - for factor in self.factors: + start_d = 0 + for j, factor in enumerate(self.factors): num_optim_vars = factor.cf.num_optim_vars() - - factor.comp_mess( - self.vtof_msgs[start : start + num_optim_vars], - self.ftov_msgs[start : start + num_optim_vars], - damping[start : start + num_optim_vars], + n_factors = num_optim_vars * factor.batch_size + damping_tsr = damping[start_d : start_d + n_factors].reshape( + num_optim_vars, factor.batch_size ) + start_d += n_factors + + if self.vectorize: + # prepare vectorized messages + ixs = torch.tensor(self.objective.vectorized_msg_ixs[j]) + vtof_msgs: List[Message] = [] + ftov_msgs: List[Message] = [] + for var in factor.cf.optim_vars: + mean_vtof_msgs = var.copy() + mean_ftov_msgs = var.copy() + mean_data_vtof_msgs = torch.cat( + [self.vtof_msgs[i].mean[0].data for i in ixs] + ) + mean_data_ftov_msgs = torch.cat( + [self.ftov_msgs[i].mean[0].data for i in ixs] + ) + mean_vtof_msgs.update(data=mean_data_vtof_msgs) + mean_ftov_msgs.update(data=mean_data_ftov_msgs) + precision_vtof_msgs = torch.cat( + [self.vtof_msgs[i].precision for i in ixs] + ) + precision_ftov_msgs = torch.cat( + [self.ftov_msgs[i].precision for i in ixs] + ) + + vtof_msg = Message( + mean=[mean_vtof_msgs], precision=precision_vtof_msgs + ) + ftov_msg = Message( + mean=[mean_ftov_msgs], precision=precision_ftov_msgs + ) + vtof_msgs.append(vtof_msg) + ftov_msgs.append(ftov_msg) + + ixs += 1 + else: + vtof_msgs = self.vtof_msgs[start : start + num_optim_vars] + ftov_msgs = self.ftov_msgs[start : start + num_optim_vars] + + factor.comp_mess(vtof_msgs, ftov_msgs, damping_tsr) + + if self.vectorize: + # fill in messages using vectorized messages + ixs = torch.tensor(self.objective.vectorized_msg_ixs[j]) + for ftov_msg in ftov_msgs: + start_idx = 0 + for ix in ixs: + v_slice = slice( + start_idx, start_idx + self.objective.batch_size + ) + self.ftov_msgs[ix].update( + mean=[ftov_msg.mean[0][v_slice]], + precision=ftov_msg.precision[v_slice], + ) + start_idx += self.objective.batch_size + ixs += 1 start += num_optim_vars @@ -641,28 +645,16 @@ def _optimize_loop( relin_threshold: float, damping: float, dropout: float, - schedule: torch.Tensor, + schedule: GBPSchedule, lin_system_damping: float, clear_messages: bool = True, **kwargs, ): if damping > 1.0 or damping < 0.0: - raise ValueError(f"Damping must be in between 0 and 1. Got {damping}.") + raise ValueError(f"Damping must be between 0 and 1. Got {damping}.") if dropout > 1.0 or dropout < 0.0: raise ValueError( - f"Dropout probability must be in between 0 and 1. Got {dropout}." - ) - if schedule is None: - schedule = random_schedule(self.params.max_iterations, self.n_edges) - elif schedule.dtype != torch.bool: - raise ValueError( - f"Schedule must be of dtype {torch.bool} but has dtype {schedule.dtype}." - ) - elif schedule.shape != torch.Size([self.params.max_iterations, self.n_edges]): - raise ValueError( - f"Schedule must have shape [max_iterations, num_edges]. " - f"Should be {torch.Size([self.params.max_iterations, self.n_edges])} " - f"but got {schedule.shape}." + f"Dropout probability must be between 0 and 1. Got {dropout}." ) if clear_messages: @@ -689,9 +681,18 @@ def _optimize_loop( for var in self.ordering: self.beliefs.append(th.ManifoldGaussian([var])) + self.n_individual_factors = ( + len(self.objective.cost_functions) * self.objective.batch_size + ) + if self.vectorize: + self.objective.update_vectorization(compute_caches=False) + cf_iterator = iter(self.objective.vectorized_cost_fns) + else: + cf_iterator = self.objective._get_iterator() + # compute factor potentials for the first time self.factors: List[Factor] = [] - for cost_function in self.objective._get_iterator(): + for cost_function in cf_iterator: self.factors.append( Factor( cost_function, @@ -699,6 +700,17 @@ def _optimize_loop( lin_system_damping=lin_system_damping, ) ) + relins = self._linearize_factors() + + self.n_edges = sum( + [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] + ) + if schedule == GBPSchedule.RANDOM: + ftov_schedule = random_schedule(self.params.max_iterations, self.n_edges) + elif schedule == GBPSchedule.SYNCHRONOUS: + ftov_schedule = synchronous_schedule( + self.params.max_iterations, self.n_edges + ) self.belief_history = {} self.ftov_msgs_history = {} @@ -709,34 +721,50 @@ def _optimize_loop( self.belief_history[it_] = [belief.copy() for belief in self.beliefs] # damping - damping_arr = torch.full([self.n_edges], damping) - + damping_arr = torch.full( + [self.n_edges], + damping, + device=self.ordering[0].device, + dtype=self.ordering[0].dtype, + ) # dropout can be implemented through damping if dropout != 0.0: dropout_ixs = torch.rand(self.n_edges) < dropout damping_arr[dropout_ixs] = 1.0 - t1 = time.time() + t0 = time.time() relins = self._linearize_factors(relin_threshold) - print("relin time", time.time() - t1) + t_relin = time.time() - t0 t1 = time.time() - self._pass_fac_to_var_messages(schedule[it_], damping_arr) - # print("ftov time", time.time() - t1) + self._pass_fac_to_var_messages(ftov_schedule[it_], damping_arr) + t_ftov = time.time() - t1 t1 = time.time() self._pass_var_to_fac_messages(update_belief=True) - # print("vtof time", time.time() - t1) + t_vtof = time.time() - t1 + + t_vec = 0.0 + if self.vectorize: + t1 = time.time() + self.objective.update_vectorization(compute_caches=False) + t_vec = time.time() - t1 + + t_tot = time.time() - t0 + print( + f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," + f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" + ) # check for convergence - if it_ > 0: + if it_ >= 0: with torch.no_grad(): err = self.objective.error_squared_norm() / 2 self._update_info(info, it_, err, converged_indices) if verbose: print( f"GBP. Iteration: {it_+1}. Error: {err.mean().item():.4f}. " - f"Relins: {relins} / {len(self.factors)}" + f"Relins: {relins} / {self.n_individual_factors}" ) converged_indices = self._check_convergence(err, info.last_err) info.status[ @@ -763,10 +791,12 @@ def _optimize_impl( relin_threshold: float = 0.1, damping: float = 0.0, dropout: float = 0.0, - schedule: torch.Tensor = None, + schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, lin_system_damping: float = 1e-6, + vectorize: bool = True, **kwargs, ) -> NonlinearOptimizerInfo: + self.vectorize = vectorize and self.objective.vectorized_cost_fns is not None with torch.no_grad(): info = self._init_info(track_best_solution, track_err_history, verbose) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index a7fe1899c..d0d02d9e9 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -7,7 +7,7 @@ import torch import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule +from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule # This example illustrates the Gaussian Belief Propagation (GBP) optimizer # for a 2D pose graph optimization problem. @@ -68,9 +68,7 @@ def create_pgo(): inputs[f"x{p}"] = init[None, :] inputs[f"prior_{p}"] = init[None, :] - cf_prior = th.Difference( - poses[p], w, prior_target, name=f"prior_cost_{p}" - ) + cf_prior = th.Difference(poses[p], w, prior_target, name=f"prior_cost_{p}") objective.add(cf_prior) @@ -170,7 +168,8 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): "relin_threshold": 1e-8, "damping": 0.0, "dropout": 0.0, - "schedule": synchronous_schedule(max_iterations, optimizer.n_edges), + "schedule": GBPSchedule.SYNCHRONOUS, + "vectorize": True, } outputs_gbp, info = theseus_optim.forward(inputs, optim_arg) diff --git a/theseus/optimizer/gbp/vectorize_test.py b/theseus/optimizer/gbp/vectorize_test.py index 6f434ef9e..4457ab047 100644 --- a/theseus/optimizer/gbp/vectorize_test.py +++ b/theseus/optimizer/gbp/vectorize_test.py @@ -1,7 +1,7 @@ import torch import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation, synchronous_schedule +from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule torch.manual_seed(0) @@ -91,7 +91,7 @@ def construct_theseus_layer_inputs(): max_iterations=50, # step_size=0.5, ) -theseus_optim = th.TheseusLayer(optimizer, vectorize=False) +theseus_optim = th.TheseusLayer(optimizer, vectorize=True) a_tensor = torch.nn.Parameter(torch.rand(num_models, 1)) @@ -104,8 +104,9 @@ def construct_theseus_layer_inputs(): "relin_threshold": 0.0000000001, "damping": 0.5, "dropout": 0.0, - "schedule": synchronous_schedule(50, optimizer.n_edges), + "schedule": GBPSchedule.SYNCHRONOUS, "lin_system_damping": 1e-5, + "vectorize": True, } From c7d9b1cc9257c9b9eb8b2c24feb2c4e0ece47005 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Tue, 5 Jul 2022 13:33:07 +0100 Subject: [PATCH 29/64] added missing aux vars to reprojection error cf --- .../examples/bundle_adjustment/reprojection_error.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/theseus/utils/examples/bundle_adjustment/reprojection_error.py b/theseus/utils/examples/bundle_adjustment/reprojection_error.py index ed811ba74..1b176dc3a 100644 --- a/theseus/utils/examples/bundle_adjustment/reprojection_error.py +++ b/theseus/utils/examples/bundle_adjustment/reprojection_error.py @@ -52,7 +52,13 @@ def __init__( self.register_optim_vars(["camera_pose", "world_point"]) self.register_aux_vars( - ["log_loss_radius", "focal_length", "image_feature_point"] + [ + "log_loss_radius", + "focal_length", + "image_feature_point", + "calib_k1", + "calib_k2", + ] ) def error(self) -> torch.Tensor: From 954a154f58c328f81dbba85715d46c92c632253c Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Thu, 7 Jul 2022 09:37:19 +0100 Subject: [PATCH 30/64] vectorized vtof msg passing --- theseus/core/objective.py | 11 +- theseus/core/vectorizer.py | 11 +- theseus/optimizer/gbp/ba_test.py | 19 +- theseus/optimizer/gbp/gbp.py | 411 ++++++++++++++++++++----------- 4 files changed, 280 insertions(+), 172 deletions(-) diff --git a/theseus/core/objective.py b/theseus/core/objective.py index f65c5c922..a7afdc061 100644 --- a/theseus/core/objective.py +++ b/theseus/core/objective.py @@ -72,7 +72,8 @@ def __init__(self, dtype: Optional[torch.dtype] = None): self._vectorization_to: Optional[Callable] = None self.vectorized_cost_fns: Optional[List[CostFunction]] = None - self.vectorized_msg_ixs: Optional[List[List[int]]] = None + # nested list of name of each base cost function in the vectorized cfs + self.vectorized_cf_names: Optional[List[List[str]]] = None # If vectorization is on, this gets replaced by a vectorized version self._retract_method = Objective._retract_base @@ -565,7 +566,7 @@ def _enable_vectorization( vectorized_to: Callable, vectorized_retract_fn: Callable, vectorized_cost_fns: List[CostFunction], - vectorized_msg_ixs: List[List[int]], + vectorized_cf_names: List[List[str]], enabler: Any, ): # Hacky way to make Vectorize a "friend" class @@ -578,7 +579,7 @@ def _enable_vectorization( self._vectorization_to = vectorized_to self._retract_method = vectorized_retract_fn self.vectorized_cost_fns = vectorized_cost_fns - self.vectorized_msg_ixs = vectorized_msg_ixs + self.vectorized_cf_names = vectorized_cf_names self._vectorized = True # Making public, since this should be a safe operation @@ -588,7 +589,7 @@ def disable_vectorization(self): self._vectorization_to = None self._retract_method = Objective._retract_base self.vectorized_cost_fns = None - self.vectorized_msg_ixs = None + self.vectorized_cf_names = None self._vectorized = False @property @@ -600,6 +601,6 @@ def vectorized(self): == (self._vectorization_to is None) == (self._retract_method is Objective._retract_base) == (self.vectorized_cost_fns is None) - == (self.vectorized_msg_ixs is None) + == (self.vectorized_cf_names is None) ) return self._vectorized diff --git a/theseus/core/vectorizer.py b/theseus/core/vectorizer.py index 6f682b392..047d4f5b5 100644 --- a/theseus/core/vectorizer.py +++ b/theseus/core/vectorizer.py @@ -94,18 +94,16 @@ def __init__(self, objective: Objective): _CostFunctionSchema, List[_CostFunctionWrapper] ] = defaultdict(list) - schema_ixs_dict: Dict[_CostFunctionSchema, List[int]] = defaultdict(list) + schema_cf_names_dict: Dict[_CostFunctionSchema, List[str]] = defaultdict(list) # Create wrappers for all cost functions and also get their schemas - msg_ix = 0 for cost_fn in objective.cost_functions.values(): wrapper = _CostFunctionWrapper(cost_fn) self._cost_fn_wrappers.append(wrapper) schema = _get_cost_function_schema(cost_fn) self._schema_dict[schema].append(wrapper) - schema_ixs_dict[schema].append(msg_ix) - msg_ix += cost_fn.num_optim_vars() + schema_cf_names_dict[schema].append(cost_fn.name) # Now create a vectorized cost function for each unique schema self._vectorized_cost_fns: Dict[_CostFunctionSchema, CostFunction] = {} @@ -130,7 +128,7 @@ def __init__(self, objective: Objective): self._to, self._vectorized_retract_optim_vars, list(self._vectorized_cost_fns.values()), - list(schema_ixs_dict.values()), + list(schema_cf_names_dict.values()), self, ) @@ -297,8 +295,7 @@ def _clear_wrapper_caches(self): cf._cached_jacobians = None def _vectorize(self, compute_caches=True): - if compute_caches: - self._clear_wrapper_caches() + self._clear_wrapper_caches() for schema, cost_fn_wrappers in self._schema_dict.items(): var_names = self._var_names[schema] vectorized_cost_fn = self._vectorized_cost_fns[schema] diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index 70c25fc4c..eb41a2275 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -83,8 +83,8 @@ def run(cfg: omegaconf.OmegaConf): # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") # param that control transition from squared loss to huber - # radius_tensor = torch.tensor([1.0], dtype=torch.float64) - # log_loss_radius = th.Vector(data=radius_tensor, name="log_loss_radius") + radius_tensor = torch.tensor([1.0], dtype=torch.float64) + log_loss_radius = th.Vector(data=radius_tensor, name="log_loss_radius") # Set up objective print("Setting up objective") @@ -103,14 +103,13 @@ def run(cfg: omegaconf.OmegaConf): image_feature_point=obs.image_feature_point, weight=weight, ) - # robust_cost_function = th.RobustCostFunction( - # cost_function, - # th.HuberLoss, - # log_loss_radius, - # name=f"robust_{cost_function.name}", - # ) - # objective.add(robust_cost_function) - objective.add(cost_function) + robust_cost_function = th.RobustCostFunction( + cost_function, + th.HuberLoss, + log_loss_radius, + name=f"robust_{cost_function.name}", + ) + objective.add(robust_cost_function) dtype = objective.dtype # Add regularization diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 1ebe4bfd0..80a3bf3b0 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -10,7 +10,7 @@ from dataclasses import dataclass from enum import Enum from itertools import count -from typing import Dict, List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Tuple, Type import numpy as np import torch @@ -29,6 +29,10 @@ """ TODO - solving inverse problem to compute message mean + - move vectorized indices compute to init + - use schedule + - random schedule not vectorized + - factor inherits CF class """ @@ -114,6 +118,7 @@ class Factor: def __init__( self, cf: CostFunction, + var_ixs: torch.Tensor, name: Optional[str] = None, lin_system_damping: float = 1e-6, ): @@ -124,6 +129,7 @@ def __init__( self.name = f"{self.__class__.__name__}__{self._id}" self.cf = cf + self.var_ixs = var_ixs self.lin_system_damping = lin_system_damping # batch_size of the vectorized factor. In general != objective.batch_size. @@ -147,6 +153,18 @@ def __init__( self.batch_size, device=device, dtype=torch.int ) + self.vtof_msgs: List[Message] = [] + self.ftov_msgs: List[Message] = [] + for var in cf.optim_vars: + # Set mean of initial message to identity of the group + # NB doesn't matter what it is as long as precision is zero + vtof_msg = Message([var.copy()], name=f"msg_{var.name}_to_{cf.name}") + ftov_msg = Message([var.copy()], name=f"msg_{cf.name}_to_{var.name}") + vtof_msg.zero_message() + ftov_msg.zero_message() + self.vtof_msgs.append(vtof_msg) + self.ftov_msgs.append(ftov_msg) + # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. def linearize( @@ -196,8 +214,6 @@ def linearize( # Compute all outgoing messages from the factor. def comp_mess( self, - vtof_msgs, - ftov_msgs, damping, ): num_optim_vars = self.cf.num_optim_vars() @@ -216,7 +232,7 @@ def comp_mess( var_dofs = self.cf.optim_var_at(i).dof() if i != v: eta_mess, lam_mess = th.local_gaussian( - self.lin_point[i], vtof_msgs[i], return_mean=False + self.lin_point[i], self.vtof_msgs[i], return_mean=False ) eta_factor[:, start : start + var_dofs] += eta_mess lam_factor[ @@ -288,12 +304,12 @@ def comp_mess( if do_damping.sum() > 0: damping_check = torch.logical_and( new_mess_lam.count_nonzero(1, 2) != 0, - ftov_msgs[v].precision.count_nonzero(1, 2) != 0, + self.ftov_msgs[v].precision.count_nonzero(1, 2) != 0, ) do_damping = torch.logical_and(do_damping, damping_check) if do_damping.sum() > 0: prev_mess_mean, prev_mess_lam = th.local_gaussian( - self.lin_point[v], ftov_msgs[v], return_mean=True + self.lin_point[v], self.ftov_msgs[v], return_mean=True ) new_mess_mean = torch.bmm( torch.inverse(new_mess_lam), new_mess_eta.unsqueeze(-1) @@ -327,12 +343,10 @@ def comp_mess( # update messages for v in range(num_optim_vars): - ftov_msgs[v].update( + self.ftov_msgs[v].update( mean=new_messages[v].mean, precision=new_messages[v].precision ) - return new_messages - @property def dof(self) -> int: return self._dof @@ -357,13 +371,8 @@ def __init__( abs_err_tolerance, rel_err_tolerance, max_iterations ) - # create array for indexing the messages - var_ixs_nested = [ - [self.ordering.index_of(var.name) for var in cf.optim_vars] - for cf in self.objective.cost_functions.values() - ] - var_ixs = [item for sublist in var_ixs_nested for item in sublist] - self.var_ix_for_edges = torch.tensor(var_ixs).long() + self.beliefs: List[th.ManifoldGaussian] = [] + self.factors: List[Factor] = [] """ Copied and slightly modified from nonlinear optimizer class @@ -497,63 +506,211 @@ def _merge_infos( GBP functions """ - def _pass_var_to_fac_messages( + def _pass_var_to_fac_messages_loop( self, update_belief=True, ): for i, var in enumerate(self.ordering): # Collect all incoming messages in the tangent space at the current belief - taus = [] # message means + etas_tp = [] # message etas lams_tp = [] # message lams - for j, msg in enumerate(self.ftov_msgs): - if self.var_ix_for_edges[j] == i: - # print(msg.mean, msg.precision) - tau, lam_tp = th.local_gaussian(var, msg, return_mean=True) - taus.append(tau[None, ...]) - lams_tp.append(lam_tp[None, ...]) - - taus = torch.cat(taus) + for factor in self.factors: + for j, msg in enumerate(factor.ftov_msgs): + if factor.var_ixs[j] == i: + eta_tp, lam_tp = th.local_gaussian(var, msg, return_mean=False) + etas_tp.append(eta_tp[None, ...]) + lams_tp.append(lam_tp[None, ...]) + + etas_tp = torch.cat(etas_tp) lams_tp = torch.cat(lams_tp) lam_tau = lams_tp.sum(dim=0) # Compute outgoing messages - ix = 0 - for j, msg in enumerate(self.ftov_msgs): - if self.var_ix_for_edges[j] == i: - taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) - lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) - - lam_a = lams_inc.sum(dim=0) - if lam_a.count_nonzero() == 0: - self.vtof_msgs[j].zero_message() - else: - inv_lam_a = torch.linalg.inv(lam_a) - sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( - dim=0 - ) - tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) - new_mess = th.retract_gaussian(var, tau_a, lam_a) - self.vtof_msgs[j].update(new_mess.mean, new_mess.precision) - ix += 1 + ix = 0 # index for msg in list of msgs to variables + for factor in self.factors: + for j, msg in enumerate(factor.vtof_msgs): + if factor.var_ixs[j] == i: + etas_inc = torch.cat((etas_tp[:ix], etas_tp[ix + 1 :])) + lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) + + lam_a = lams_inc.sum(dim=0) + if lam_a.count_nonzero() == 0: + msg.zero_message() + else: + inv_lam_a = torch.linalg.inv(lam_a) + sum_etas = etas_inc.sum(dim=0) + mean_a = torch.matmul( + inv_lam_a, sum_etas.unsqueeze(-1) + ).squeeze(-1) + new_mess = th.retract_gaussian(var, mean_a, lam_a) + msg.update(new_mess.mean, new_mess.precision) + ix += 1 # update belief mean and variance # if no incoming messages then leave current belief unchanged if update_belief and lam_tau.count_nonzero() != 0: inv_lam_tau = torch.inverse(lam_tau) - sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) - tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) + sum_taus = etas_tp.sum(dim=0) + tau = torch.matmul(inv_lam_tau, sum_taus.unsqueeze(-1)).squeeze(-1) new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) + def _pass_var_to_fac_messages_vectorized( + self, + update_belief=True, + ): + # Each (variable-type, dof) gets mapped to a tuple with: + # - the variable that will hold the vectorized data + # - all the variables of that type that will be vectorized together + # - list of variable indices from ordering + # - tensor that will hold incoming messages [eta, lam] in the belief tangent plane + var_info: Dict[ + Tuple[Type[Manifold], int], + Tuple[Manifold, List[Manifold], List[int], List[torch.Tensor]], + ] = {} + + batch_size = -1 + # Create var info by looping variables in the given order + # All variables of the same type get grouped together + for ix, var in enumerate(self.ordering): + if batch_size == -1: + batch_size = var.shape[0] + else: + assert batch_size == var.shape[0] + + var_type = (var.__class__, var.dof()) + if var_type not in var_info: + var_info[var_type] = (var.copy(), [], [], []) + var_info[var_type][1].append(var) + var_info[var_type][2].append(ix) + + for var_type, (vectorized_var, var_list, _, eta_lam) in var_info.items(): + n_vars, dof = len(var_list), vectorized_var.dof() + + # Get the vectorized tensor that has the current variable data. + # The resulting shape is (N * b, M), b is batch size, N is the number of + # variables in the group, and M is the data shape for this class + vectorized_data = torch.cat([v.data for v in var_list], dim=0) + assert ( + vectorized_data.shape + == (n_vars * batch_size,) + vectorized_data.shape[1:] + ) + vectorized_var.update(vectorized_data) + + eta_tp_acc = torch.zeros(n_vars * batch_size, dof) + lam_tp_acc = torch.zeros(n_vars * batch_size, dof, dof) + eta_tp_acc = eta_tp_acc.to(vectorized_data.device, vectorized_data.dtype) + lam_tp_acc = lam_tp_acc.to(vectorized_data.device, vectorized_data.dtype) + eta_lam.extend([eta_tp_acc, lam_tp_acc]) + + # add ftov messages to eta_tp and lam_tp accumulator tensors + for factor in self.factors: + for i, msg in enumerate(factor.ftov_msgs): + + # transform messages to tangent plane at the current variable value + eta_tp, lam_tp = th.local_gaussian( + factor.cf.optim_var_at(i), msg, return_mean=False + ) + + receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) + + # add messages to correct variable using indices + eta_tp_acc = var_info[receiver_var_type][3][0] + lam_tp_acc = var_info[receiver_var_type][3][1] + + # COULD MOVE THIS TO THE INIT + # get indices of the vectorized variables that receive each message + receiver_var_ixs = factor.var_ixs[:, i] # ixs of the receiver variables + var_type_ixs = torch.tensor( + var_info[receiver_var_type][2] + ) # all ixs for variables of this type + var_type_ixs = var_type_ixs[None, :].repeat(len(receiver_var_ixs), 1) + is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 + indices = is_receiver.nonzero()[:, 1] + + # expand indices for all batch variables when batch size > 1 + if self.objective.batch_size != 1: + indices = indices[:, None].repeat(1, self.objective.batch_size) + shift = ( + torch.arange(self.objective.batch_size)[None, :] + .long() + .to(indices.device) + ) + indices = indices + shift + indices = indices.flatten() + + eta_tp_acc.index_add_(0, indices, eta_tp) + lam_tp_acc.index_add_(0, indices, lam_tp) + + # compute variable to factor messages, now all incoming messages are accumulated + for factor in self.factors: + for i, msg in enumerate(factor.vtof_msgs): + + # transform messages to tangent plane at the current variable value + ftov_msg = factor.ftov_msgs[i] + eta_tp, lam_tp = th.local_gaussian( + factor.cf.optim_var_at(i), ftov_msg, return_mean=False + ) + + receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) + eta_tp_acc = var_info[receiver_var_type][3][0] + lam_tp_acc = var_info[receiver_var_type][3][1] + + # COULD MOVE THIS TO THE INIT + # get indices of the vectorized variables that receive each message + receiver_var_ixs = factor.var_ixs[:, i] # ixs of the receiver variables + var_type_ixs = torch.tensor( + var_info[receiver_var_type][2] + ) # all ixs for variables of this type + var_type_ixs = var_type_ixs[None, :].repeat(len(receiver_var_ixs), 1) + is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 + indices = is_receiver.nonzero()[:, 1] + + # new outgoing message is belief - last incoming mesage (in log space parameters) + sum_etas = eta_tp_acc[indices] - eta_tp + lam_a = lam_tp_acc[indices] - lam_tp + + if lam_a.count_nonzero() == 0: + msg.zero_message() + else: + print(lam_a.shape, lam_a) + inv_lam_a = torch.linalg.inv(lam_a) + mean_a = torch.matmul(inv_lam_a, sum_etas.unsqueeze(-1)).squeeze(-1) + new_mess = th.retract_gaussian( + factor.cf.optim_var_at(i), mean_a, lam_a + ) + msg.update(new_mess.mean, new_mess.precision) + + # compute the new belief for the vectorized variables + for (vectorized_var, _, var_ixs, eta_lam) in var_info.values(): + eta_tp_acc = eta_lam[0] + lam_tau = eta_lam[1] + + if update_belief and lam_tau.count_nonzero() != 0: + inv_lam_tau = torch.inverse(lam_tau) + tau = torch.matmul(inv_lam_tau, eta_tp_acc.unsqueeze(-1)).squeeze(-1) + + new_belief = th.retract_gaussian(vectorized_var, tau, lam_tau) + # update non vectorized beliefs with slices + start_idx = 0 + for ix in var_ixs: + belief_mean_slice = new_belief.mean[0][ + start_idx : start_idx + batch_size + ] + belief_precision_slice = new_belief.precision[ + start_idx : start_idx + batch_size + ] + self.beliefs[ix].update([belief_mean_slice], belief_precision_slice) + start_idx += batch_size + def _linearize_factors(self, relin_threshold: float = None): relins = 0 for factor in self.factors: factor.linearize(relin_threshold=relin_threshold) relins += int((factor.steps_since_lin == 0).sum().item()) - return relins def _pass_fac_to_var_messages( @@ -564,72 +721,15 @@ def _pass_fac_to_var_messages( # USE THE SCHEDULE!!!!! - start = 0 start_d = 0 for j, factor in enumerate(self.factors): num_optim_vars = factor.cf.num_optim_vars() n_factors = num_optim_vars * factor.batch_size - damping_tsr = damping[start_d : start_d + n_factors].reshape( - num_optim_vars, factor.batch_size - ) + damping_tsr = damping[start_d : start_d + n_factors] + damping_tsr = damping_tsr.reshape(num_optim_vars, factor.batch_size) start_d += n_factors - if self.objective.vectorized: - # prepare vectorized messages - ixs = torch.tensor(self.objective.vectorized_msg_ixs[j]) - vtof_msgs: List[Message] = [] - ftov_msgs: List[Message] = [] - for var in factor.cf.optim_vars: - mean_vtof_msgs = var.copy() - mean_ftov_msgs = var.copy() - mean_data_vtof_msgs = torch.cat( - [self.vtof_msgs[i].mean[0].data for i in ixs] - ) - mean_data_ftov_msgs = torch.cat( - [self.ftov_msgs[i].mean[0].data for i in ixs] - ) - mean_vtof_msgs.update(data=mean_data_vtof_msgs) - mean_ftov_msgs.update(data=mean_data_ftov_msgs) - precision_vtof_msgs = torch.cat( - [self.vtof_msgs[i].precision for i in ixs] - ) - precision_ftov_msgs = torch.cat( - [self.ftov_msgs[i].precision for i in ixs] - ) - - vtof_msg = Message( - mean=[mean_vtof_msgs], precision=precision_vtof_msgs - ) - ftov_msg = Message( - mean=[mean_ftov_msgs], precision=precision_ftov_msgs - ) - vtof_msgs.append(vtof_msg) - ftov_msgs.append(ftov_msg) - - ixs += 1 - else: - vtof_msgs = self.vtof_msgs[start : start + num_optim_vars] - ftov_msgs = self.ftov_msgs[start : start + num_optim_vars] - - factor.comp_mess(vtof_msgs, ftov_msgs, damping_tsr) - - if self.objective.vectorized: - # fill in messages using vectorized messages - ixs = torch.tensor(self.objective.vectorized_msg_ixs[j]) - for ftov_msg in ftov_msgs: - start_idx = 0 - for ix in ixs: - v_slice = slice( - start_idx, start_idx + self.objective.batch_size - ) - self.ftov_msgs[ix].update( - mean=[ftov_msg.mean[0][v_slice]], - precision=ftov_msg.precision[v_slice], - ) - start_idx += self.objective.batch_size - ixs += 1 - - start += num_optim_vars + factor.comp_mess(damping_tsr) """ Optimization loop functions @@ -658,49 +758,56 @@ def _optimize_loop( f"Dropout probability must be between 0 and 1. Got {dropout}." ) - if clear_messages: - # initialise messages with zeros - self.vtof_msgs: List[Message] = [] - self.ftov_msgs: List[Message] = [] - for cf in self.objective.cost_functions.values(): - for var in cf.optim_vars: - # Set mean of initial message to identity of the group - # doesn't matter what it is as long as precision is zero - vtof_msg = Message( - [var.copy()], name=f"msg_{var.name}_to_{cf.name}" - ) - ftov_msg = Message( - [var.copy()], name=f"msg_{cf.name}_to_{var.name}" - ) - vtof_msg.zero_message() - ftov_msg.zero_message() - self.vtof_msgs.append(vtof_msg) - self.ftov_msgs.append(ftov_msg) - - # initialise ManifoldGaussian for belief - self.beliefs: List[th.ManifoldGaussian] = [] + # initialise beliefs for var in self.ordering: self.beliefs.append(th.ManifoldGaussian([var])) - self.n_individual_factors = ( - len(self.objective.cost_functions) * self.objective.batch_size - ) - if self.objective.vectorized: - cf_iterator = iter(self.objective.vectorized_cost_fns) - else: - cf_iterator = self.objective._get_iterator() + if clear_messages: + self.n_individual_factors = ( + len(self.objective.cost_functions) * self.objective.batch_size + ) - # compute factor potentials for the first time - self.factors: List[Factor] = [] - for cost_function in cf_iterator: - self.factors.append( - Factor( - cost_function, - name=cost_function.name, - lin_system_damping=lin_system_damping, + if self.objective.vectorized: + cf_iterator = iter(self.objective.vectorized_cost_fns) + self._pass_var_to_fac_messages = ( + self._pass_var_to_fac_messages_vectorized ) - ) - relins = self._linearize_factors() + else: + cf_iterator = self.objective._get_iterator() + self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop + + # compute factor potentials for the first time + for i, cost_function in enumerate(cf_iterator): + + if self.objective.vectorized: + # create array for indexing the messages + base_cf_names = self.objective.vectorized_cf_names[i] + base_cfs = [ + self.objective.get_cost_function(name) for name in base_cf_names + ] + var_ixs = torch.tensor( + [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] + for cf in base_cfs + ] + ).long() + else: + var_ixs = torch.tensor( + [ + self.ordering.index_of(var.name) + for var in cost_function.optim_vars + ] + ).long() + + self.factors.append( + Factor( + cost_function, + name=cost_function.name, + var_ixs=var_ixs, + lin_system_damping=lin_system_damping, + ) + ) + relins = self._linearize_factors() self.n_edges = sum( [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] @@ -717,7 +824,10 @@ def _optimize_loop( converged_indices = torch.zeros_like(info.last_err).bool() for it_ in range(start_iter, start_iter + num_iter): - self.ftov_msgs_history[it_] = [msg.copy() for msg in self.ftov_msgs] + curr_ftov_msgs = [] + for factor in self.factors: + curr_ftov_msgs.extend([msg.copy() for msg in factor.ftov_msgs]) + self.ftov_msgs_history[it_] = curr_ftov_msgs self.belief_history[it_] = [belief.copy() for belief in self.beliefs] # damping @@ -751,10 +861,11 @@ def _optimize_loop( t_vec = time.time() - t1 t_tot = time.time() - t0 - print( - f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," - f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" - ) + if verbose: + print( + f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," + f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" + ) # check for convergence if it_ >= 0: From 8bcd2cdf88ac227156b275de86c998efb8236c50 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Thu, 7 Jul 2022 15:06:53 +0100 Subject: [PATCH 31/64] handles vectorized inversion with some singular matrices, only computes vectorization indices once --- theseus/optimizer/gbp/gbp.py | 80 ++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 80a3bf3b0..bc82a70e8 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -29,7 +29,6 @@ """ TODO - solving inverse problem to compute message mean - - move vectorized indices compute to init - use schedule - random schedule not vectorized - factor inherits CF class @@ -165,6 +164,9 @@ def __init__( self.vtof_msgs.append(vtof_msg) self.ftov_msgs.append(ftov_msg) + # for vectorized vtof message passing + self.vectorized_var_ixs: List[torch.Tensor] = [None] * cf.num_optim_vars() + # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. def linearize( @@ -616,34 +618,36 @@ def _pass_var_to_fac_messages_vectorized( ) receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) + # get indices of the vectorized variables that receive each message + if factor.vectorized_var_ixs[i] is None: + receiver_var_ixs = factor.var_ixs[ + :, i + ] # ixs of the receiver variables + var_type_ixs = torch.tensor( + var_info[receiver_var_type][2] + ) # all ixs for variables of this type + var_type_ixs = var_type_ixs[None, :].repeat( + len(receiver_var_ixs), 1 + ) + is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 + indices = is_receiver.nonzero()[:, 1] + # expand indices for all batch variables when batch size > 1 + if self.objective.batch_size != 1: + indices = indices[:, None].repeat(1, self.objective.batch_size) + shift = ( + torch.arange(self.objective.batch_size)[None, :] + .long() + .to(indices.device) + ) + indices = indices + shift + indices = indices.flatten() + factor.vectorized_var_ixs[i] = indices # add messages to correct variable using indices eta_tp_acc = var_info[receiver_var_type][3][0] lam_tp_acc = var_info[receiver_var_type][3][1] - - # COULD MOVE THIS TO THE INIT - # get indices of the vectorized variables that receive each message - receiver_var_ixs = factor.var_ixs[:, i] # ixs of the receiver variables - var_type_ixs = torch.tensor( - var_info[receiver_var_type][2] - ) # all ixs for variables of this type - var_type_ixs = var_type_ixs[None, :].repeat(len(receiver_var_ixs), 1) - is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 - indices = is_receiver.nonzero()[:, 1] - - # expand indices for all batch variables when batch size > 1 - if self.objective.batch_size != 1: - indices = indices[:, None].repeat(1, self.objective.batch_size) - shift = ( - torch.arange(self.objective.batch_size)[None, :] - .long() - .to(indices.device) - ) - indices = indices + shift - indices = indices.flatten() - - eta_tp_acc.index_add_(0, indices, eta_tp) - lam_tp_acc.index_add_(0, indices, lam_tp) + eta_tp_acc.index_add_(0, factor.vectorized_var_ixs[i], eta_tp) + lam_tp_acc.index_add_(0, factor.vectorized_var_ixs[i], lam_tp) # compute variable to factor messages, now all incoming messages are accumulated for factor in self.factors: @@ -655,29 +659,25 @@ def _pass_var_to_fac_messages_vectorized( factor.cf.optim_var_at(i), ftov_msg, return_mean=False ) + # new outgoing message is belief - last incoming mesage (in log space parameters) receiver_var_type = (msg.mean[0].__class__, msg.mean[0].dof()) eta_tp_acc = var_info[receiver_var_type][3][0] lam_tp_acc = var_info[receiver_var_type][3][1] - - # COULD MOVE THIS TO THE INIT - # get indices of the vectorized variables that receive each message - receiver_var_ixs = factor.var_ixs[:, i] # ixs of the receiver variables - var_type_ixs = torch.tensor( - var_info[receiver_var_type][2] - ) # all ixs for variables of this type - var_type_ixs = var_type_ixs[None, :].repeat(len(receiver_var_ixs), 1) - is_receiver = (var_type_ixs - receiver_var_ixs[:, None]) == 0 - indices = is_receiver.nonzero()[:, 1] - - # new outgoing message is belief - last incoming mesage (in log space parameters) - sum_etas = eta_tp_acc[indices] - eta_tp - lam_a = lam_tp_acc[indices] - lam_tp + sum_etas = eta_tp_acc[factor.vectorized_var_ixs[i]] - eta_tp + lam_a = lam_tp_acc[factor.vectorized_var_ixs[i]] - lam_tp if lam_a.count_nonzero() == 0: msg.zero_message() else: - print(lam_a.shape, lam_a) + zero_lam = lam_a.count_nonzero(1, 2) == 0 + # add to zero precision matrices so inversion doesn't fail + lam_a[zero_lam] += torch.eye( + lam_a.shape[1], dtype=lam_a.dtype, device=lam_a.device + ) inv_lam_a = torch.linalg.inv(lam_a) + # restore zeros precision matrices + lam_a[zero_lam] = 0.0 + inv_lam_a[zero_lam] = 0.0 mean_a = torch.matmul(inv_lam_a, sum_etas.unsqueeze(-1)).squeeze(-1) new_mess = th.retract_gaussian( factor.cf.optim_var_at(i), mean_a, lam_a From eaeab00402a4cd0b0cd2abd051693231b1d2b6f9 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 8 Jul 2022 13:30:55 +0100 Subject: [PATCH 32/64] removed random message schedule --- theseus/core/objective.py | 4 +- theseus/core/vectorizer.py | 9 ++-- theseus/optimizer/gbp/gbp.py | 91 ++++++++++++++++++------------------ 3 files changed, 51 insertions(+), 53 deletions(-) diff --git a/theseus/core/objective.py b/theseus/core/objective.py index a7afdc061..9006fa8b2 100644 --- a/theseus/core/objective.py +++ b/theseus/core/objective.py @@ -505,11 +505,11 @@ def _vectorization_needs_update(self): needs = True return needs - def update_vectorization_if_needed(self, compute_caches=True): + def update_vectorization_if_needed(self): if self.vectorized and self._vectorization_needs_update(): if self._batch_size is None: self.update() - self._vectorization_run(compute_caches=compute_caches) + self._vectorization_run() self._last_vectorization_has_grad = torch.is_grad_enabled() # iterates over cost functions diff --git a/theseus/core/vectorizer.py b/theseus/core/vectorizer.py index 047d4f5b5..d1efeb53d 100644 --- a/theseus/core/vectorizer.py +++ b/theseus/core/vectorizer.py @@ -294,7 +294,7 @@ def _clear_wrapper_caches(self): cf._cached_error = None cf._cached_jacobians = None - def _vectorize(self, compute_caches=True): + def _vectorize(self): self._clear_wrapper_caches() for schema, cost_fn_wrappers in self._schema_dict.items(): var_names = self._var_names[schema] @@ -314,10 +314,9 @@ def _vectorize(self, compute_caches=True): batch_size, len(cost_fn_wrappers), ) - if compute_caches: - Vectorize._compute_error_and_replace_wrapper_caches( - vectorized_cost_fn, cost_fn_wrappers, batch_size - ) + Vectorize._compute_error_and_replace_wrapper_caches( + vectorized_cost_fn, cost_fn_wrappers, batch_size + ) @staticmethod def _vectorized_retract_optim_vars( diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index bc82a70e8..f965ea72e 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -29,8 +29,6 @@ """ TODO - solving inverse problem to compute message mean - - use schedule - - random schedule not vectorized - factor inherits CF class """ @@ -57,18 +55,19 @@ def update(self, params_dict): class GBPSchedule(Enum): SYNCHRONOUS = 0 - RANDOM = 1 def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: return torch.full([max_iters, n_edges], True) -def random_schedule(max_iters, n_edges) -> torch.Tensor: - schedule = torch.full([max_iters, n_edges], False) - ixs = torch.randint(0, n_edges, [max_iters]) - schedule[torch.arange(max_iters), ixs] = True - return schedule +# def random_schedule(max_iters, n_edges) -> torch.Tensor: +# schedule = torch.full([max_iters, n_edges], False) +# # on first step send messages along all edges +# schedule[0] = True +# ixs = torch.randint(0, n_edges, [max_iters]) +# schedule[torch.arange(max_iters), ixs] = True +# return schedule # Initialises message precision to zero @@ -144,7 +143,7 @@ def __init__( self.potential_lam = torch.zeros(self.batch_size, self.dof, self.dof).to( dtype=dtype, device=device ) - self.lin_point = [ + self.lin_point: List[Manifold] = [ var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars ] @@ -217,6 +216,7 @@ def linearize( def comp_mess( self, damping, + schedule, ): num_optim_vars = self.cf.num_optim_vars() new_messages = [] @@ -301,9 +301,9 @@ def comp_mess( # damping in tangent space at linearisation point as message # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. - # mean damping + # Damping is applied to the mean parameters. do_damping = torch.logical_and(damping[v] > 0, self.steps_since_lin > 0) - if do_damping.sum() > 0: + if do_damping.sum() != 0: damping_check = torch.logical_and( new_mess_lam.count_nonzero(1, 2) != 0, self.ftov_msgs[v].precision.count_nonzero(1, 2) != 0, @@ -324,6 +324,16 @@ def comp_mess( new_mess_lam, new_mess_mean.unsqueeze(-1) ).squeeze(-1) + # don't send messages if schedule is False + if not schedule[v].all(): + # if any are False set these to prev message + prev_mess_eta, prev_mess_lam = th.local_gaussian( + self.lin_point[v], self.ftov_msgs[v], return_mean=False + ) + no_update = ~schedule[v] + new_mess_eta[no_update] = prev_mess_eta[no_update] + new_mess_lam[no_update] = prev_mess_lam[no_update] + new_mess_lam = th.DenseSolver._apply_damping( new_mess_lam, self.lin_system_damping, @@ -508,10 +518,7 @@ def _merge_infos( GBP functions """ - def _pass_var_to_fac_messages_loop( - self, - update_belief=True, - ): + def _pass_var_to_fac_messages_loop(self, update_belief=True): for i, var in enumerate(self.ordering): # Collect all incoming messages in the tangent space at the current belief @@ -560,10 +567,7 @@ def _pass_var_to_fac_messages_loop( new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _pass_var_to_fac_messages_vectorized( - self, - update_belief=True, - ): + def _pass_var_to_fac_messages_vectorized(self, update_belief=True): # Each (variable-type, dof) gets mapped to a tuple with: # - the variable that will hold the vectorized data # - all the variables of that type that will be vectorized together @@ -669,15 +673,11 @@ def _pass_var_to_fac_messages_vectorized( if lam_a.count_nonzero() == 0: msg.zero_message() else: - zero_lam = lam_a.count_nonzero(1, 2) == 0 - # add to zero precision matrices so inversion doesn't fail - lam_a[zero_lam] += torch.eye( - lam_a.shape[1], dtype=lam_a.dtype, device=lam_a.device + valid_lam = lam_a.count_nonzero(1, 2) != 0 + inv_lam_a = torch.zeros_like( + lam_a, dtype=lam_a.dtype, device=lam_a.device ) - inv_lam_a = torch.linalg.inv(lam_a) - # restore zeros precision matrices - lam_a[zero_lam] = 0.0 - inv_lam_a[zero_lam] = 0.0 + inv_lam_a[valid_lam] = torch.linalg.inv(lam_a[valid_lam]) mean_a = torch.matmul(inv_lam_a, sum_etas.unsqueeze(-1)).squeeze(-1) new_mess = th.retract_gaussian( factor.cf.optim_var_at(i), mean_a, lam_a @@ -713,23 +713,19 @@ def _linearize_factors(self, relin_threshold: float = None): relins += int((factor.steps_since_lin == 0).sum().item()) return relins - def _pass_fac_to_var_messages( - self, - schedule: torch.Tensor, - damping: torch.Tensor, - ): - - # USE THE SCHEDULE!!!!! - + def _pass_fac_to_var_messages(self, schedule: torch.Tensor, damping: torch.Tensor): start_d = 0 for j, factor in enumerate(self.factors): num_optim_vars = factor.cf.num_optim_vars() - n_factors = num_optim_vars * factor.batch_size - damping_tsr = damping[start_d : start_d + n_factors] + n_edges = num_optim_vars * factor.batch_size + damping_tsr = damping[start_d : start_d + n_edges] + schedule_tsr = schedule[start_d : start_d + n_edges] damping_tsr = damping_tsr.reshape(num_optim_vars, factor.batch_size) - start_d += n_factors + schedule_tsr = schedule_tsr.reshape(num_optim_vars, factor.batch_size) + start_d += n_edges - factor.comp_mess(damping_tsr) + if schedule_tsr.sum() != 0: + factor.comp_mess(damping_tsr, schedule_tsr) """ Optimization loop functions @@ -757,6 +753,11 @@ def _optimize_loop( raise ValueError( f"Dropout probability must be between 0 and 1. Got {dropout}." ) + if dropout > 0.9: + print( + "Disabling vectorization due to dropout > 0.9 in GBP message schedule." + ) + self.objective.disable_vectorization() # initialise beliefs for var in self.ordering: @@ -812,9 +813,7 @@ def _optimize_loop( self.n_edges = sum( [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] ) - if schedule == GBPSchedule.RANDOM: - ftov_schedule = random_schedule(self.params.max_iterations, self.n_edges) - elif schedule == GBPSchedule.SYNCHRONOUS: + if schedule == GBPSchedule.SYNCHRONOUS: ftov_schedule = synchronous_schedule( self.params.max_iterations, self.n_edges ) @@ -837,10 +836,10 @@ def _optimize_loop( device=self.ordering[0].device, dtype=self.ordering[0].dtype, ) - # dropout can be implemented through damping - if dropout != 0.0: + # dropout is implemented by changing the schedule + if dropout != 0.0 and it_ != 0: dropout_ixs = torch.rand(self.n_edges) < dropout - damping_arr[dropout_ixs] = 1.0 + ftov_schedule[it_, dropout_ixs] = False t0 = time.time() relins = self._linearize_factors(relin_threshold) @@ -857,7 +856,7 @@ def _optimize_loop( t_vec = 0.0 if self.objective.vectorized: t1 = time.time() - self.objective.update_vectorization_if_needed(compute_caches=False) + self.objective.update_vectorization_if_needed() t_vec = time.time() - t1 t_tot = time.time() - t0 From d8c774afd6cf693ec093f3c0b839d2bfd8b4eab4 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Mon, 18 Jul 2022 08:34:15 +0100 Subject: [PATCH 33/64] damping linear system --- theseus/optimizer/gbp/gbp.py | 69 ++++++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 10 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index fa662c86c..63ec4e5af 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -28,6 +28,7 @@ """ TODO + - batch the factor damping params - solving inverse problem to compute message mean - factor inherits CF class """ @@ -128,7 +129,6 @@ def __init__( self.cf = cf self.var_ixs = var_ixs - self.lin_system_damping = lin_system_damping # batch_size of the vectorized factor. In general != objective.batch_size. # They are equal without vectorization or for unique cost function schema. @@ -151,6 +151,14 @@ def __init__( self.batch_size, device=device, dtype=torch.int ) + # self.lm_damping = lin_system_damping + self.lm_damping = torch.full([self.batch_size], lin_system_damping).to( + dtype=dtype, device=device + ) + self.last_err: torch.Tensor = None + self.a = 2 + self.b = 10 + self.vtof_msgs: List[Message] = [] self.ftov_msgs: List[Message] = [] for var in cf.optim_vars: @@ -171,6 +179,7 @@ def __init__( def linearize( self, relin_threshold: float = None, + err_change: float = 0.0, lie=True, ): self.steps_since_lin += 1 @@ -206,6 +215,39 @@ def linearize( eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) + # update lm_damping + err = (self.cf.error() ** 2).sum(dim=1) + if self.last_err is not None: + decreased_ixs = err < self.last_err + self.lm_damping[decreased_ixs] = torch.max( + self.lm_damping[decreased_ixs] / self.a, torch.Tensor([1e-4]) + ) + self.lm_damping[~decreased_ixs] = ( + self.lm_damping[~decreased_ixs] * self.b + ) + self.last_err = err + # damp precision matrix + damped_D = self.lm_damping[:, None, None] * torch.eye( + lam.shape[1], device=lam.device, dtype=lam.dtype + ).unsqueeze(0).repeat(self.batch_size, 1, 1) + lam = lam + damped_D + + # # damping unchanged if err_change == 0 + # if err_change < 0.0: + # self.lm_damping = max(self.lm_damping / self.a, 1e-4) + # elif err_change > 0.0: + # self.lm_damping = self.lm_damping * self.b + # lam = th.DenseSolver._apply_damping( + # lam, + # damping=self.lm_damping, + # ellipsoidal=False, + # eps=1e-8, + # ) + + # if self.name == 'robust_Reprojection__1_copy': + # print('lm damping', self.lm_damping) + # print('err change', err_change) + self.potential_eta[do_lin] = eta[do_lin] self.potential_lam[do_lin] = lam[do_lin] @@ -336,12 +378,12 @@ def comp_mess( new_mess_eta[no_update] = prev_mess_eta[no_update] new_mess_lam[no_update] = prev_mess_lam[no_update] - new_mess_lam = th.DenseSolver._apply_damping( - new_mess_lam, - self.lin_system_damping, - ellipsoidal=True, - eps=1e-8, - ) + # new_mess_lam = th.DenseSolver._apply_damping( + # new_mess_lam, + # self.lin_system_damping, + # ellipsoidal=True, + # eps=1e-8, + # ) new_mess_mean = th.LUDenseSolver._solve_sytem( new_mess_eta[..., None], new_mess_lam @@ -744,10 +786,12 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): self.beliefs[ix].update([belief_mean_slice], belief_precision_slice) start_idx += batch_size - def _linearize_factors(self, relin_threshold: float = None): + def _linearize_factors( + self, relin_threshold: float = None, err_change: float = 0.0 + ): relins = 0 for factor in self.factors: - factor.linearize(relin_threshold=relin_threshold) + factor.linearize(relin_threshold=relin_threshold, err_change=err_change) relins += int((factor.steps_since_lin == 0).sum().item()) return relins @@ -879,7 +923,12 @@ def _optimize_loop( ftov_schedule[it_, dropout_ixs] = False t0 = time.time() - relins = self._linearize_factors(relin_threshold) + err_change = 0.0 + if it_ > 0: + err_change = ( + info.err_history[0, it_] - info.err_history[0, it_ - 1] + ).item() + relins = self._linearize_factors(relin_threshold, err_change) t_relin = time.time() - t0 t1 = time.time() From ea5173faf59ea1ea761306373ae3a21732080992 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 18 Jul 2022 15:07:50 +0100 Subject: [PATCH 34/64] local linear damping, fixes for gbp on gpu --- theseus/optimizer/gbp/ba_test.py | 61 ++++++++++--------- theseus/optimizer/gbp/ba_viewer.py | 6 +- theseus/optimizer/gbp/gbp.py | 43 +++---------- .../{vectorize_test.py => vectorize_poc.py} | 0 4 files changed, 46 insertions(+), 64 deletions(-) rename theseus/optimizer/gbp/{vectorize_test.py => vectorize_poc.py} (100%) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index a12679ea7..db891ac86 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -68,19 +68,24 @@ def run(cfg: omegaconf.OmegaConf): num_points=cfg["num_points"], average_track_length=cfg["average_track_length"], track_locality=cfg["track_locality"], - feat_random=0.0, - prob_feat_is_outlier=0.0, + feat_random=1.5, + prob_feat_is_outlier=0.02, outlier_feat_random=70, - cam_pos_rand=0.5, - cam_rot_rand=0.1, - point_rand=5.0, + cam_pos_rand=5.0, + cam_rot_rand=0.9, + point_rand=10.0, ) # cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( - # "/media/joe/3.0TB Hard Disk/bal_data/problem-21-11315-pre.txt") + # # "/home/joe/Downloads/riku/fr1desk.txt", drop_obs=0.0) + # "/mnt/sda/bal/problem-21-11315-pre.txt", drop_obs=0.0) # ba = theg.BundleAdjustmentDataset(cams, points, obs) # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") + print("Cameras:", len(ba.cameras)) + print("Points:", len(ba.points)) + print("Observations:", len(ba.observations), "\n") + # param that control transition from squared loss to huber radius_tensor = torch.tensor([1.0], dtype=torch.float64) log_loss_radius = th.Vector(tensor=radius_tensor, name="log_loss_radius") @@ -134,7 +139,7 @@ def run(cfg: omegaconf.OmegaConf): camera_pose_vars: List[th.LieGroup] = [ objective.optim_vars[c.pose.name] for c in ba.cameras # type: ignore ] - if cfg["inner_optim"]["ratio_known_cameras"] > 0.0: + if cfg["inner_optim"]["ratio_known_cameras"] > 0.0 and ba.gt_cameras is not None: w = 1000.0 camera_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) for i in range(len(ba.cameras)): @@ -150,19 +155,20 @@ def run(cfg: omegaconf.OmegaConf): ) ) - # print("Factors:\n", objective.cost_functions.keys(), "\n") - # Create optimizer and theseus layer vectorize = True optimizer = cfg["optimizer_cls"]( objective, max_iterations=cfg["inner_optim"]["max_iters"], vectorize=vectorize, + # linearization_cls=th.SparseLinearization, + # linear_solver_cls=th.LUCudaSparseSolver, ) theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) # device = "cuda" if torch.cuda.is_available() else "cpu" # theseus_optim.to(device) + # print('Device:', device) optim_arg = { "track_best_solution": True, @@ -171,15 +177,15 @@ def run(cfg: omegaconf.OmegaConf): "verbose": True, "backward_mode": th.BackwardMode.FULL, } - if cfg["optimizer_cls"] == GaussianBeliefPropagation: - gbp_optim_arg = { + if isinstance(optimizer, GaussianBeliefPropagation): + extra_args = { "relin_threshold": 0.0000000001, "damping": 0.0, "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": 1e-5, + "lin_system_damping": 1.0e-4, } - optim_arg = {**optim_arg, **gbp_optim_arg} + optim_arg = {**optim_arg, **extra_args} theseus_inputs = {} for cam in ba.cameras: @@ -187,28 +193,27 @@ def run(cfg: omegaconf.OmegaConf): for pt in ba.points: theseus_inputs[pt.name] = pt.tensor.clone() - with torch.no_grad(): - camera_loss_ref = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS: {camera_loss_ref: .3f}") + if ba.gt_cameras is not None: + with torch.no_grad(): + camera_loss_ref = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS: {camera_loss_ref: .3f}") print_histogram(ba, theseus_inputs, "Input histogram:") objective.update(theseus_inputs) print("squred err:", objective.error_squared_norm().item()) - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=optim_arg, - ) + with torch.no_grad(): + theseus_outputs, info = theseus_optim.forward( + input_tensors=theseus_inputs, + optimizer_kwargs=optim_arg, + ) - loss = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") + if ba.gt_cameras is not None: + loss = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") are = average_repojection_error(objective) print("Average reprojection error (pixels): ", are) - - with torch.no_grad(): - camera_loss_ref = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS: {camera_loss_ref: .3f}") print_histogram(ba, theseus_inputs, "Final histogram:") # BAViewer( @@ -227,12 +232,12 @@ def run(cfg: omegaconf.OmegaConf): "optimizer_cls": GaussianBeliefPropagation, # "optimizer_cls": th.GaussNewton, "inner_optim": { - "max_iters": 10, + "max_iters": 20, "verbose": True, "track_err_history": True, "keep_step_size": True, "regularize": True, - "ratio_known_cameras": 0.3, + "ratio_known_cameras": 0.1, "reg_w": 1e-7, }, } diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py index 2b35df86f..80534b7af 100644 --- a/theseus/optimizer/gbp/ba_viewer.py +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -43,11 +43,11 @@ def __init__( if gt_cameras is not None: for i, cam in enumerate(gt_cameras): - camera = self.make_cam(cam.pose.tensor[0]) + camera = self.make_cam(cam.pose.tensor[0].cpu()) self.scene.add_geometry(camera[1], geom_name=f"gt_cam_{i}") if gt_points is not None: - pts = torch.cat([pt.tensor for pt in gt_points]) + pts = torch.cat([pt.tensor.cpu() for pt in gt_points]) pc = trimesh.PointCloud(pts, [0, 255, 0, 200]) self.scene.add_geometry(pc, geom_name="gt_points") @@ -127,7 +127,7 @@ def next_iteration(self): points = [] n_cams, n_pts = 0, 0 for state in self.state_history.values(): - state = state[..., self._it] + state = state[..., self._it].cpu() if state.ndim == 3: camera = self.make_cam(state[0], color=(0.0, 0.0, 1.0, 0.8)) self.scene.delete_geometry(f"cam_{n_cams}") diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 63ec4e5af..d3d4bf797 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -28,7 +28,6 @@ """ TODO - - batch the factor damping params - solving inverse problem to compute message mean - factor inherits CF class """ @@ -151,10 +150,11 @@ def __init__( self.batch_size, device=device, dtype=torch.int ) - # self.lm_damping = lin_system_damping self.lm_damping = torch.full([self.batch_size], lin_system_damping).to( dtype=dtype, device=device ) + self.min_damping = torch.Tensor([1e-4]).to(dtype=dtype, device=device) + self.max_damping = torch.Tensor([1e2]).to(dtype=dtype, device=device) self.last_err: torch.Tensor = None self.a = 2 self.b = 10 @@ -179,7 +179,6 @@ def __init__( def linearize( self, relin_threshold: float = None, - err_change: float = 0.0, lie=True, ): self.steps_since_lin += 1 @@ -215,15 +214,15 @@ def linearize( eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) - # update lm_damping + # update damping parameter err = (self.cf.error() ** 2).sum(dim=1) if self.last_err is not None: decreased_ixs = err < self.last_err self.lm_damping[decreased_ixs] = torch.max( - self.lm_damping[decreased_ixs] / self.a, torch.Tensor([1e-4]) + self.lm_damping[decreased_ixs] / self.a, self.min_damping ) - self.lm_damping[~decreased_ixs] = ( - self.lm_damping[~decreased_ixs] * self.b + self.lm_damping[~decreased_ixs] = torch.min( + self.lm_damping[~decreased_ixs] * self.b, self.max_damping ) self.last_err = err # damp precision matrix @@ -232,22 +231,6 @@ def linearize( ).unsqueeze(0).repeat(self.batch_size, 1, 1) lam = lam + damped_D - # # damping unchanged if err_change == 0 - # if err_change < 0.0: - # self.lm_damping = max(self.lm_damping / self.a, 1e-4) - # elif err_change > 0.0: - # self.lm_damping = self.lm_damping * self.b - # lam = th.DenseSolver._apply_damping( - # lam, - # damping=self.lm_damping, - # ellipsoidal=False, - # eps=1e-8, - # ) - - # if self.name == 'robust_Reprojection__1_copy': - # print('lm damping', self.lm_damping) - # print('err change', err_change) - self.potential_eta[do_lin] = eta[do_lin] self.potential_lam[do_lin] = lam[do_lin] @@ -725,6 +708,7 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): ) indices = indices + shift indices = indices.flatten() + indices = indices.to(factor.cf.optim_var_at(0).device) factor.vectorized_var_ixs[i] = indices # add messages to correct variable using indices @@ -786,12 +770,10 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): self.beliefs[ix].update([belief_mean_slice], belief_precision_slice) start_idx += batch_size - def _linearize_factors( - self, relin_threshold: float = None, err_change: float = 0.0 - ): + def _linearize_factors(self, relin_threshold: float = None): relins = 0 for factor in self.factors: - factor.linearize(relin_threshold=relin_threshold, err_change=err_change) + factor.linearize(relin_threshold=relin_threshold) relins += int((factor.steps_since_lin == 0).sum().item()) return relins @@ -923,12 +905,7 @@ def _optimize_loop( ftov_schedule[it_, dropout_ixs] = False t0 = time.time() - err_change = 0.0 - if it_ > 0: - err_change = ( - info.err_history[0, it_] - info.err_history[0, it_ - 1] - ).item() - relins = self._linearize_factors(relin_threshold, err_change) + relins = self._linearize_factors(relin_threshold) t_relin = time.time() - t0 t1 = time.time() diff --git a/theseus/optimizer/gbp/vectorize_test.py b/theseus/optimizer/gbp/vectorize_poc.py similarity index 100% rename from theseus/optimizer/gbp/vectorize_test.py rename to theseus/optimizer/gbp/vectorize_poc.py From c2507e30ae459b128827bfeb08c248eeea90ea1a Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 18 Jul 2022 15:08:28 +0100 Subject: [PATCH 35/64] handle loading different format bal file and drop observations --- .../utils/examples/bundle_adjustment/data.py | 45 ++++++++++++------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/theseus/utils/examples/bundle_adjustment/data.py b/theseus/utils/examples/bundle_adjustment/data.py index 2c2dcfcbd..903171694 100644 --- a/theseus/utils/examples/bundle_adjustment/data.py +++ b/theseus/utils/examples/bundle_adjustment/data.py @@ -169,7 +169,7 @@ def __init__( self.gt_points = gt_points @staticmethod - def load_bal_dataset(path: str): + def load_bal_dataset(path: str, drop_obs=0.0): observations = [] cameras = [] points = [] @@ -177,26 +177,41 @@ def load_bal_dataset(path: str): num_cameras, num_points, num_observations = [ int(x) for x in out.readline().rstrip().split() ] + + fields = out.readline().rstrip().split() + intrinsics = None + if len(fields) == 4: + intrinsics = [ + (float(fields[0]) + float(fields[1])) / 2.0, + float(fields[2]), + float(fields[3]), + ] + for i in range(num_observations): - fields = out.readline().rstrip().split() - feat = th.Point2( - tensor=torch.tensor( - [float(fields[2]), float(fields[3])], dtype=torch.float64 - ).unsqueeze(0), - name=f"Feat{i}", - ) - observations.append( - Observation( - camera_index=int(fields[0]), - point_index=int(fields[1]), - image_feature_point=feat, + if i > 0 or intrinsics is not None: + fields = out.readline().rstrip().split() + if np.random.rand() > drop_obs: + feat = th.Point2( + tensor=torch.tensor( + [float(fields[2]), float(fields[3])], dtype=torch.float64 + ).unsqueeze(0), + name=f"Feat{i}", + ) + observations.append( + Observation( + camera_index=int(fields[0]), + point_index=int(fields[1]), + image_feature_point=feat, + ) ) - ) for i in range(num_cameras): params = [] - for _ in range(9): + n_params = 6 if intrinsics is not None else 9 + for _ in range(n_params): params.append(float(out.readline().rstrip())) + if intrinsics is not None: + params.extend(intrinsics) cameras.append(Camera.from_params(params, name=f"Cam{i}")) for i in range(num_points): From 400091b995a3c8f33abde602dcbc21a99ffc5ab1 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 20 Jul 2022 09:13:25 +0100 Subject: [PATCH 36/64] gbp check unary factor, fix bug in ba viewer --- theseus/optimizer/gbp/ba_test.py | 144 +++++++++++------- theseus/optimizer/gbp/ba_viewer.py | 10 +- theseus/optimizer/gbp/gbp.py | 18 ++- .../{jax_torch_test.py => jax_torch_poc.py} | 0 .../utils/examples/bundle_adjustment/data.py | 2 +- 5 files changed, 114 insertions(+), 60 deletions(-) rename theseus/optimizer/gbp/{jax_torch_test.py => jax_torch_poc.py} (100%) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/ba_test.py index db891ac86..9610d145a 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/ba_test.py @@ -7,8 +7,11 @@ import numpy as np import omegaconf +import time import torch +# import os + import theseus as th import theseus.utils.examples as theg from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule @@ -16,6 +19,13 @@ # from theseus.optimizer.gbp import BAViewer +OPTIMIZER_CLASS = { + "gbp": GaussianBeliefPropagation, + "gauss_newton": th.GaussNewton, + "levenberg_marquardt": th.LevenbergMarquardt, +} + + def print_histogram( ba: theg.BundleAdjustmentDataset, var_dict: Dict[str, torch.Tensor], msg: str ): @@ -63,24 +73,24 @@ def average_repojection_error(objective) -> float: def run(cfg: omegaconf.OmegaConf): # create (or load) dataset - ba = theg.BundleAdjustmentDataset.generate_synthetic( - num_cameras=cfg["num_cameras"], - num_points=cfg["num_points"], - average_track_length=cfg["average_track_length"], - track_locality=cfg["track_locality"], - feat_random=1.5, - prob_feat_is_outlier=0.02, - outlier_feat_random=70, - cam_pos_rand=5.0, - cam_rot_rand=0.9, - point_rand=10.0, - ) - - # cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( - # # "/home/joe/Downloads/riku/fr1desk.txt", drop_obs=0.0) - # "/mnt/sda/bal/problem-21-11315-pre.txt", drop_obs=0.0) - # ba = theg.BundleAdjustmentDataset(cams, points, obs) - # ba.save_to_file(results_path / "ba.txt", gt_path=results_path / "ba_gt.txt") + if cfg["bal_file"] is None: + ba = theg.BundleAdjustmentDataset.generate_synthetic( + num_cameras=cfg["synthetic"]["num_cameras"], + num_points=cfg["synthetic"]["num_points"], + average_track_length=cfg["synthetic"]["average_track_length"], + track_locality=cfg["synthetic"]["track_locality"], + feat_random=1.5, + prob_feat_is_outlier=0.02, + outlier_feat_random=70, + cam_pos_rand=5.0, + cam_rot_rand=0.9, + point_rand=10.0, + ) + else: + cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( + cfg["bal_file"], drop_obs=0.0 + ) + ba = theg.BundleAdjustmentDataset(cams, points, obs) print("Cameras:", len(ba.cameras)) print("Points:", len(ba.points)) @@ -92,6 +102,7 @@ def run(cfg: omegaconf.OmegaConf): # Set up objective print("Setting up objective") + t0 = time.time() objective = th.Objective(dtype=torch.float64) weight = th.ScaleCostWeight(torch.tensor(1.0).to(dtype=ba.cameras[0].pose.dtype)) @@ -117,33 +128,36 @@ def run(cfg: omegaconf.OmegaConf): dtype = objective.dtype # Add regularization - if cfg["inner_optim"]["regularize"]: - zero_point3 = th.Point3(dtype=dtype, name="zero_point") + if cfg["optim"]["regularize"]: + # zero_point3 = th.Point3(dtype=dtype, name="zero_point") # identity_se3 = th.SE3(dtype=dtype, name="zero_se3") - w = np.sqrt(cfg["inner_optim"]["reg_w"]) + w = np.sqrt(cfg["optim"]["reg_w"]) damping_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) for name, var in objective.optim_vars.items(): target: th.Manifold if isinstance(var, th.SE3): target = var.copy(new_name="target_" + var.name) # target = identity_se3 - elif isinstance(var, th.Point3): - # target = var.copy(new_name="target_" + var.name) - target = zero_point3 - else: - assert False - objective.add( - th.Difference(var, target, damping_weight, name=f"reg_{name}") - ) + objective.add( + th.Difference(var, target, damping_weight, name=f"reg_{name}") + ) + # elif isinstance(var, th.Point3): + # target = var.copy(new_name="target_" + var.name) + # # target = zero_point3 + # else: + # assert False + # objective.add( + # th.Difference(var, target, damping_weight, name=f"reg_{name}") + # ) camera_pose_vars: List[th.LieGroup] = [ objective.optim_vars[c.pose.name] for c in ba.cameras # type: ignore ] - if cfg["inner_optim"]["ratio_known_cameras"] > 0.0 and ba.gt_cameras is not None: + if cfg["optim"]["ratio_known_cameras"] > 0.0 and ba.gt_cameras is not None: w = 1000.0 camera_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) for i in range(len(ba.cameras)): - if np.random.rand() > cfg["inner_optim"]["ratio_known_cameras"]: + if np.random.rand() > cfg["optim"]["ratio_known_cameras"]: continue print("fixing cam", i) objective.add( @@ -154,26 +168,28 @@ def run(cfg: omegaconf.OmegaConf): name=f"camera_diff_{i}", ) ) + print("done in:", time.time() - t0) # Create optimizer and theseus layer vectorize = True - optimizer = cfg["optimizer_cls"]( + optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( objective, - max_iterations=cfg["inner_optim"]["max_iters"], + max_iterations=cfg["optim"]["max_iters"], vectorize=vectorize, # linearization_cls=th.SparseLinearization, # linear_solver_cls=th.LUCudaSparseSolver, ) theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) - # device = "cuda" if torch.cuda.is_available() else "cpu" - # theseus_optim.to(device) - # print('Device:', device) + if cfg["device"] == "cuda": + cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" + theseus_optim.to(cfg["device"]) + print("Device:", cfg["device"]) optim_arg = { - "track_best_solution": True, + "track_best_solution": False, "track_err_history": True, - "track_state_history": True, + "track_state_history": cfg["optim"]["track_state_history"], "verbose": True, "backward_mode": th.BackwardMode.FULL, } @@ -183,9 +199,9 @@ def run(cfg: omegaconf.OmegaConf): "damping": 0.0, "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": 1.0e-4, + "lin_system_damping": 1.0e-0, } - optim_arg = {**optim_arg, **extra_args} + optim_arg = {**optim_arg, **extra_args} theseus_inputs = {} for cam in ba.cameras: @@ -197,6 +213,8 @@ def run(cfg: omegaconf.OmegaConf): with torch.no_grad(): camera_loss_ref = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: {camera_loss_ref: .3f}") + are = average_repojection_error(objective) + print("Average reprojection error (pixels): ", are) print_histogram(ba, theseus_inputs, "Input histogram:") objective.update(theseus_inputs) @@ -214,28 +232,44 @@ def run(cfg: omegaconf.OmegaConf): are = average_repojection_error(objective) print("Average reprojection error (pixels): ", are) - print_histogram(ba, theseus_inputs, "Final histogram:") + print_histogram(ba, theseus_outputs, "Final histogram:") + + # if cfg["optim"]["track_state_history"]: + # BAViewer( + # info.state_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points + # ) # , msg_history=optimizer.ftov_msgs_history) - # BAViewer( - # info.state_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points - # ) # , msg_history=optimizer.ftov_msgs_history) + # if cfg["bal_file"] is not None: + # save_dir = os.path.join(os.getcwd(), "outputs") + # if not os.path.exists(save_dir): + # os.mkdir(save_dir) + # err_history = info.err_history[0].cpu().numpy() + # save_file = os.path.join( + # save_dir, + # f"{cfg['optim']['optimizer_cls']}_{cfg['bal_file'].split('/')[-1]}", + # ) + # np.savetxt(save_file, err_history) if __name__ == "__main__": cfg = { "seed": 1, - "num_cameras": 10, - "num_points": 100, - "average_track_length": 8, - "track_locality": 0.2, - "optimizer_cls": GaussianBeliefPropagation, - # "optimizer_cls": th.GaussNewton, - "inner_optim": { - "max_iters": 20, - "verbose": True, - "track_err_history": True, - "keep_step_size": True, + "device": "cpu", + # "bal_file": None, + "bal_file": "/media/joe/data/bal/trafalgar/problem-21-11315-pre.txt", + "synthetic": { + "num_cameras": 10, + "num_points": 100, + "average_track_length": 8, + "track_locality": 0.2, + }, + "optim": { + "max_iters": 500, + "optimizer_cls": "gbp", + # "optimizer_cls": "gauss_newton", + # "optimizer_cls": "levenberg_marquardt", + "track_state_history": False, "regularize": True, "ratio_known_cameras": 0.1, "reg_w": 1e-7, diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py index 80534b7af..04686c1a2 100644 --- a/theseus/optimizer/gbp/ba_viewer.py +++ b/theseus/optimizer/gbp/ba_viewer.py @@ -36,7 +36,14 @@ def __init__( self.flip_z = flip_z self.lock = threading.Lock() - self.num_iters = list(state_history.values())[0].shape[-1] + self.num_iters = (~list(state_history.values())[0].isinf()[0, 0, 0]).sum() + + pts = [] + for k, state in state_history.items(): + if "Pt" in k: + pts.append(state[:, :, 0]) + extents = torch.cat(pts).max(dim=0)[0] - torch.cat(pts).min(dim=0)[0] + self.marker_height = extents.max().item() / 50 scene = trimesh.Scene() self.scene = scene @@ -119,6 +126,7 @@ def make_cam(self, pose, color=(0.0, 1.0, 0.0, 0.8)): self.scene.camera.fov, self.scene.camera.resolution, color=color, + marker_height=self.marker_height, ) return camera diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index d3d4bf797..15eba5495 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -754,7 +754,11 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): lam_tau = eta_lam[1] if update_belief and lam_tau.count_nonzero() != 0: - inv_lam_tau = torch.inverse(lam_tau) + valid_lam = lam_tau.count_nonzero(1, 2) != 0 + inv_lam_tau = torch.zeros_like( + lam_tau, dtype=lam_tau.dtype, device=lam_tau.device + ) + inv_lam_tau[valid_lam] = torch.linalg.inv(lam_tau[valid_lam]) tau = torch.matmul(inv_lam_tau, eta_tp_acc.unsqueeze(-1)).squeeze(-1) new_belief = th.retract_gaussian(vectorized_var, tau, lam_tau) @@ -841,6 +845,7 @@ def _optimize_loop( self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop # compute factor potentials for the first time + unary_factor = False for i, cost_function in enumerate(cf_iterator): if self.objective.vectorized: @@ -871,6 +876,13 @@ def _optimize_loop( lin_system_damping=lin_system_damping, ) ) + if cost_function.num_optim_vars() == 1: + unary_factor = True + if unary_factor is False: + raise Exception( + "We require at least one unary cost function to act as a prior." + "This is because Gaussian Belief Propagation is performing Bayesian inference." + ) relins = self._linearize_factors() self.n_edges = sum( @@ -962,11 +974,11 @@ def _optimize_impl( track_state_history: bool = False, verbose: bool = False, backward_mode: BackwardMode = BackwardMode.FULL, - relin_threshold: float = 0.1, + relin_threshold: float = 1e-8, damping: float = 0.0, dropout: float = 0.0, schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, - lin_system_damping: float = 1e-6, + lin_system_damping: float = 1e-4, **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): diff --git a/theseus/optimizer/gbp/jax_torch_test.py b/theseus/optimizer/gbp/jax_torch_poc.py similarity index 100% rename from theseus/optimizer/gbp/jax_torch_test.py rename to theseus/optimizer/gbp/jax_torch_poc.py diff --git a/theseus/utils/examples/bundle_adjustment/data.py b/theseus/utils/examples/bundle_adjustment/data.py index 903171694..42502f200 100644 --- a/theseus/utils/examples/bundle_adjustment/data.py +++ b/theseus/utils/examples/bundle_adjustment/data.py @@ -180,7 +180,7 @@ def load_bal_dataset(path: str, drop_obs=0.0): fields = out.readline().rstrip().split() intrinsics = None - if len(fields) == 4: + if "." in fields[0]: intrinsics = [ (float(fields[0]) + float(fields[1])) / 2.0, float(fields[2]), From 9db961227094ef20955c7412caaac472aa831740 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 20 Jul 2022 09:27:37 +0100 Subject: [PATCH 37/64] ba error plot --- theseus/optimizer/gbp/plot_ba_err.py | 31 ++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 theseus/optimizer/gbp/plot_ba_err.py diff --git a/theseus/optimizer/gbp/plot_ba_err.py b/theseus/optimizer/gbp/plot_ba_err.py new file mode 100644 index 000000000..543555e14 --- /dev/null +++ b/theseus/optimizer/gbp/plot_ba_err.py @@ -0,0 +1,31 @@ +import numpy as np +import matplotlib.pylab as plt +import os + + +root_dir = "/home/joe/projects/theseus/theseus/optimizer/gbp/outputs" +err_files1 = [ + "gbp_problem-21-11315-pre.txt", + "levenberg_marquardt_problem-21-11315-pre.txt", +] +err_files2 = [ + "gbp_problem-50-20431-pre.txt", + "levenberg_marquardt_problem-50-20431-pre.txt", +] + + +err_files = err_files1 + +for err_files in [err_files1, err_files2]: + + gbp_err = np.loadtxt(os.path.join(root_dir, err_files[0])) + lm_err = np.loadtxt(os.path.join(root_dir, err_files[1])) + + plt.plot(gbp_err, label="GBP") + plt.plot(lm_err, label="Levenberg Marquardt") + plt.xscale("log") + plt.title(err_files[0][4:]) + plt.xlabel("Iterations") + plt.ylabel("Total Energy") + plt.legend() + plt.show() From 12f0b77ac9fe8843a9c96c491f7e063c8352476f Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 22 Jul 2022 18:53:14 +0100 Subject: [PATCH 38/64] fixed are calculation --- .../gbp/{ba_test.py => bundle_adjustment.py} | 73 +++++++++++++++---- 1 file changed, 58 insertions(+), 15 deletions(-) rename theseus/optimizer/gbp/{ba_test.py => bundle_adjustment.py} (77%) diff --git a/theseus/optimizer/gbp/ba_test.py b/theseus/optimizer/gbp/bundle_adjustment.py similarity index 77% rename from theseus/optimizer/gbp/ba_test.py rename to theseus/optimizer/gbp/bundle_adjustment.py index 9610d145a..2d50c1735 100644 --- a/theseus/optimizer/gbp/ba_test.py +++ b/theseus/optimizer/gbp/bundle_adjustment.py @@ -13,6 +13,7 @@ # import os import theseus as th +from theseus.core import Vectorize import theseus.utils.examples as theg from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule @@ -59,12 +60,18 @@ def camera_loss( return loss -def average_repojection_error(objective) -> float: - +# Assumes the weight of the cost functions are 1 +def average_repojection_error(objective, values_dict=None) -> float: + if values_dict is not None: + objective.update(values_dict) + if objective._vectorized is False: + Vectorize(objective) reproj_norms = [] - for k in objective.cost_functions.keys(): - if "Reprojection" in k: - err = objective.cost_functions[k].error().norm(dim=1) + for cost_function in objective._get_iterator(): + if "Reprojection" in cost_function.name: + # should equal error as weight is 1 + # need to call weighted_error as error is not cached + err = cost_function.weighted_error().norm(dim=1) reproj_norms.append(err) are = torch.tensor(reproj_norms).mean().item() @@ -96,6 +103,8 @@ def run(cfg: omegaconf.OmegaConf): print("Points:", len(ba.points)) print("Observations:", len(ba.observations), "\n") + print("Optimizer:", cfg["optim"]["optimizer_cls"], "\n") + # param that control transition from squared loss to huber radius_tensor = torch.tensor([1.0], dtype=torch.float64) log_loss_radius = th.Vector(tensor=radius_tensor, name="log_loss_radius") @@ -103,7 +112,9 @@ def run(cfg: omegaconf.OmegaConf): # Set up objective print("Setting up objective") t0 = time.time() - objective = th.Objective(dtype=torch.float64) + dtype = torch.float64 + objective = th.Objective(dtype=dtype) + dummy_objective = th.Objective(dtype=dtype) # for computing are weight = th.ScaleCostWeight(torch.tensor(1.0).to(dtype=ba.cameras[0].pose.dtype)) for i, obs in enumerate(ba.observations): @@ -125,7 +136,7 @@ def run(cfg: omegaconf.OmegaConf): name=f"robust_{cost_function.name}", ) objective.add(robust_cost_function) - dtype = objective.dtype + dummy_objective.add(cost_function) # Add regularization if cfg["optim"]["regularize"]: @@ -184,6 +195,7 @@ def run(cfg: omegaconf.OmegaConf): if cfg["device"] == "cuda": cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" theseus_optim.to(cfg["device"]) + dummy_objective.to(cfg["device"]) print("Device:", cfg["device"]) optim_arg = { @@ -199,7 +211,7 @@ def run(cfg: omegaconf.OmegaConf): "damping": 0.0, "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": 1.0e-0, + "lin_system_damping": 1.0e-1, } optim_arg = {**optim_arg, **extra_args} @@ -213,12 +225,12 @@ def run(cfg: omegaconf.OmegaConf): with torch.no_grad(): camera_loss_ref = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: {camera_loss_ref: .3f}") - are = average_repojection_error(objective) + are = average_repojection_error(dummy_objective, values_dict=theseus_inputs) print("Average reprojection error (pixels): ", are) print_histogram(ba, theseus_inputs, "Input histogram:") objective.update(theseus_inputs) - print("squred err:", objective.error_squared_norm().item()) + print("squared err:", objective.error_squared_norm().item() / 2) with torch.no_grad(): theseus_outputs, info = theseus_optim.forward( @@ -230,11 +242,11 @@ def run(cfg: omegaconf.OmegaConf): loss = camera_loss(ba, camera_pose_vars).item() print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") - are = average_repojection_error(objective) + are = average_repojection_error(dummy_objective, values_dict=theseus_outputs) print("Average reprojection error (pixels): ", are) print_histogram(ba, theseus_outputs, "Final histogram:") - # if cfg["optim"]["track_state_history"]: + # if info.state_history is not None: # BAViewer( # info.state_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points # ) # , msg_history=optimizer.ftov_msgs_history) @@ -246,10 +258,41 @@ def run(cfg: omegaconf.OmegaConf): # err_history = info.err_history[0].cpu().numpy() # save_file = os.path.join( # save_dir, - # f"{cfg['optim']['optimizer_cls']}_{cfg['bal_file'].split('/')[-1]}", + # f"{cfg['optim']['optimizer_cls']}_err_{cfg['bal_file'].split('/')[-1]}", # ) # np.savetxt(save_file, err_history) + # # get average reprojection error for each iteration + # if info.state_history is not None: + # ares = [] + # iters = ( + # info.converged_iter + # if info.converged_iter != -1 + # else cfg["optim"]["max_iters"] + # ) + # for i in range(iters): + # t0 = time.time() + # values_dict = {} + # for name, state in info.state_history.items(): + # values_dict[name] = ( + # state[..., i].to(dtype=torch.float64).to(dummy_objective.device) + # ) + # are = average_repojection_error(dummy_objective, values_dict=values_dict) + # ares.append(are) + # print(i, "-- ARE:", are, " -- time", time.time() - t0) + # are = average_repojection_error(dummy_objective, values_dict=theseus_outputs) + # ares.append(are) + + # if cfg["bal_file"] is not None: + # save_dir = os.path.join(os.getcwd(), "outputs") + # if not os.path.exists(save_dir): + # os.mkdir(save_dir) + # save_file = os.path.join( + # save_dir, + # f"{cfg['optim']['optimizer_cls']}_are_{cfg['bal_file'].split('/')[-1]}", + # ) + # np.savetxt(save_file, np.array(ares)) + if __name__ == "__main__": @@ -265,11 +308,11 @@ def run(cfg: omegaconf.OmegaConf): "track_locality": 0.2, }, "optim": { - "max_iters": 500, + "max_iters": 300, "optimizer_cls": "gbp", # "optimizer_cls": "gauss_newton", # "optimizer_cls": "levenberg_marquardt", - "track_state_history": False, + "track_state_history": True, "regularize": True, "ratio_known_cameras": 0.1, "reg_w": 1e-7, From 255a5b505dd02957b61dbc5835669199b142aae8 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Tue, 2 Aug 2022 15:02:01 +0100 Subject: [PATCH 39/64] tensor for linear system damping, rename message damping --- theseus/optimizer/gbp/gbp.py | 81 +++++++++++++++++++++--------------- 1 file changed, 47 insertions(+), 34 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 15eba5495..d52575638 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -117,8 +117,8 @@ def __init__( self, cf: CostFunction, var_ixs: torch.Tensor, + lin_system_damping: torch.Tensor, name: Optional[str] = None, - lin_system_damping: float = 1e-6, ): self._id = next(Factor._ids) if name: @@ -150,9 +150,7 @@ def __init__( self.batch_size, device=device, dtype=torch.int ) - self.lm_damping = torch.full([self.batch_size], lin_system_damping).to( - dtype=dtype, device=device - ) + self.lm_damping = lin_system_damping.repeat(self.batch_size) self.min_damping = torch.Tensor([1e-4]).to(dtype=dtype, device=device) self.max_damping = torch.Tensor([1e2]).to(dtype=dtype, device=device) self.last_err: torch.Tensor = None @@ -242,7 +240,7 @@ def linearize( # Compute all outgoing messages from the factor. def comp_mess( self, - damping, + msg_damping, schedule, ): num_optim_vars = self.cf.num_optim_vars() @@ -325,11 +323,12 @@ def comp_mess( torch.bmm(lono, torch.linalg.inv(lnono)), eno.unsqueeze(-1) ).squeeze(-1) - # damping in tangent space at linearisation point as message + # message damping in tangent space at linearisation point as message # is already in this tangent space. Could equally do damping # in the tangent space of the new or old message mean. # Damping is applied to the mean parameters. - do_damping = torch.logical_and(damping[v] > 0, self.steps_since_lin > 0) + # do_damping = torch.logical_and(msg_damping[v] > 0, self.steps_since_lin >= 0) + do_damping = msg_damping[v] if do_damping.sum() != 0: damping_check = torch.logical_and( new_mess_lam.count_nonzero(1, 2) != 0, @@ -343,10 +342,10 @@ def comp_mess( new_mess_mean = torch.bmm( torch.inverse(new_mess_lam), new_mess_eta.unsqueeze(-1) ).squeeze(-1) - damping[v][~do_damping] = 0.0 + msg_damping[v][~do_damping] = 0.0 new_mess_mean = ( - 1 - damping[v][:, None] - ) * new_mess_mean + damping[v][:, None] * prev_mess_mean + 1 - msg_damping[v][:, None] + ) * new_mess_mean + msg_damping[v][:, None] * prev_mess_mean new_mess_eta = torch.bmm( new_mess_lam, new_mess_mean.unsqueeze(-1) ).squeeze(-1) @@ -781,12 +780,14 @@ def _linearize_factors(self, relin_threshold: float = None): relins += int((factor.steps_since_lin == 0).sum().item()) return relins - def _pass_fac_to_var_messages(self, schedule: torch.Tensor, damping: torch.Tensor): + def _pass_fac_to_var_messages( + self, schedule: torch.Tensor, ftov_msg_damping: torch.Tensor + ): start_d = 0 for j, factor in enumerate(self.factors): num_optim_vars = factor.cf.num_optim_vars() n_edges = num_optim_vars * factor.batch_size - damping_tsr = damping[start_d : start_d + n_edges] + damping_tsr = ftov_msg_damping[start_d : start_d + n_edges] schedule_tsr = schedule[start_d : start_d + n_edges] damping_tsr = damping_tsr.reshape(num_optim_vars, factor.batch_size) schedule_tsr = schedule_tsr.reshape(num_optim_vars, factor.batch_size) @@ -807,25 +808,13 @@ def _optimize_loop( verbose: bool, truncated_grad_loop: bool, relin_threshold: float, - damping: float, + ftov_msg_damping: float, dropout: float, schedule: GBPSchedule, - lin_system_damping: float, + lin_system_damping: torch.Tensor, clear_messages: bool = True, **kwargs, ): - if damping > 1.0 or damping < 0.0: - raise ValueError(f"Damping must be between 0 and 1. Got {damping}.") - if dropout > 1.0 or dropout < 0.0: - raise ValueError( - f"Dropout probability must be between 0 and 1. Got {dropout}." - ) - if dropout > 0.9: - print( - "Disabling vectorization due to dropout > 0.9 in GBP message schedule." - ) - self.objective.disable_vectorization() - # initialise beliefs for var in self.ordering: self.beliefs.append(th.ManifoldGaussian([var])) @@ -905,9 +894,9 @@ def _optimize_loop( self.ftov_msgs_history[it_] = curr_ftov_msgs # damping - damping_arr = torch.full( + ftov_damping_arr = torch.full( [self.n_edges], - damping, + ftov_msg_damping, device=self.ordering[0].device, dtype=self.ordering[0].dtype, ) @@ -921,7 +910,7 @@ def _optimize_loop( t_relin = time.time() - t0 t1 = time.time() - self._pass_fac_to_var_messages(ftov_schedule[it_], damping_arr) + self._pass_fac_to_var_messages(ftov_schedule[it_], ftov_damping_arr) t_ftov = time.time() - t1 t1 = time.time() @@ -975,10 +964,10 @@ def _optimize_impl( verbose: bool = False, backward_mode: BackwardMode = BackwardMode.FULL, relin_threshold: float = 1e-8, - damping: float = 0.0, + ftov_msg_damping: float = 0.0, dropout: float = 0.0, schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, - lin_system_damping: float = 1e-4, + lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): @@ -986,6 +975,30 @@ def _optimize_impl( track_best_solution, track_err_history, track_state_history ) + if ftov_msg_damping > 1.0 or ftov_msg_damping < 0.0: + raise ValueError( + f"Damping must be between 0 and 1. Got {ftov_msg_damping}." + ) + if dropout > 1.0 or dropout < 0.0: + raise ValueError( + f"Dropout probability must be between 0 and 1. Got {dropout}." + ) + if dropout > 0.9: + print( + "Disabling vectorization due to dropout > 0.9 in GBP message schedule." + ) + self.objective.disable_vectorization() + + if not isinstance(lin_system_damping, torch.Tensor): + raise TypeError("lin_system_damping should be an instance of torch.Tensor.") + expected_shape = torch.Size([1]) + if lin_system_damping.shape != expected_shape: + raise ValueError( + f"lin_system_damping should have shape {expected_shape}. " + f"Got shape {lin_system_damping.shape}." + ) + lin_system_damping.to(self.objective.device, self.objective.dtype) + if verbose: print( f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" @@ -998,7 +1011,7 @@ def _optimize_impl( verbose=verbose, truncated_grad_loop=False, relin_threshold=relin_threshold, - damping=damping, + ftov_msg_damping=ftov_msg_damping, dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, @@ -1039,7 +1052,7 @@ def _optimize_impl( verbose=verbose, truncated_grad_loop=False, relin_threshold=relin_threshold, - damping=damping, + ftov_msg_damping=ftov_msg_damping, dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, @@ -1055,7 +1068,7 @@ def _optimize_impl( verbose=verbose, truncated_grad_loop=True, relin_threshold=relin_threshold, - damping=damping, + ftov_msg_damping=ftov_msg_damping, dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, From 634ce74e926fa62d975575211ce4413d226f5eda Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Thu, 4 Aug 2022 16:02:20 +0100 Subject: [PATCH 40/64] fixes bug where beleifs and factors are created twice --- theseus/optimizer/gbp/gbp.py | 154 +++++++++++++++++++---------------- 1 file changed, 82 insertions(+), 72 deletions(-) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index d52575638..6b63389bf 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -223,6 +223,7 @@ def linearize( self.lm_damping[~decreased_ixs] * self.b, self.max_damping ) self.last_err = err + # damp precision matrix damped_D = self.lm_damping[:, None, None] * torch.eye( lam.shape[1], device=lam.device, dtype=lam.dtype @@ -278,27 +279,29 @@ def comp_mess( else: # print(self.cf.name, "---> sending message") # Divide up parameters of distribution - eo = eta_factor[:, sdim : sdim + dofs] - eno = torch.cat( + # *_out = parameters for receiver variable (outgoing message vars) + # *_notout = parameters for other variables (not outgoing message vars) + eta_out = eta_factor[:, sdim : sdim + dofs] + eta_notout = torch.cat( (eta_factor[:, :sdim], eta_factor[:, sdim + dofs :]), dim=1 ) - loo = lam_factor[:, sdim : sdim + dofs, sdim : sdim + dofs] - lono = torch.cat( + lam_out_out = lam_factor[:, sdim : sdim + dofs, sdim : sdim + dofs] + lam_out_notout = torch.cat( ( lam_factor[:, sdim : sdim + dofs, :sdim], lam_factor[:, sdim : sdim + dofs, sdim + dofs :], ), dim=2, ) - lnoo = torch.cat( + lam_notout_out = torch.cat( ( lam_factor[:, :sdim, sdim : sdim + dofs], lam_factor[:, sdim + dofs :, sdim : sdim + dofs], ), dim=1, ) - lnono = torch.cat( + lam_notout_notout = torch.cat( ( torch.cat( ( @@ -318,9 +321,15 @@ def comp_mess( dim=1, ) - new_mess_lam = loo - lono @ torch.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - torch.bmm( - torch.bmm(lono, torch.linalg.inv(lnono)), eno.unsqueeze(-1) + new_mess_lam = ( + lam_out_out + - lam_out_notout + @ torch.linalg.inv(lam_notout_notout) + @ lam_notout_out + ) + new_mess_eta = eta_out - torch.bmm( + torch.bmm(lam_out_notout, torch.linalg.inv(lam_notout_notout)), + eta_notout.unsqueeze(-1), ).squeeze(-1) # message damping in tangent space at linearisation point as message @@ -409,9 +418,6 @@ def __init__( abs_err_tolerance, rel_err_tolerance, max_iterations ) - self.beliefs: List[th.ManifoldGaussian] = [] - self.factors: List[Factor] = [] - """ Copied and slightly modified from nonlinear optimizer class """ @@ -796,6 +802,67 @@ def _pass_fac_to_var_messages( if schedule_tsr.sum() != 0: factor.comp_mess(damping_tsr, schedule_tsr) + def _create_factors_beliefs(self, lin_system_damping): + self.factors: List[Factor] = [] + self.beliefs: List[th.ManifoldGaussian] = [] + for var in self.ordering: + self.beliefs.append(th.ManifoldGaussian([var])) + + if self.objective.vectorized: + cf_iterator = iter(self.objective.vectorized_cost_fns) + self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_vectorized + else: + cf_iterator = self.objective._get_iterator() + self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop + + # compute factor potentials for the first time + unary_factor = False + for i, cost_function in enumerate(cf_iterator): + + if self.objective.vectorized: + # create array for indexing the messages + base_cf_names = self.objective.vectorized_cf_names[i] + base_cfs = [ + self.objective.get_cost_function(name) for name in base_cf_names + ] + var_ixs = torch.tensor( + [ + [self.ordering.index_of(var.name) for var in cf.optim_vars] + for cf in base_cfs + ] + ).long() + else: + var_ixs = torch.tensor( + [ + self.ordering.index_of(var.name) + for var in cost_function.optim_vars + ] + ).long() + + self.factors.append( + Factor( + cost_function, + name=cost_function.name, + var_ixs=var_ixs, + lin_system_damping=lin_system_damping, + ) + ) + if cost_function.num_optim_vars() == 1: + unary_factor = True + if unary_factor is False: + raise Exception( + "We require at least one unary cost function to act as a prior." + "This is because Gaussian Belief Propagation is performing Bayesian inference." + ) + self._linearize_factors() + + self.n_individual_factors = ( + len(self.objective.cost_functions) * self.objective.batch_size + ) + self.n_edges = sum( + [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] + ) + """ Optimization loop functions """ @@ -815,68 +882,11 @@ def _optimize_loop( clear_messages: bool = True, **kwargs, ): - # initialise beliefs - for var in self.ordering: - self.beliefs.append(th.ManifoldGaussian([var])) - + # we only create the factors and beliefs right before runnig GBP as they are + # not automatically updated when objective.update is called. if clear_messages: - self.n_individual_factors = ( - len(self.objective.cost_functions) * self.objective.batch_size - ) - - if self.objective.vectorized: - cf_iterator = iter(self.objective.vectorized_cost_fns) - self._pass_var_to_fac_messages = ( - self._pass_var_to_fac_messages_vectorized - ) - else: - cf_iterator = self.objective._get_iterator() - self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop - - # compute factor potentials for the first time - unary_factor = False - for i, cost_function in enumerate(cf_iterator): - - if self.objective.vectorized: - # create array for indexing the messages - base_cf_names = self.objective.vectorized_cf_names[i] - base_cfs = [ - self.objective.get_cost_function(name) for name in base_cf_names - ] - var_ixs = torch.tensor( - [ - [self.ordering.index_of(var.name) for var in cf.optim_vars] - for cf in base_cfs - ] - ).long() - else: - var_ixs = torch.tensor( - [ - self.ordering.index_of(var.name) - for var in cost_function.optim_vars - ] - ).long() - - self.factors.append( - Factor( - cost_function, - name=cost_function.name, - var_ixs=var_ixs, - lin_system_damping=lin_system_damping, - ) - ) - if cost_function.num_optim_vars() == 1: - unary_factor = True - if unary_factor is False: - raise Exception( - "We require at least one unary cost function to act as a prior." - "This is because Gaussian Belief Propagation is performing Bayesian inference." - ) - relins = self._linearize_factors() + self._create_factors_beliefs(lin_system_damping) - self.n_edges = sum( - [factor.cf.num_optim_vars() * factor.batch_size for factor in self.factors] - ) if schedule == GBPSchedule.SYNCHRONOUS: ftov_schedule = synchronous_schedule( self.params.max_iterations, self.n_edges From 39be1f19eb47eff3c98cc55c250c1bcf6a9debe2 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 26 Aug 2022 11:58:39 +0100 Subject: [PATCH 41/64] nesterov, no grad for lm damping, ba batch experiments --- theseus/optimizer/gbp/bundle_adjustment.py | 175 +++++++++++++++++++-- theseus/optimizer/gbp/gbp.py | 41 +++-- 2 files changed, 193 insertions(+), 23 deletions(-) diff --git a/theseus/optimizer/gbp/bundle_adjustment.py b/theseus/optimizer/gbp/bundle_adjustment.py index 2d50c1735..5623b255c 100644 --- a/theseus/optimizer/gbp/bundle_adjustment.py +++ b/theseus/optimizer/gbp/bundle_adjustment.py @@ -10,7 +10,9 @@ import time import torch -# import os +import os +import json +from datetime import datetime import theseus as th from theseus.core import Vectorize @@ -26,6 +28,24 @@ "levenberg_marquardt": th.LevenbergMarquardt, } +OUTER_OPTIMIZER_CLASS = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, +} + + +def save_res_loss_rad(save_dir, cfg, sweep_radii, sweep_losses, radius_vals, losses): + with open(f"{save_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) + + # sweep values + np.savetxt(f"{save_dir}/sweep_radius.txt", sweep_radii) + np.savetxt(f"{save_dir}/sweep_loss.txt", sweep_losses) + + # optim trajectory + np.savetxt(f"{save_dir}/optim_radius.txt", radius_vals) + np.savetxt(f"{save_dir}/optim_loss.txt", losses) + def print_histogram( ba: theg.BundleAdjustmentDataset, var_dict: Dict[str, torch.Tensor], msg: str @@ -78,7 +98,7 @@ def average_repojection_error(objective, values_dict=None) -> float: return are -def run(cfg: omegaconf.OmegaConf): +def load_problem(cfg: omegaconf.OmegaConf): # create (or load) dataset if cfg["bal_file"] is None: ba = theg.BundleAdjustmentDataset.generate_synthetic( @@ -103,6 +123,12 @@ def run(cfg: omegaconf.OmegaConf): print("Points:", len(ba.points)) print("Observations:", len(ba.observations), "\n") + return ba + + +def setup_layer(cfg: omegaconf.OmegaConf): + ba = load_problem(cfg) + print("Optimizer:", cfg["optim"]["optimizer_cls"], "\n") # param that control transition from squared loss to huber @@ -198,6 +224,10 @@ def run(cfg: omegaconf.OmegaConf): dummy_objective.to(cfg["device"]) print("Device:", cfg["device"]) + # create damping parameter + lin_system_damping = torch.nn.Parameter(torch.tensor([1.0e-2], dtype=torch.float64)) + lin_system_damping.to(device=cfg["device"]) + optim_arg = { "track_best_solution": False, "track_err_history": True, @@ -207,11 +237,11 @@ def run(cfg: omegaconf.OmegaConf): } if isinstance(optimizer, GaussianBeliefPropagation): extra_args = { - "relin_threshold": 0.0000000001, - "damping": 0.0, + "relin_threshold": 1e-8, + "ftov_msg_damping": 0.0, "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": 1.0e-1, + "lin_system_damping": lin_system_damping, } optim_arg = {**optim_arg, **extra_args} @@ -221,6 +251,26 @@ def run(cfg: omegaconf.OmegaConf): for pt in ba.points: theseus_inputs[pt.name] = pt.tensor.clone() + return ( + theseus_optim, + theseus_inputs, + optim_arg, + ba, + dummy_objective, + camera_pose_vars, + lin_system_damping, + ) + + +def run_inner( + theseus_optim, + theseus_inputs, + optim_arg, + ba, + dummy_objective, + camera_pose_vars, + lin_system_damping, +): if ba.gt_cameras is not None: with torch.no_grad(): camera_loss_ref = camera_loss(ba, camera_pose_vars).item() @@ -229,9 +279,6 @@ def run(cfg: omegaconf.OmegaConf): print("Average reprojection error (pixels): ", are) print_histogram(ba, theseus_inputs, "Input histogram:") - objective.update(theseus_inputs) - print("squared err:", objective.error_squared_norm().item() / 2) - with torch.no_grad(): theseus_outputs, info = theseus_optim.forward( input_tensors=theseus_inputs, @@ -294,13 +341,109 @@ def run(cfg: omegaconf.OmegaConf): # np.savetxt(save_file, np.array(ares)) +def run_outer(cfg: omegaconf.OmegaConf): + + ( + theseus_optim, + theseus_inputs, + optim_arg, + ba, + dummy_objective, + camera_pose_vars, + lin_system_damping, + ) = setup_layer(cfg) + + loss_radius_tensor = torch.nn.Parameter(torch.tensor([3.0], dtype=torch.float64)) + model_optimizer = OUTER_OPTIMIZER_CLASS[cfg["outer"]["optimizer"]]( + [loss_radius_tensor], lr=cfg["outer"]["lr"] + ) + # model_optimizer = torch.optim.Adam([lin_system_damping], lr=cfg["outer"]["lr"]) + + theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() + + with torch.no_grad(): + camera_loss_ref = camera_loss(ba, camera_pose_vars).item() + print(f"CAMERA LOSS (no learning): {camera_loss_ref: .3f}") + print_histogram(ba, theseus_inputs, "Input histogram:") + + import matplotlib.pylab as plt + + sweep_radii = torch.linspace(0.01, 5.0, 20) + sweep_losses = [] + with torch.set_grad_enabled(False): + for r in sweep_radii: + theseus_inputs["log_loss_radius"][0] = r + + print(theseus_inputs["log_loss_radius"]) + + theseus_outputs, info = theseus_optim.forward( + input_tensors=theseus_inputs, + optimizer_kwargs=optim_arg, + ) + cam_loss = camera_loss(ba, camera_pose_vars) + loss = (cam_loss - camera_loss_ref) / camera_loss_ref + sweep_losses.append(torch.sum(loss.detach()).item()) + + plt.plot(sweep_radii, sweep_losses) + plt.xlabel("Log loss radius") + plt.ylabel("(Camera loss - reference loss) / reference loss") + + losses = [] + radius_vals = [] + theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() + + for epoch in range(cfg["outer"]["num_epochs"]): + print(f" ******************* EPOCH {epoch} ******************* ") + start_time = time.time_ns() + model_optimizer.zero_grad() + theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() + + theseus_outputs, info = theseus_optim.forward( + input_tensors=theseus_inputs, + optimizer_kwargs=optim_arg, + ) + + cam_loss = camera_loss(ba, camera_pose_vars) + loss = (cam_loss - camera_loss_ref) / camera_loss_ref + loss.backward() + radius_vals.append(loss_radius_tensor.data.item()) + print(loss_radius_tensor.grad) + model_optimizer.step() + loss_value = torch.sum(loss.detach()).item() + losses.append(loss_value) + end_time = time.time_ns() + + # print_histogram(ba, theseus_outputs, "Output histogram:") + print(f"camera loss {cam_loss} and ref loss {camera_loss_ref}") + print( + f"Epoch: {epoch} Loss: {loss_value} " + # f"Lin system damping {lin_system_damping}" + f"Kernel Radius: exp({loss_radius_tensor.data.item()})=" + f"{torch.exp(loss_radius_tensor.data).item()}" + ) + print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") + + print("Loss values:", losses) + + now = datetime.now() + time_str = now.strftime("%m-%d-%y_%H-%M-%S") + save_dir = os.getcwd() + "/outputs/loss_radius_exp/" + time_str + os.mkdir(save_dir) + + save_res_loss_rad(save_dir, cfg, sweep_radii, sweep_losses, radius_vals, losses) + + plt.scatter(radius_vals, losses, c=range(len(losses)), cmap=plt.get_cmap("viridis")) + plt.title(cfg["optim"]["optimizer_cls"] + " - " + time_str) + plt.show() + + if __name__ == "__main__": cfg = { "seed": 1, "device": "cpu", - # "bal_file": None, - "bal_file": "/media/joe/data/bal/trafalgar/problem-21-11315-pre.txt", + "bal_file": None, + # "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", "synthetic": { "num_cameras": 10, "num_points": 100, @@ -308,7 +451,7 @@ def run(cfg: omegaconf.OmegaConf): "track_locality": 0.2, }, "optim": { - "max_iters": 300, + "max_iters": 200, "optimizer_cls": "gbp", # "optimizer_cls": "gauss_newton", # "optimizer_cls": "levenberg_marquardt", @@ -317,10 +460,18 @@ def run(cfg: omegaconf.OmegaConf): "ratio_known_cameras": 0.1, "reg_w": 1e-7, }, + "outer": { + "num_epochs": 15, + "lr": 1e2, # 5.0e-1, + "optimizer": "sgd", + }, } torch.manual_seed(cfg["seed"]) np.random.seed(cfg["seed"]) random.seed(cfg["seed"]) - run(cfg) + # args = setup_layer(cfg) + # run_inner(*args) + + run_outer(cfg) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 6b63389bf..604f4fe19 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -38,6 +38,12 @@ """ +def next_nesterov_params(lam) -> Tuple[float, float]: + new_lambda = (1 + np.sqrt(4 * lam * lam + 1)) / 2.0 + new_gamma = (1 - lam) / new_lambda + return new_lambda, new_gamma + + # Same of NonlinearOptimizerParams but without step size @dataclass class GBPOptimizerParams: @@ -212,17 +218,18 @@ def linearize( eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) - # update damping parameter - err = (self.cf.error() ** 2).sum(dim=1) - if self.last_err is not None: - decreased_ixs = err < self.last_err - self.lm_damping[decreased_ixs] = torch.max( - self.lm_damping[decreased_ixs] / self.a, self.min_damping - ) - self.lm_damping[~decreased_ixs] = torch.min( - self.lm_damping[~decreased_ixs] * self.b, self.max_damping - ) - self.last_err = err + # update damping parameter. This is non-differentiable + with torch.no_grad(): + err = (self.cf.error() ** 2).sum(dim=1) + if self.last_err is not None: + decreased_ixs = err < self.last_err + self.lm_damping[decreased_ixs] = torch.max( + self.lm_damping[decreased_ixs] / self.a, self.min_damping + ) + self.lm_damping[~decreased_ixs] = torch.min( + self.lm_damping[~decreased_ixs] * self.b, self.max_damping + ) + self.last_err = err # damp precision matrix damped_D = self.lm_damping[:, None, None] * torch.eye( @@ -879,6 +886,7 @@ def _optimize_loop( dropout: float, schedule: GBPSchedule, lin_system_damping: torch.Tensor, + nesterov: bool, clear_messages: bool = True, **kwargs, ): @@ -892,6 +900,9 @@ def _optimize_loop( self.params.max_iterations, self.n_edges ) + if nesterov: + nest_lambda, nest_gamma = next_nesterov_params(0.0) + self.ftov_msgs_history = {} converged_indices = torch.zeros_like(info.last_err).bool() @@ -924,6 +935,10 @@ def _optimize_loop( t_ftov = time.time() - t1 t1 = time.time() + if nesterov: + nest_lambda, nest_gamma = next_nesterov_params(nest_lambda) + print("nesterov lambda", nest_lambda) + print("nesterov gamma", nest_gamma) self._pass_var_to_fac_messages(update_belief=True) t_vtof = time.time() - t1 @@ -978,6 +993,7 @@ def _optimize_impl( dropout: float = 0.0, schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), + nesterov: bool = False, **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): @@ -1025,6 +1041,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + nesterov=nesterov, **kwargs, ) @@ -1066,6 +1083,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + nesterov=nesterov, **kwargs, ) @@ -1082,6 +1100,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + nesterov=nesterov, clear_messages=False, **kwargs, ) From 25e87de99748a9238de87be94af1e20abccf87ed Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Tue, 6 Sep 2022 18:11:04 +0100 Subject: [PATCH 42/64] nesterov acceleration, two modes --- theseus/optimizer/gbp/bundle_adjustment.py | 62 +++++++++++---- theseus/optimizer/gbp/gbp.py | 87 ++++++++++++++++++++-- theseus/optimizer/gbp/plot_ba_err.py | 75 ++++++++++++++----- 3 files changed, 184 insertions(+), 40 deletions(-) diff --git a/theseus/optimizer/gbp/bundle_adjustment.py b/theseus/optimizer/gbp/bundle_adjustment.py index 5623b255c..e9a4cb666 100644 --- a/theseus/optimizer/gbp/bundle_adjustment.py +++ b/theseus/optimizer/gbp/bundle_adjustment.py @@ -33,6 +33,10 @@ "adam": torch.optim.Adam, } +GBP_SCHEDULE = { + "synchronous": GBPSchedule.SYNCHRONOUS, +} + def save_res_loss_rad(save_dir, cfg, sweep_radii, sweep_losses, radius_vals, losses): with open(f"{save_dir}/config.txt", "w") as f: @@ -208,7 +212,7 @@ def setup_layer(cfg: omegaconf.OmegaConf): print("done in:", time.time() - t0) # Create optimizer and theseus layer - vectorize = True + vectorize = cfg["optim"]["vectorize"] optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( objective, max_iterations=cfg["optim"]["max_iters"], @@ -225,7 +229,11 @@ def setup_layer(cfg: omegaconf.OmegaConf): print("Device:", cfg["device"]) # create damping parameter - lin_system_damping = torch.nn.Parameter(torch.tensor([1.0e-2], dtype=torch.float64)) + lin_system_damping = torch.nn.Parameter( + torch.tensor( + [cfg["optim"]["gbp_settings"]["lin_system_damping"]], dtype=torch.float64 + ) + ) lin_system_damping.to(device=cfg["device"]) optim_arg = { @@ -236,14 +244,10 @@ def setup_layer(cfg: omegaconf.OmegaConf): "backward_mode": th.BackwardMode.FULL, } if isinstance(optimizer, GaussianBeliefPropagation): - extra_args = { - "relin_threshold": 1e-8, - "ftov_msg_damping": 0.0, - "dropout": 0.0, - "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": lin_system_damping, - } - optim_arg = {**optim_arg, **extra_args} + gbp_args = cfg["optim"]["gbp_settings"].copy() + gbp_args["lin_system_damping"] = lin_system_damping + gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] + optim_arg = {**optim_arg, **gbp_args} theseus_inputs = {} for cam in ba.cameras: @@ -298,6 +302,23 @@ def run_inner( # info.state_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points # ) # , msg_history=optimizer.ftov_msgs_history) + """ + Save for nesterov experiments + """ + save_dir = os.getcwd() + "/outputs/nesterov/bal/" + if cfg["optim"]["gbp_settings"]["nesterov"]: + save_dir += "1/" + else: + save_dir += "0/" + os.mkdir(save_dir) + with open(f"{save_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) + np.savetxt(save_dir + "/error_history.txt", info.err_history[0].cpu().numpy()) + + """ + Save for bal sequences + """ + # if cfg["bal_file"] is not None: # save_dir = os.path.join(os.getcwd(), "outputs") # if not os.path.exists(save_dir): @@ -442,8 +463,8 @@ def run_outer(cfg: omegaconf.OmegaConf): cfg = { "seed": 1, "device": "cpu", - "bal_file": None, - # "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", + # "bal_file": None, + "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", "synthetic": { "num_cameras": 10, "num_points": 100, @@ -451,7 +472,8 @@ def run_outer(cfg: omegaconf.OmegaConf): "track_locality": 0.2, }, "optim": { - "max_iters": 200, + "max_iters": 300, + "vectorize": True, "optimizer_cls": "gbp", # "optimizer_cls": "gauss_newton", # "optimizer_cls": "levenberg_marquardt", @@ -459,6 +481,14 @@ def run_outer(cfg: omegaconf.OmegaConf): "regularize": True, "ratio_known_cameras": 0.1, "reg_w": 1e-7, + "gbp_settings": { + "relin_threshold": 1e-8, + "ftov_msg_damping": 0.0, + "dropout": 0.0, + "schedule": "synchronous", + "lin_system_damping": 1.0e-2, + "nesterov": True, + }, }, "outer": { "num_epochs": 15, @@ -471,7 +501,7 @@ def run_outer(cfg: omegaconf.OmegaConf): np.random.seed(cfg["seed"]) random.seed(cfg["seed"]) - # args = setup_layer(cfg) - # run_inner(*args) + args = setup_layer(cfg) + run_inner(*args) - run_outer(cfg) + # run_outer(cfg) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 604f4fe19..308d771ff 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -38,12 +38,41 @@ """ +# as in https://blogs.princeton.edu/imabandit/2013/04/01/acceleratedgradientdescent/ def next_nesterov_params(lam) -> Tuple[float, float]: new_lambda = (1 + np.sqrt(4 * lam * lam + 1)) / 2.0 - new_gamma = (1 - lam) / new_lambda + new_gamma = (lam - 1) / new_lambda return new_lambda, new_gamma +def apply_nesterov( + y_curr: th.Manifold, + y_last: th.Manifold, + nesterov_gamma: float, + normalize_method: bool = True, +) -> th.Manifold: + if normalize_method: + # apply to tensors and then project back to closest group element + nesterov_mean_tensor = ( + 1 + nesterov_gamma + ) * y_curr.tensor - nesterov_gamma * y_last.tensor + nesterov_mean_tensor = y_curr.__class__.normalize(nesterov_mean_tensor) + nesterov_mean = y_curr.__class__(tensor=nesterov_mean_tensor) + + else: + # apply nesterov damping in tanget plane. + # Cannot use new_belief or nesterov_y as the tangent plance, because tangent vector is 0. + # Use identity as tangent plane, may not be best choice as could be far from identity. + tp = y_curr.__class__(dtype=y_curr.dtype) + tp.to(y_curr.device) + tp_mean = (1 + nesterov_gamma) * tp.local(y_curr) - nesterov_gamma * tp.local( + y_last + ) + nesterov_mean = tp.retract(tp_mean) + + return nesterov_mean + + # Same of NonlinearOptimizerParams but without step size @dataclass class GBPOptimizerParams: @@ -593,7 +622,13 @@ def _merge_infos( GBP functions """ - def _pass_var_to_fac_messages_loop(self, update_belief=True): + def _pass_var_to_fac_messages_loop(self, update_belief=True, nesterov_gamma=None): + if nesterov_gamma is not None: + if nesterov_gamma == 0: # only on the first call + self.nesterov_ys = [ + belief.mean[0].copy(new_name="nesterov_y_" + belief.mean[0].name) + for belief in self.beliefs + ] for i, var in enumerate(self.ordering): # Collect all incoming messages in the tangent space at the current belief @@ -640,9 +675,27 @@ def _pass_var_to_fac_messages_loop(self, update_belief=True): tau = torch.matmul(inv_lam_tau, sum_taus.unsqueeze(-1)).squeeze(-1) new_belief = th.retract_gaussian(var, tau, lam_tau) + + # nesterov acceleration + if nesterov_gamma is not None: + nesterov_mean = apply_nesterov( + new_belief.mean[0], + self.nesterov_ys[i], + nesterov_gamma, + normalize_method=False, + ) + # belief mean as calculated by GBP step is the new nesterov y value at this step + self.nesterov_ys[i] = new_belief.mean[0].copy() + # use nesterov mean for new belief + new_belief.update( + mean=[nesterov_mean], precision=new_belief.precision + ) + self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _pass_var_to_fac_messages_vectorized(self, update_belief=True): + def _pass_var_to_fac_messages_vectorized( + self, update_belief=True, nesterov_gamma=None + ): # Each (variable-type, dof) gets mapped to a tuple with: # - the variable that will hold the vectorized data # - all the variables of that type that will be vectorized together @@ -687,6 +740,10 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): lam_tp_acc = lam_tp_acc.to(vectorized_data.device, vectorized_data.dtype) eta_lam.extend([eta_tp_acc, lam_tp_acc]) + if nesterov_gamma is not None: + if nesterov_gamma == 0: # only on the first call + self.nesterov_ys = [info[0].copy() for info in var_info.values()] + # add ftov messages to eta_tp and lam_tp accumulator tensors for factor in self.factors: for i, msg in enumerate(factor.ftov_msgs): @@ -761,6 +818,7 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): msg.update(new_mess.mean, new_mess.precision) # compute the new belief for the vectorized variables + i = 0 for (vectorized_var, _, var_ixs, eta_lam) in var_info.values(): eta_tp_acc = eta_lam[0] lam_tau = eta_lam[1] @@ -774,6 +832,23 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): tau = torch.matmul(inv_lam_tau, eta_tp_acc.unsqueeze(-1)).squeeze(-1) new_belief = th.retract_gaussian(vectorized_var, tau, lam_tau) + + # nesterov acceleration + if nesterov_gamma is not None: + nesterov_mean = apply_nesterov( + new_belief.mean[0], + self.nesterov_ys[i], + nesterov_gamma, + normalize_method=False, + ) + # belief mean as calculated by GBP step is the new nesterov y value at this step + self.nesterov_ys[i] = new_belief.mean[0].copy() + # use nesterov mean for new belief + new_belief.update( + mean=[nesterov_mean], precision=new_belief.precision + ) + i += 1 + # update non vectorized beliefs with slices start_idx = 0 for ix in var_ixs: @@ -935,11 +1010,13 @@ def _optimize_loop( t_ftov = time.time() - t1 t1 = time.time() + nest_gamma = None if nesterov: nest_lambda, nest_gamma = next_nesterov_params(nest_lambda) - print("nesterov lambda", nest_lambda) print("nesterov gamma", nest_gamma) - self._pass_var_to_fac_messages(update_belief=True) + self._pass_var_to_fac_messages( + update_belief=True, nesterov_gamma=nest_gamma + ) t_vtof = time.time() - t1 t_vec = 0.0 diff --git a/theseus/optimizer/gbp/plot_ba_err.py b/theseus/optimizer/gbp/plot_ba_err.py index 543555e14..1d5342977 100644 --- a/theseus/optimizer/gbp/plot_ba_err.py +++ b/theseus/optimizer/gbp/plot_ba_err.py @@ -3,29 +3,66 @@ import os -root_dir = "/home/joe/projects/theseus/theseus/optimizer/gbp/outputs" -err_files1 = [ - "gbp_problem-21-11315-pre.txt", - "levenberg_marquardt_problem-21-11315-pre.txt", -] -err_files2 = [ - "gbp_problem-50-20431-pre.txt", - "levenberg_marquardt_problem-50-20431-pre.txt", -] +""" +Nesterov experiments +""" -err_files = err_files1 +def nesterov_plots(): -for err_files in [err_files1, err_files2]: + root_dir = ( + "/home/joe/projects/mpSLAM/theseus/theseus/optimizer/gbp/outputs/nesterov/" + ) + exp_dir = root_dir + "bal/" - gbp_err = np.loadtxt(os.path.join(root_dir, err_files[0])) - lm_err = np.loadtxt(os.path.join(root_dir, err_files[1])) + err_normal = np.loadtxt(exp_dir + "0/error_history.txt") + err_nesterov_normalize = np.loadtxt(exp_dir + "normalize/error_history.txt") + err_nesterov_tp = np.loadtxt(exp_dir + "tangent_plane/error_history.txt") - plt.plot(gbp_err, label="GBP") - plt.plot(lm_err, label="Levenberg Marquardt") - plt.xscale("log") - plt.title(err_files[0][4:]) - plt.xlabel("Iterations") - plt.ylabel("Total Energy") + plt.plot(err_normal, label="Normal GBP") + plt.plot(err_nesterov_normalize, label="Nesterov acceleration - normalize") + plt.plot(err_nesterov_tp, label="Nesterov acceleration - lie algebra") plt.legend() + plt.yscale("log") + plt.xlabel("Iterations") + plt.ylabel("Error") plt.show() + + +""" +Comparing GBP to Levenberg Marquardt +""" + + +def gbp_vs_lm(): + + root_dir = "/home/joe/projects/theseus/theseus/optimizer/gbp/outputs" + err_files1 = [ + "gbp_problem-21-11315-pre.txt", + "levenberg_marquardt_problem-21-11315-pre.txt", + ] + err_files2 = [ + "gbp_problem-50-20431-pre.txt", + "levenberg_marquardt_problem-50-20431-pre.txt", + ] + + err_files = err_files1 + + for err_files in [err_files1, err_files2]: + + gbp_err = np.loadtxt(os.path.join(root_dir, err_files[0])) + lm_err = np.loadtxt(os.path.join(root_dir, err_files[1])) + + plt.plot(gbp_err, label="GBP") + plt.plot(lm_err, label="Levenberg Marquardt") + plt.xscale("log") + plt.title(err_files[0][4:]) + plt.xlabel("Iterations") + plt.ylabel("Total Energy") + plt.legend() + plt.show() + + +if __name__ == "__main__": + + nesterov_plots() From 4e1a4b1ec1dff6536111c2e576cc4d45877599df Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Thu, 8 Sep 2022 12:45:31 +0100 Subject: [PATCH 43/64] swarm exp --- theseus/optimizer/gbp/swarm.py | 317 +++++++++++++++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 theseus/optimizer/gbp/swarm.py diff --git a/theseus/optimizer/gbp/swarm.py b/theseus/optimizer/gbp/swarm.py new file mode 100644 index 000000000..c8f8d878d --- /dev/null +++ b/theseus/optimizer/gbp/swarm.py @@ -0,0 +1,317 @@ +import numpy as np +import random +import omegaconf +import torch +from typing import Optional, Tuple, List + +import pygame + +import theseus as th +from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule + + +OPTIMIZER_CLASS = { + "gbp": GaussianBeliefPropagation, + "gauss_newton": th.GaussNewton, + "levenberg_marquardt": th.LevenbergMarquardt, +} + +OUTER_OPTIMIZER_CLASS = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, +} + +GBP_SCHEDULE = { + "synchronous": GBPSchedule.SYNCHRONOUS, +} + + +class SwarmViewer: + def __init__( + self, + state_history, + agent_radius, + collision_radius, + show_edges=True, + ): + self.state_history = state_history + self.t = 0 + self.num_iters = (~list(state_history.values())[0].isinf()[0, 0]).sum() + + self.agent_cols = None + self.scale = 100 + self.show_edges = show_edges + self.agent_r_pix = agent_radius * self.scale + self.collision_radius = collision_radius + self.range = np.array([[-3, -3], [3, 3]]) + self.h = (self.range[1, 1] - self.range[0, 1]) * self.scale + self.w = (self.range[1, 0] - self.range[0, 0]) * self.scale + + pygame.init() + pygame.display.set_caption("Swarm") + self.myfont = pygame.font.SysFont("Jokerman", 40) + self.screen = pygame.display.set_mode([self.h, self.w]) + + self.draw_next() + + running = True + while running: + + # Did the user click the window close button? + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + if event.type == pygame.KEYDOWN: + if event.key == pygame.K_SPACE: + self.draw_next() + + def draw_next(self): + if self.agent_cols is None: + self.agent_cols = [ + tuple(np.random.choice(range(256), size=3)) + for i in range(len(self.state_history)) + ] + + if self.t < self.num_iters: + self.screen.fill((255, 255, 255)) + + # draw agents + for i, state in enumerate(self.state_history.values()): + pos = state[0, :, self.t].cpu().numpy() + centre = self.pos_to_canvas(pos) + pygame.draw.circle( + self.screen, self.agent_cols[i], centre, self.agent_r_pix + ) + + # draw edges between agents + if self.show_edges: + for i, state1 in enumerate(self.state_history.values()): + pos1 = state1[0, :, self.t].cpu().numpy() + j = 0 + for state2 in self.state_history.values(): + if j <= i: + j += 1 + continue + pos2 = state2[0, :, self.t].cpu().numpy() + dist = np.linalg.norm(pos1 - pos2) + if dist < self.collision_radius: + start = self.pos_to_canvas(pos1) + end = self.pos_to_canvas(pos2) + pygame.draw.line(self.screen, (0, 0, 0), start, end) + + # draw text + ssshow = self.myfont.render( + f"t = {self.t} / {self.num_iters - 1}", True, (0, 0, 0) + ) + self.screen.blit(ssshow, (10, 10)) # choose location of text + + pygame.display.flip() + + self.t += 1 + + def pos_to_canvas(self, pos): + return ( + (pos - self.range[0]) + / (self.range[1] - self.range[0]) + * np.array([self.h, self.w]) + ) + + +def error_fn(optim_vars, aux_vars): + var1, var2 = optim_vars + radius = aux_vars[0] + return torch.relu( + 1 - torch.norm(var1.tensor - var2.tensor, dim=1, keepdim=True) / radius.tensor + ) + + +class TwoAgentsCollision(th.CostFunction): + def __init__( + self, + weight: th.CostWeight, + var1: th.Point2, + var2: th.Point2, + radius: th.Vector, + name: Optional[str] = None, + ): + super().__init__(weight, name=name) + self.var1 = var1 + self.var2 = var2 + self.radius = radius + # to improve readability, we have skipped the data checks from code block above + self.register_optim_vars(["var1", "var2"]) + self.register_aux_vars(["radius"]) + + # no error when distance exceeds radius + def error(self) -> torch.Tensor: + dist = torch.norm(self.var1.tensor - self.var2.tensor, dim=1, keepdim=True) + return torch.relu(1 - dist / self.radius.tensor) + + def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: + dist = torch.norm(self.var1.tensor - self.var2.tensor, dim=1, keepdim=True) + denom = dist * self.radius.tensor + jac = (self.var1.tensor - self.var2.tensor) / denom + jac = jac[:, None, :] + jac[dist > self.radius.tensor] = 0.0 + return [ + -jac, + jac, + ], self.error() + + def dim(self) -> int: + return 1 + + def _copy_impl(self, new_name: Optional[str] = None) -> "TwoAgentsCollision": + return TwoAgentsCollision( + self.weight.copy(), + self.var1.copy(), + self.var2.copy(), + self.radius.copy(), + name=new_name, + ) + + +def setup_problem(cfg: omegaconf.OmegaConf): + dtype = torch.float32 + n_agents = cfg["setup"]["num_agents"] + + # create variables, one per agent + positions = [] + for i in range(n_agents): + init = torch.normal(torch.zeros(2), cfg["setup"]["init_std"]) + position = th.Point2(tensor=init, name=f"agent_{i}") + positions.append(position) + + objective = th.Objective(dtype=dtype) + + # prior factor drawing each robot to the origin + origin = th.Point2() + origin_weight = th.ScaleCostWeight( + torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) + ) + for i in range(n_agents): + origin_cf = th.Difference( + positions[i], + origin, + origin_weight, + name=f"origin_pull_{i}", + ) + objective.add(origin_cf) + + # create collision factors, fully connected + radius = th.Vector(tensor=torch.tensor([cfg["setup"]["collision_radius"]])) + collision_weight = th.ScaleCostWeight( + torch.tensor([cfg["setup"]["collision_weight"]], dtype=dtype) + ) + for i in range(n_agents): + for j in range(i + 1, n_agents): + collision_cf = TwoAgentsCollision( + weight=collision_weight, + var1=positions[i], + var2=positions[j], + radius=radius, + name=f"collision_{i}_{j}", + ) + objective.add(collision_cf) + + return objective + + +def main(cfg: omegaconf.OmegaConf): + + objective = setup_problem(cfg) + + # setup optimizer and theseus layer + vectorize = cfg["optim"]["vectorize"] + optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( + objective, + max_iterations=cfg["optim"]["max_iters"], + vectorize=vectorize, + # linearization_cls=th.SparseLinearization, + # linear_solver_cls=th.LUCudaSparseSolver, + ) + theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) + + if cfg["device"] == "cuda": + cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" + theseus_optim.to(cfg["device"]) + + optim_arg = { + "track_best_solution": False, + "track_err_history": True, + "track_state_history": True, + "verbose": True, + "backward_mode": th.BackwardMode.FULL, + } + if isinstance(optimizer, GaussianBeliefPropagation): + gbp_args = cfg["optim"]["gbp_settings"].copy() + lin_system_damping = torch.nn.Parameter( + torch.tensor( + [cfg["optim"]["gbp_settings"]["lin_system_damping"]], + dtype=torch.float32, + ) + ) + lin_system_damping.to(device=cfg["device"]) + gbp_args["lin_system_damping"] = lin_system_damping + gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] + optim_arg = {**optim_arg, **gbp_args} + + # theseus inputs + theseus_inputs = {} + for agent in objective.optim_vars.values(): + theseus_inputs[agent.name] = agent.tensor.clone() + + # print("initial states\n", theseus_inputs) + + with torch.no_grad(): + theseus_outputs, info = theseus_optim.forward( + input_tensors=theseus_inputs, + optimizer_kwargs=optim_arg, + ) + + # print("final states\n", theseus_outputs) + + # visualisation + # SwarmViewer( + # info.state_history, + # cfg["setup"]["agent_radius"], + # cfg["setup"]["collision_radius"], + # show_edges=False, + # ) + + +if __name__ == "__main__": + + cfg = { + "seed": 0, + "device": "cpu", + "setup": { + "num_agents": 100, + "init_std": 1.0, + "agent_radius": 0.1, + "collision_radius": 1.0, + "origin_weight": 0.3, + "collision_weight": 1.0, + }, + "optim": { + "max_iters": 20, + "vectorize": True, + "optimizer_cls": "gbp", + # "optimizer_cls": "gauss_newton", + # "optimizer_cls": "levenberg_marquardt", + "gbp_settings": { + "relin_threshold": 1e-8, + "ftov_msg_damping": 0.0, + "dropout": 0.0, + "schedule": "synchronous", + "lin_system_damping": 1.0e-2, + "nesterov": False, + }, + }, + } + + torch.manual_seed(cfg["seed"]) + np.random.seed(cfg["seed"]) + random.seed(cfg["seed"]) + + main(cfg) From 37f7ed55c0be40fa47aaf23ccaca26d39755c8c8 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 9 Sep 2022 11:16:54 +0100 Subject: [PATCH 44/64] learning target for agents --- theseus/optimizer/gbp/__init__.py | 1 + theseus/optimizer/gbp/swarm.py | 254 +++++++++++++------------- theseus/optimizer/gbp/swarm_viewer.py | 179 ++++++++++++++++++ 3 files changed, 312 insertions(+), 122 deletions(-) create mode 100644 theseus/optimizer/gbp/swarm_viewer.py diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py index 53c57ea55..4e58cbe99 100644 --- a/theseus/optimizer/gbp/__init__.py +++ b/theseus/optimizer/gbp/__init__.py @@ -4,4 +4,5 @@ # LICENSE file in the root directory of this source tree. from .ba_viewer import BAViewer +from .swarm_viewer import SwarmViewer from .gbp import GaussianBeliefPropagation, GBPSchedule diff --git a/theseus/optimizer/gbp/swarm.py b/theseus/optimizer/gbp/swarm.py index c8f8d878d..81842513e 100644 --- a/theseus/optimizer/gbp/swarm.py +++ b/theseus/optimizer/gbp/swarm.py @@ -1,14 +1,16 @@ import numpy as np import random import omegaconf -import torch +import time from typing import Optional, Tuple, List -import pygame +import torch import theseus as th from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule +# from theseus.optimizer.gbp import SwarmViewer + OPTIMIZER_CLASS = { "gbp": GaussianBeliefPropagation, @@ -26,105 +28,34 @@ } -class SwarmViewer: +def fc_block(in_f, out_f): + return torch.nn.Sequential(torch.nn.Linear(in_f, out_f), torch.nn.ReLU()) + + +class TargetMLP(torch.nn.Module): def __init__( self, - state_history, - agent_radius, - collision_radius, - show_edges=True, + input_dim=1, + output_dim=2, + hidden_dim=8, + hidden_layers=0, ): - self.state_history = state_history - self.t = 0 - self.num_iters = (~list(state_history.values())[0].isinf()[0, 0]).sum() - - self.agent_cols = None - self.scale = 100 - self.show_edges = show_edges - self.agent_r_pix = agent_radius * self.scale - self.collision_radius = collision_radius - self.range = np.array([[-3, -3], [3, 3]]) - self.h = (self.range[1, 1] - self.range[0, 1]) * self.scale - self.w = (self.range[1, 0] - self.range[0, 0]) * self.scale - - pygame.init() - pygame.display.set_caption("Swarm") - self.myfont = pygame.font.SysFont("Jokerman", 40) - self.screen = pygame.display.set_mode([self.h, self.w]) - - self.draw_next() - - running = True - while running: - - # Did the user click the window close button? - for event in pygame.event.get(): - if event.type == pygame.QUIT: - running = False - if event.type == pygame.KEYDOWN: - if event.key == pygame.K_SPACE: - self.draw_next() - - def draw_next(self): - if self.agent_cols is None: - self.agent_cols = [ - tuple(np.random.choice(range(256), size=3)) - for i in range(len(self.state_history)) - ] - - if self.t < self.num_iters: - self.screen.fill((255, 255, 255)) - - # draw agents - for i, state in enumerate(self.state_history.values()): - pos = state[0, :, self.t].cpu().numpy() - centre = self.pos_to_canvas(pos) - pygame.draw.circle( - self.screen, self.agent_cols[i], centre, self.agent_r_pix - ) - - # draw edges between agents - if self.show_edges: - for i, state1 in enumerate(self.state_history.values()): - pos1 = state1[0, :, self.t].cpu().numpy() - j = 0 - for state2 in self.state_history.values(): - if j <= i: - j += 1 - continue - pos2 = state2[0, :, self.t].cpu().numpy() - dist = np.linalg.norm(pos1 - pos2) - if dist < self.collision_radius: - start = self.pos_to_canvas(pos1) - end = self.pos_to_canvas(pos2) - pygame.draw.line(self.screen, (0, 0, 0), start, end) - - # draw text - ssshow = self.myfont.render( - f"t = {self.t} / {self.num_iters - 1}", True, (0, 0, 0) - ) - self.screen.blit(ssshow, (10, 10)) # choose location of text - - pygame.display.flip() - - self.t += 1 - - def pos_to_canvas(self, pos): - return ( - (pos - self.range[0]) - / (self.range[1] - self.range[0]) - * np.array([self.h, self.w]) - ) - - -def error_fn(optim_vars, aux_vars): - var1, var2 = optim_vars - radius = aux_vars[0] - return torch.relu( - 1 - torch.norm(var1.tensor - var2.tensor, dim=1, keepdim=True) / radius.tensor - ) - - + super(TargetMLP, self).__init__() + # input is agent index + self.relu = torch.nn.ReLU() + self.in_layer = torch.nn.Linear(1, hidden_dim) + hidden = [fc_block(hidden_dim, hidden_dim) for _ in range(hidden_layers)] + self.mid = torch.nn.Sequential(*hidden) + self.out_layer = torch.nn.Linear(hidden_dim, output_dim) + + def forward(self, x): + x = self.relu(self.in_layer(x)) + x = self.mid(x) + out = self.out_layer(x) + return out + + +# custom factor for two agents collision class TwoAgentsCollision(th.CostFunction): def __init__( self, @@ -138,7 +69,7 @@ def __init__( self.var1 = var1 self.var2 = var2 self.radius = radius - # to improve readability, we have skipped the data checks from code block above + # skips data checks self.register_optim_vars(["var1", "var2"]) self.register_aux_vars(["radius"]) @@ -153,10 +84,7 @@ def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: jac = (self.var1.tensor - self.var2.tensor) / denom jac = jac[:, None, :] jac[dist > self.radius.tensor] = 0.0 - return [ - -jac, - jac, - ], self.error() + return [-jac, jac], self.error() def dim(self) -> int: return 1 @@ -171,6 +99,13 @@ def _copy_impl(self, new_name: Optional[str] = None) -> "TwoAgentsCollision": ) +# all agents should be in square of side length 1 centered at the origin +def square_loss_fn(outputs, side_len): + positions = torch.cat(list(outputs.values())) + loss = torch.relu(torch.abs(positions) - side_len / 2) + return loss.sum() + + def setup_problem(cfg: omegaconf.OmegaConf): dtype = torch.float32 n_agents = cfg["setup"]["num_agents"] @@ -185,21 +120,23 @@ def setup_problem(cfg: omegaconf.OmegaConf): objective = th.Objective(dtype=dtype) # prior factor drawing each robot to the origin - origin = th.Point2() - origin_weight = th.ScaleCostWeight( - torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) - ) - for i in range(n_agents): - origin_cf = th.Difference( - positions[i], - origin, - origin_weight, - name=f"origin_pull_{i}", - ) - objective.add(origin_cf) + # origin = th.Point2(name="origin") + # origin_weight = th.ScaleCostWeight( + # torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) + # ) + # for i in range(n_agents): + # origin_cf = th.Difference( + # positions[i], + # origin, + # origin_weight, + # name=f"origin_pull_{i}", + # ) + # objective.add(origin_cf) # create collision factors, fully connected - radius = th.Vector(tensor=torch.tensor([cfg["setup"]["collision_radius"]])) + radius = th.Vector( + tensor=torch.tensor([cfg["setup"]["collision_radius"]]), name="radius" + ) collision_weight = th.ScaleCostWeight( torch.tensor([cfg["setup"]["collision_weight"]], dtype=dtype) ) @@ -214,6 +151,23 @@ def setup_problem(cfg: omegaconf.OmegaConf): ) objective.add(collision_cf) + # learned factors, encouraging a square formation + target_weight = th.ScaleCostWeight( + torch.tensor([cfg["setup"]["origin_weight"] * 10], dtype=dtype) + ) + for i in range(n_agents): + target = th.Point2( + tensor=torch.normal(torch.zeros(2), cfg["setup"]["init_std"]), + name=f"target_{i}", + ) + target_cf = th.Difference( + positions[i], + target, + target_weight, + name=f"formation_target_{i}", + ) + objective.add(target_cf) + return objective @@ -261,22 +215,70 @@ def main(cfg: omegaconf.OmegaConf): for agent in objective.optim_vars.values(): theseus_inputs[agent.name] = agent.tensor.clone() - # print("initial states\n", theseus_inputs) + # setup outer optimizer + targets = {} + for name, aux_var in objective.aux_vars.items(): + if "target" in name: + targets[name] = torch.nn.Parameter(aux_var.tensor.clone()) + outer_optimizer = OUTER_OPTIMIZER_CLASS[cfg["outer_optim"]["optimizer"]]( + targets.values(), lr=cfg["outer_optim"]["lr"] + ) + + losses = [] + targets_history = {} + for k, target in targets.items(): + targets_history[k] = target.detach().clone().cpu().unsqueeze(-1) + + for epoch in range(cfg["outer_optim"]["num_epochs"]): + print(f" ******************* EPOCH {epoch} ******************* ") + start_time = time.time_ns() + outer_optimizer.zero_grad() + + for k, target in targets.items(): + theseus_inputs[k] = target.clone() - with torch.no_grad(): theseus_outputs, info = theseus_optim.forward( input_tensors=theseus_inputs, optimizer_kwargs=optim_arg, ) - # print("final states\n", theseus_outputs) + if epoch < cfg["outer"]["num_epochs"] - 1: + loss = square_loss_fn( + theseus_outputs, cfg["outer_optim"]["square_side_len"] + ) + loss.backward() + outer_optimizer.step() + losses.append(loss.detach().item()) + end_time = time.time_ns() + + for k, target in targets.items(): + targets_history[k] = torch.cat( + (targets_history[k], target.detach().clone().cpu().unsqueeze(-1)), + dim=-1, + ) + + print(f"Loss {losses[-1]}") + print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") + + print("Loss values:", losses) # visualisation - # SwarmViewer( - # info.state_history, + # viewer = SwarmViewer( # cfg["setup"]["agent_radius"], # cfg["setup"]["collision_radius"], + # ) + + # viewer.vis_outer_targets_optim( + # targets_history, + # square_side=cfg["outer_optim"]["square_side_len"], + # video_file=cfg["outer_optim_video_file"], + # ) + + # viewer.vis_inner_optim( + # info.state_history, + # targets=targets, # make sure targets are from correct innner optim # show_edges=False, + # video_file=cfg["out_video_file"], # ) @@ -285,8 +287,10 @@ def main(cfg: omegaconf.OmegaConf): cfg = { "seed": 0, "device": "cpu", + "out_video_file": "outputs/swarm/inner.gif", + "outer_optim_video_file": "outputs/swarm/outer_targets.gif", "setup": { - "num_agents": 100, + "num_agents": 80, "init_std": 1.0, "agent_radius": 0.1, "collision_radius": 1.0, @@ -308,6 +312,12 @@ def main(cfg: omegaconf.OmegaConf): "nesterov": False, }, }, + "outer_optim": { + "num_epochs": 25, + "lr": 4e-1, + "optimizer": "sgd", + "square_side_len": 2.0, + }, } torch.manual_seed(cfg["seed"]) diff --git a/theseus/optimizer/gbp/swarm_viewer.py b/theseus/optimizer/gbp/swarm_viewer.py new file mode 100644 index 000000000..ff73e6b89 --- /dev/null +++ b/theseus/optimizer/gbp/swarm_viewer.py @@ -0,0 +1,179 @@ +import numpy as np +import shutil +import os +import pygame + + +class SwarmViewer: + def __init__( + self, + agent_radius, + collision_radius, + ): + self.agent_cols = None + self.scale = 100 + self.agent_r_pix = agent_radius * self.scale + self.collision_radius = collision_radius + self.square_side = None + + self.range = np.array([[-3, -3], [3, 3]]) + self.h = (self.range[1, 1] - self.range[0, 1]) * self.scale + self.w = (self.range[1, 0] - self.range[0, 0]) * self.scale + + pygame.init() + pygame.display.set_caption("Swarm") + self.myfont = pygame.font.SysFont("Jokerman", 40) + self.screen = pygame.display.set_mode([self.w, self.h]) + + def vis_inner_optim( + self, + state_history, + targets=None, + show_edges=True, + video_file=None, + ): + self.state_history = state_history + self.t = 0 + self.num_iters = (~list(state_history.values())[0].isinf()[0, 0]).sum() + + self.video_file = video_file + if self.video_file is not None: + self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" + self.save_ix = 0 + os.mkdir(self.tmp_dir) + + self.targets = targets + self.show_edges = show_edges + self.width = 0 + + self.run() + + def vis_outer_targets_optim( + self, + targets_history, + square_side=None, + video_file=None, + ): + self.state_history = targets_history + self.t = 0 + self.num_iters = list(targets_history.values())[0].shape[-1] + + self.video_file = video_file + if self.video_file is not None: + self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" + self.save_ix = 0 + os.mkdir(self.tmp_dir) + + self.targets = None + self.show_edges = False + self.width = 3 + self.square_side = square_side + + self.run() + + def run(self): + self.draw_next() + + running = True + while running: + + # Did the user click the window close button? + for event in pygame.event.get(): + if event.type == pygame.QUIT: + running = False + if event.type == pygame.KEYDOWN: + if event.key == pygame.K_SPACE: + self.draw_next() + + def draw_next(self): + if self.agent_cols is None: + self.agent_cols = [ + tuple(np.random.choice(range(256), size=3)) + for i in range(len(self.state_history)) + ] + + if self.t < self.num_iters: + self.screen.fill((255, 255, 255)) + + # draw agents + for i, state in enumerate(self.state_history.values()): + pos = state[0, :, self.t].cpu().numpy() + centre = self.pos_to_canvas(pos) + pygame.draw.circle( + self.screen, + self.agent_cols[i], + centre, + self.agent_r_pix, + self.width, + ) + + # draw edges between agents + if self.show_edges: + for i, state1 in enumerate(self.state_history.values()): + pos1 = state1[0, :, self.t].cpu().numpy() + j = 0 + for state2 in self.state_history.values(): + if j <= i: + j += 1 + continue + pos2 = state2[0, :, self.t].cpu().numpy() + dist = np.linalg.norm(pos1 - pos2) + if dist < self.collision_radius: + start = self.pos_to_canvas(pos1) + end = self.pos_to_canvas(pos2) + pygame.draw.line(self.screen, (0, 0, 0), start, end) + + # draw targets + if self.targets is not None: + for i, state in enumerate(self.targets.values()): + centre = self.pos_to_canvas(state[0].detach().cpu().numpy()) + pygame.draw.circle( + self.screen, self.agent_cols[i], centre, self.agent_r_pix, 3 + ) + + # draw square + if self.square_side is not None: + side = self.square_side * self.scale + left = (self.w - side) / 2 + top = (self.h - side) / 2 + pygame.draw.rect(self.screen, (0, 100, 255), (left, top, side, side), 3) + + # draw text + ssshow = self.myfont.render( + f"t = {self.t} / {self.num_iters - 1}", True, (0, 0, 0) + ) + self.screen.blit(ssshow, (10, 10)) # choose location of text + + pygame.display.flip() + + if self.video_file: + self.save_image() + + self.t += 1 + + elif self.t == self.num_iters and self.video_file: + if self.video_file[-3:] == "mp4": + os.system( + f"ffmpeg -r 4 -i {self.tmp_dir}/%06d.png -vcodec mpeg4 -y {self.video_file}" + ) + elif self.video_file[-3:] == "gif": + os.system( + f"ffmpeg -i {self.tmp_dir}/%06d.png -vf palettegen {self.tmp_dir}/palette.png" + ) + os.system( + f"ffmpeg -r 4 -i {self.tmp_dir}/%06d.png -i {self.tmp_dir}/palette.png" + " -lavfi paletteuse {self.video_file}" + ) + else: + raise ValueError("video file must be either mp4 or gif.") + shutil.rmtree(self.tmp_dir) + self.t += 1 + + def pos_to_canvas(self, pos): + x = (pos - self.range[0]) / (self.range[1] - self.range[0]) + return x * np.array([self.h, self.w]) + + def save_image(self): + fname = self.tmp_dir + f"/{self.save_ix:06d}.png" + pygame.image.save(self.screen, fname) + self.save_ix += 1 From a6b7274b14d48d6989ec8138d80bc144003fbd3f Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 16 Sep 2022 16:44:06 +0100 Subject: [PATCH 45/64] target character and joint mlp + gbp --- theseus/optimizer/gbp/swarm.py | 372 ++++++++++++++++---------- theseus/optimizer/gbp/swarm_viewer.py | 107 +++++--- 2 files changed, 310 insertions(+), 169 deletions(-) diff --git a/theseus/optimizer/gbp/swarm.py b/theseus/optimizer/gbp/swarm.py index 81842513e..2d83a3ab4 100644 --- a/theseus/optimizer/gbp/swarm.py +++ b/theseus/optimizer/gbp/swarm.py @@ -2,14 +2,15 @@ import random import omegaconf import time +from PIL import Image, ImageDraw, ImageFont, ImageFilter from typing import Optional, Tuple, List import torch +import torch.nn as nn import theseus as th from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule - -# from theseus.optimizer.gbp import SwarmViewer +from theseus.optimizer.gbp import SwarmViewer OPTIMIZER_CLASS = { @@ -28,30 +29,81 @@ } +# create image from a character and font +def gen_char_img( + char, dilate=True, fontname="LiberationSerif-Bold.ttf", size=(200, 200) +): + img = Image.new("L", size, "white") + draw = ImageDraw.Draw(img) + fontsize = int(size[0] * 0.5) + font = ImageFont.truetype(fontname, fontsize) + char_displaysize = font.getsize(char) + offset = tuple((si - sc) // 2 for si, sc in zip(size, char_displaysize)) + draw.text((offset[0], offset[1] * 3 // 4), char, font=font, fill="#000") + + if dilate: + img = img.filter(ImageFilter.MinFilter(3)) + + return img + + +# all agents should be inside object (negative SDF values) +def target_char_loss(outputs, sdf): + positions = torch.cat(list(outputs.values())) + dists = sdf.signed_distance(positions)[0] + loss = torch.relu(dists) + return loss.sum() + + +def gen_target_sdf(cfg): + # setup target shape for outer loop loss fn + area_limits = np.array(cfg["setup"]["area_limits"]) + cell_size = 0.05 + img_size = tuple(np.rint((area_limits[1] - area_limits[0]) / cell_size).astype(int)) + img = gen_char_img( + cfg["outer_optim"]["target_char"], + dilate=True, + fontname="DejaVuSans-Bold.ttf", + size=img_size, + ) + occ_map = torch.Tensor(np.array(img) < 255) + occ_map = torch.flip( + occ_map, [0] + ) # flip vertically so y axis is upwards wrt character + sdf = th.eb.SignedDistanceField2D( + th.Variable(torch.Tensor(area_limits[0][None, :])), + th.Variable(torch.Tensor([cell_size])), + occupancy_map=th.Variable(occ_map[None, :]), + ) + return sdf + + def fc_block(in_f, out_f): - return torch.nn.Sequential(torch.nn.Linear(in_f, out_f), torch.nn.ReLU()) + return nn.Sequential(nn.Linear(in_f, out_f), nn.ReLU()) -class TargetMLP(torch.nn.Module): +class SimpleMLP(nn.Module): def __init__( self, - input_dim=1, + input_dim=2, output_dim=2, hidden_dim=8, hidden_layers=0, + scale_output=1.0, ): - super(TargetMLP, self).__init__() + super(SimpleMLP, self).__init__() # input is agent index - self.relu = torch.nn.ReLU() - self.in_layer = torch.nn.Linear(1, hidden_dim) + self.scale_output = scale_output + self.relu = nn.ReLU() + self.in_layer = nn.Linear(input_dim, hidden_dim) hidden = [fc_block(hidden_dim, hidden_dim) for _ in range(hidden_layers)] - self.mid = torch.nn.Sequential(*hidden) - self.out_layer = torch.nn.Linear(hidden_dim, output_dim) + self.mid = nn.Sequential(*hidden) + self.out_layer = nn.Linear(hidden_dim, output_dim) def forward(self, x): - x = self.relu(self.in_layer(x)) - x = self.mid(x) - out = self.out_layer(x) + y = self.relu(self.in_layer(x)) + y = self.mid(y) + out = self.out_layer(y) * self.scale_output return out @@ -99,14 +151,7 @@ def _copy_impl(self, new_name: Optional[str] = None) -> "TwoAgentsCollision": ) -# all agents should be in square of side length 1 centered at the origin -def square_loss_fn(outputs, side_len): - positions = torch.cat(list(outputs.values())) - loss = torch.relu(torch.abs(positions) - side_len / 2) - return loss.sum() - - -def setup_problem(cfg: omegaconf.OmegaConf): +def setup_problem(cfg: omegaconf.OmegaConf, gnn_err_fn): dtype = torch.float32 n_agents = cfg["setup"]["num_agents"] @@ -120,18 +165,18 @@ def setup_problem(cfg: omegaconf.OmegaConf): objective = th.Objective(dtype=dtype) # prior factor drawing each robot to the origin - # origin = th.Point2(name="origin") - # origin_weight = th.ScaleCostWeight( - # torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) - # ) - # for i in range(n_agents): - # origin_cf = th.Difference( - # positions[i], - # origin, - # origin_weight, - # name=f"origin_pull_{i}", - # ) - # objective.add(origin_cf) + origin = th.Point2(name="origin") + origin_weight = th.ScaleCostWeight( + torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) + ) + for i in range(n_agents): + origin_cf = th.Difference( + positions[i], + origin, + origin_weight, + name=f"origin_pull_{i}", + ) + objective.add(origin_cf) # create collision factors, fully connected radius = th.Vector( @@ -152,150 +197,205 @@ def setup_problem(cfg: omegaconf.OmegaConf): objective.add(collision_cf) # learned factors, encouraging a square formation + # target_weight = th.ScaleCostWeight( + # torch.tensor([cfg["setup"]["origin_weight"] * 10], dtype=dtype) + # ) + # for i in range(n_agents): + # target = th.Point2( + # tensor=torch.normal(torch.zeros(2), cfg["setup"]["init_std"]), + # name=f"target_{i}", + # ) + # target_cf = th.Difference( + # positions[i], + # target, + # target_weight, + # name=f"formation_target_{i}", + # ) + # objective.add(target_cf) + + # GNN factor - MLP that takes in all current belief means and outputs all targets target_weight = th.ScaleCostWeight( torch.tensor([cfg["setup"]["origin_weight"] * 10], dtype=dtype) ) - for i in range(n_agents): - target = th.Point2( - tensor=torch.normal(torch.zeros(2), cfg["setup"]["init_std"]), - name=f"target_{i}", - ) - target_cf = th.Difference( - positions[i], - target, - target_weight, - name=f"formation_target_{i}", - ) - objective.add(target_cf) + gnn_cf = th.AutoDiffCostFunction( + optim_vars=positions, + err_fn=gnn_err_fn, + dim=n_agents, + cost_weight=target_weight, + ) + objective.add(gnn_cf) return objective -def main(cfg: omegaconf.OmegaConf): +class SwarmGBPAndGNN(nn.Module): + def __init__(self, cfg): + super().__init__() - objective = setup_problem(cfg) + n_agents = cfg["setup"]["num_agents"] + self.gnn = SimpleMLP( + input_dim=2 * n_agents, + output_dim=2 * n_agents, + hidden_dim=64, + hidden_layers=2, + scale_output=1.0, + ) - # setup optimizer and theseus layer - vectorize = cfg["optim"]["vectorize"] - optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( - objective, - max_iterations=cfg["optim"]["max_iters"], - vectorize=vectorize, - # linearization_cls=th.SparseLinearization, - # linear_solver_cls=th.LUCudaSparseSolver, - ) - theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) - - if cfg["device"] == "cuda": - cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" - theseus_optim.to(cfg["device"]) - - optim_arg = { - "track_best_solution": False, - "track_err_history": True, - "track_state_history": True, - "verbose": True, - "backward_mode": th.BackwardMode.FULL, - } - if isinstance(optimizer, GaussianBeliefPropagation): - gbp_args = cfg["optim"]["gbp_settings"].copy() - lin_system_damping = torch.nn.Parameter( - torch.tensor( + # setup objective, optimizer and theseus layer + objective = setup_problem(cfg, self._gnn_err_fn) + vectorize = cfg["optim"]["vectorize"] + optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( + objective, + max_iterations=cfg["optim"]["max_iters"], + vectorize=vectorize, + ) + self.layer = th.TheseusLayer(optimizer, vectorize=vectorize) + + # put on device + if cfg["device"] == "cuda": + cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" + self.gnn.to(cfg["device"]) + self.layer.to(cfg["device"]) + + # optimizer arguments + optim_arg = { + "track_best_solution": False, + "track_err_history": True, + "track_state_history": True, + "verbose": True, + "backward_mode": th.BackwardMode.FULL, + } + if isinstance(optimizer, GaussianBeliefPropagation): + gbp_args = cfg["optim"]["gbp_settings"].copy() + lin_system_damping = torch.tensor( [cfg["optim"]["gbp_settings"]["lin_system_damping"]], dtype=torch.float32, ) + lin_system_damping.to(device=cfg["device"]) + gbp_args["lin_system_damping"] = lin_system_damping + gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] + optim_arg = {**optim_arg, **gbp_args} + self.optim_arg = optim_arg + + # fixed inputs to theseus layer + self.inputs = {} + for agent in objective.optim_vars.values(): + self.inputs[agent.name] = agent.tensor.clone() + + # network outputs offset for target from agent position + # cost is zero when offset is zero, i.e. agent is at the target + def _gnn_err_fn(self, optim_vars: List[th.Manifold], aux_vars): + assert len(aux_vars) == 0 + positions = optim_vars + + batch_size = positions[0].shape[0] + flattened_pos = torch.cat( + [pos.tensor.unsqueeze(1) for pos in positions], dim=1 + ).flatten(1, 2) + + offsets = self.gnn(flattened_pos) + offsets = offsets.reshape(batch_size, len(positions), 2) + err = offsets.norm(dim=-1) + + return err + + def forward(self, track_history=False): + + optim_arg = self.optim_arg.copy() + optim_arg["track_state_history"] = track_history + + outputs, info = self.layer.forward( + input_tensors=self.inputs, + optimizer_kwargs=optim_arg, ) - lin_system_damping.to(device=cfg["device"]) - gbp_args["lin_system_damping"] = lin_system_damping - gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] - optim_arg = {**optim_arg, **gbp_args} - - # theseus inputs - theseus_inputs = {} - for agent in objective.optim_vars.values(): - theseus_inputs[agent.name] = agent.tensor.clone() - - # setup outer optimizer - targets = {} - for name, aux_var in objective.aux_vars.items(): - if "target" in name: - targets[name] = torch.nn.Parameter(aux_var.tensor.clone()) + + history = None + if track_history: + + history = info.state_history + + # recover target history + agent_histories = torch.cat( + [state_hist.unsqueeze(1) for state_hist in history.values()], dim=1 + ) + history["agent_0"] + + batch_size = agent_histories.shape[0] + ts = agent_histories.shape[-1] + agent_histories = agent_histories.permute( + 0, 3, 1, 2 + ) # time dim is second dim + agent_histories = agent_histories.flatten(-2, -1) + target_hist = self.gnn(agent_histories) + target_hist = target_hist.reshape(batch_size, ts, -1, 2) + target_hist = target_hist.permute(0, 2, 3, 1) # time back to last dim + + for i in range(target_hist.shape[1]): + history[f"target_{i}"] = target_hist[:, i] + history[f"agent_{i}"] + + return outputs, history + + +def main(cfg: omegaconf.OmegaConf): + + sdf = gen_target_sdf(cfg) + + model = SwarmGBPAndGNN(cfg) + outer_optimizer = OUTER_OPTIMIZER_CLASS[cfg["outer_optim"]["optimizer"]]( - targets.values(), lr=cfg["outer_optim"]["lr"] + model.gnn.parameters(), lr=cfg["outer_optim"]["lr"] ) - losses = [] - targets_history = {} - for k, target in targets.items(): - targets_history[k] = target.detach().clone().cpu().unsqueeze(-1) + viewer = SwarmViewer(cfg["setup"]["collision_radius"], cfg["setup"]["area_limits"]) + losses = [] for epoch in range(cfg["outer_optim"]["num_epochs"]): print(f" ******************* EPOCH {epoch} ******************* ") start_time = time.time_ns() outer_optimizer.zero_grad() - for k, target in targets.items(): - theseus_inputs[k] = target.clone() + track_history = epoch % 5 == 0 - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=optim_arg, - ) + outputs, history = model.forward(track_history=track_history) - if epoch < cfg["outer"]["num_epochs"] - 1: - loss = square_loss_fn( - theseus_outputs, cfg["outer_optim"]["square_side_len"] - ) - loss.backward() - outer_optimizer.step() - losses.append(loss.detach().item()) - end_time = time.time_ns() + loss = target_char_loss(outputs, sdf) - for k, target in targets.items(): - targets_history[k] = torch.cat( - (targets_history[k], target.detach().clone().cpu().unsqueeze(-1)), - dim=-1, - ) + loss.backward() + outer_optimizer.step() + losses.append(loss.detach().item()) + end_time = time.time_ns() - print(f"Loss {losses[-1]}") - print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") + print(f"Loss {losses[-1]}") + print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") - print("Loss values:", losses) + if track_history: + viewer.vis_inner_optim(history, target_sdf=sdf, show_edges=False) - # visualisation - # viewer = SwarmViewer( - # cfg["setup"]["agent_radius"], - # cfg["setup"]["collision_radius"], - # ) + print("Loss values:", losses) + # outputs visualisations # viewer.vis_outer_targets_optim( # targets_history, - # square_side=cfg["outer_optim"]["square_side_len"], + # target_sdf=sdf, # video_file=cfg["outer_optim_video_file"], # ) - # viewer.vis_inner_optim( - # info.state_history, - # targets=targets, # make sure targets are from correct innner optim - # show_edges=False, - # video_file=cfg["out_video_file"], - # ) - if __name__ == "__main__": cfg = { "seed": 0, "device": "cpu", - "out_video_file": "outputs/swarm/inner.gif", - "outer_optim_video_file": "outputs/swarm/outer_targets.gif", + "out_video_file": "outputs/swarm/inner_mlp.gif", + "outer_optim_video_file": "outputs/swarm/outer_targets_mlp.gif", "setup": { - "num_agents": 80, + "num_agents": 100, "init_std": 1.0, "agent_radius": 0.1, "collision_radius": 1.0, - "origin_weight": 0.3, + "origin_weight": 0.1, "collision_weight": 1.0, + "area_limits": [[-3, -3], [3, 3]], }, "optim": { "max_iters": 20, @@ -313,10 +413,10 @@ def main(cfg: omegaconf.OmegaConf): }, }, "outer_optim": { - "num_epochs": 25, - "lr": 4e-1, - "optimizer": "sgd", - "square_side_len": 2.0, + "num_epochs": 50, + "lr": 2e-2, + "optimizer": "adam", + "target_char": "A", }, } diff --git a/theseus/optimizer/gbp/swarm_viewer.py b/theseus/optimizer/gbp/swarm_viewer.py index ff73e6b89..ee526913e 100644 --- a/theseus/optimizer/gbp/swarm_viewer.py +++ b/theseus/optimizer/gbp/swarm_viewer.py @@ -1,57 +1,84 @@ import numpy as np import shutil import os +import torch + import pygame class SwarmViewer: def __init__( self, - agent_radius, collision_radius, + area_limits, ): self.agent_cols = None self.scale = 100 - self.agent_r_pix = agent_radius * self.scale + self.agent_r_pix = collision_radius / 20 * self.scale self.collision_radius = collision_radius - self.square_side = None + self.target_sdf = None - self.range = np.array([[-3, -3], [3, 3]]) + self.range = np.array(area_limits) self.h = (self.range[1, 1] - self.range[0, 1]) * self.scale self.w = (self.range[1, 0] - self.range[0, 0]) * self.scale + self.video_file = None + pygame.init() pygame.display.set_caption("Swarm") self.myfont = pygame.font.SysFont("Jokerman", 40) self.screen = pygame.display.set_mode([self.w, self.h]) - def vis_inner_optim( + def vis_target_step( self, - state_history, - targets=None, - show_edges=True, - video_file=None, + targets_history, + target_sdf, ): - self.state_history = state_history - self.t = 0 - self.num_iters = (~list(state_history.values())[0].isinf()[0, 0]).sum() + self.state_history = targets_history + self.t = (~list(targets_history.values())[0].isinf()[0, 0]).sum().item() - 1 + self.num_iters = self.t + 1 + self.targets = None + self.show_edges = False + self.width = 3 + self.target_sdf = target_sdf + + self.draw_next() + + def prepare_video(self, video_file): self.video_file = video_file if self.video_file is not None: self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" self.save_ix = 0 + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) os.mkdir(self.tmp_dir) - self.targets = targets + def vis_inner_optim( + self, + history, + target_sdf=None, + show_edges=True, + video_file=None, + ): + self.prepare_video(video_file) + + self.state_history = {k: v for k, v in history.items() if "agent" in k} + self.target_history = {k: v for k, v in history.items() if "target" in k} + + self.t = 0 + self.num_iters = (~list(self.state_history.values())[0].isinf()[0, 0]).sum() + self.show_edges = show_edges self.width = 0 + self.target_sdf = target_sdf self.run() def vis_outer_targets_optim( self, targets_history, - square_side=None, + target_sdf=None, video_file=None, ): self.state_history = targets_history @@ -62,12 +89,14 @@ def vis_outer_targets_optim( if self.video_file is not None: self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" self.save_ix = 0 + if os.path.exists(self.tmp_dir): + shutil.rmtree(self.tmp_dir) os.mkdir(self.tmp_dir) self.targets = None self.show_edges = False self.width = 3 - self.square_side = square_side + self.target_sdf = target_sdf self.run() @@ -95,9 +124,25 @@ def draw_next(self): if self.t < self.num_iters: self.screen.fill((255, 255, 255)) + # draw target shape as background + if self.target_sdf is not None: + sdf = self.target_sdf.sdf_data.tensor[0].transpose(0, 1) + sdf = torch.flip( + sdf, [1] + ) # flip vertically so y is increasing going up + repeats = self.screen.get_width() // sdf.shape[0] + sdf = torch.repeat_interleave(sdf, repeats, dim=0) + sdf = torch.repeat_interleave(sdf, repeats, dim=1) + sdf = sdf.detach().cpu().numpy() + bg_img = np.zeros([*sdf.shape, 3]) + bg_img[sdf > 0] = 255 + bg_img[sdf <= 0] = [144, 238, 144] + bg = pygame.surfarray.make_surface(bg_img) + self.screen.blit(bg, (0, 0)) + # draw agents for i, state in enumerate(self.state_history.values()): - pos = state[0, :, self.t].cpu().numpy() + pos = state[0, :, self.t].detach().cpu().numpy() centre = self.pos_to_canvas(pos) pygame.draw.circle( self.screen, @@ -110,33 +155,28 @@ def draw_next(self): # draw edges between agents if self.show_edges: for i, state1 in enumerate(self.state_history.values()): - pos1 = state1[0, :, self.t].cpu().numpy() + pos1 = state1[0, :, self.t].detach().cpu().numpy() j = 0 for state2 in self.state_history.values(): if j <= i: j += 1 continue - pos2 = state2[0, :, self.t].cpu().numpy() + pos2 = state2[0, :, self.t].detach().cpu().numpy() dist = np.linalg.norm(pos1 - pos2) if dist < self.collision_radius: start = self.pos_to_canvas(pos1) end = self.pos_to_canvas(pos2) pygame.draw.line(self.screen, (0, 0, 0), start, end) - # draw targets - if self.targets is not None: - for i, state in enumerate(self.targets.values()): - centre = self.pos_to_canvas(state[0].detach().cpu().numpy()) - pygame.draw.circle( - self.screen, self.agent_cols[i], centre, self.agent_r_pix, 3 - ) - - # draw square - if self.square_side is not None: - side = self.square_side * self.scale - left = (self.w - side) / 2 - top = (self.h - side) / 2 - pygame.draw.rect(self.screen, (0, 100, 255), (left, top, side, side), 3) + # draw agents + for i, state in enumerate(self.target_history.values()): + pos = state[0, :, self.t].detach().cpu().numpy() + centre = self.pos_to_canvas(pos) + pygame.draw.circle( + self.screen, self.agent_cols[i], centre, self.agent_r_pix, 3 + ) + + # draw line between agent and target # draw text ssshow = self.myfont.render( @@ -162,7 +202,7 @@ def draw_next(self): ) os.system( f"ffmpeg -r 4 -i {self.tmp_dir}/%06d.png -i {self.tmp_dir}/palette.png" - " -lavfi paletteuse {self.video_file}" + f" -lavfi paletteuse {self.video_file}" ) else: raise ValueError("video file must be either mp4 or gif.") @@ -171,6 +211,7 @@ def draw_next(self): def pos_to_canvas(self, pos): x = (pos - self.range[0]) / (self.range[1] - self.range[0]) + x[1] = 1 - x[1] return x * np.array([self.h, self.w]) def save_image(self): From 89f0c78ff4b61b9f1912c0cf886f6d8ec4220bf5 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Tue, 20 Sep 2022 12:34:20 +0100 Subject: [PATCH 46/64] fixed jacobians for gnn factor --- theseus/optimizer/gbp/swarm.py | 133 +++++++++++++++++--------- theseus/optimizer/gbp/swarm_viewer.py | 6 ++ 2 files changed, 94 insertions(+), 45 deletions(-) diff --git a/theseus/optimizer/gbp/swarm.py b/theseus/optimizer/gbp/swarm.py index 2d83a3ab4..cbf1e5bd0 100644 --- a/theseus/optimizer/gbp/swarm.py +++ b/theseus/optimizer/gbp/swarm.py @@ -3,7 +3,7 @@ import omegaconf import time from PIL import Image, ImageDraw, ImageFont, ImageFilter -from typing import Optional, Tuple, List +from typing import Optional, Tuple, List, Callable import torch import torch.nn as nn @@ -51,15 +51,17 @@ def gen_char_img( def target_char_loss(outputs, sdf): positions = torch.cat(list(outputs.values())) dists = sdf.signed_distance(positions)[0] + if torch.sum(dists == 0).item() != 0: + print("\n\nNumber of agents out of bounds: ", torch.sum(dists == 0).item()) loss = torch.relu(dists) return loss.sum() def gen_target_sdf(cfg): # setup target shape for outer loop loss fn - area_limits = np.array(cfg["setup"]["area_limits"]) + vis_limits = np.array(cfg["setup"]["vis_limits"]) cell_size = 0.05 - img_size = tuple(np.rint((area_limits[1] - area_limits[0]) / cell_size).astype(int)) + img_size = tuple(np.rint((vis_limits[1] - vis_limits[0]) / cell_size).astype(int)) img = gen_char_img( cfg["outer_optim"]["target_char"], dilate=True, @@ -70,10 +72,18 @@ def gen_target_sdf(cfg): occ_map = torch.flip( occ_map, [0] ) # flip vertically so y axis is upwards wrt character + # pad to expand area + area_limits = np.array(cfg["setup"]["area_limits"]) + padded_size = tuple( + np.rint((area_limits[1] - area_limits[0]) / cell_size).astype(int) + ) + pad = int((padded_size[0] - img_size[0]) / 2) + larger_occ_map = torch.zeros(padded_size) + larger_occ_map[pad:-pad, pad:-pad] = occ_map sdf = th.eb.SignedDistanceField2D( th.Variable(torch.Tensor(area_limits[0][None, :])), th.Variable(torch.Tensor([cell_size])), - occupancy_map=th.Variable(occ_map[None, :]), + occupancy_map=th.Variable(larger_occ_map[None, :]), ) return sdf @@ -151,6 +161,59 @@ def _copy_impl(self, new_name: Optional[str] = None) -> "TwoAgentsCollision": ) +# custom factor for GNN +class GNNTargets(th.CostFunction): + def __init__( + self, + weight: th.CostWeight, + agents: List[th.Point2], + gnn_err_fn: Callable, + name: Optional[str] = None, + ): + super().__init__(weight, name=name) + self.agents = agents + self.n_agents = len(agents) + self._gnn_err_fn = gnn_err_fn + # skips data checks + for agent in self.agents: + setattr(self, agent.name, agent) + self.register_optim_vars([v.name for v in agents]) + + # no error when distance exceeds radius + def error(self) -> torch.Tensor: + return self._gnn_err_fn(self.agents) + + # Cannot use autodiff for jacobians as we want the factor to be + # independent for each agent. i.e. GNN is implemented as many prior factors + def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: + batch_size = self.agents[0].shape[0] + jacs = torch.zeros( + batch_size, + self.n_agents, + self.dim(), + 2, + dtype=self.agents[0].dtype, + device=self.agents[0].device, + ) + jacs[:, torch.arange(self.n_agents), 2 * torch.arange(self.n_agents), 0] = 1.0 + jacs[ + :, torch.arange(self.n_agents), 2 * torch.arange(self.n_agents) + 1, 1 + ] = 1.0 + jac_list = [jacs[:, i] for i in range(self.n_agents)] + return jac_list, self.error() + + def dim(self) -> int: + return self.n_agents * 2 + + def _copy_impl(self, new_name: Optional[str] = None) -> "GNNTargets": + return GNNTargets( + self.weight.copy(), + [agent.copy() for agent in self.agents], + self._gnn_err_fn, + name=new_name, + ) + + def setup_problem(cfg: omegaconf.OmegaConf, gnn_err_fn): dtype = torch.float32 n_agents = cfg["setup"]["num_agents"] @@ -196,32 +259,15 @@ def setup_problem(cfg: omegaconf.OmegaConf, gnn_err_fn): ) objective.add(collision_cf) - # learned factors, encouraging a square formation - # target_weight = th.ScaleCostWeight( - # torch.tensor([cfg["setup"]["origin_weight"] * 10], dtype=dtype) - # ) - # for i in range(n_agents): - # target = th.Point2( - # tensor=torch.normal(torch.zeros(2), cfg["setup"]["init_std"]), - # name=f"target_{i}", - # ) - # target_cf = th.Difference( - # positions[i], - # target, - # target_weight, - # name=f"formation_target_{i}", - # ) - # objective.add(target_cf) - - # GNN factor - MLP that takes in all current belief means and outputs all targets + # GNN factor - GNN takes in all current belief means and outputs all targets target_weight = th.ScaleCostWeight( - torch.tensor([cfg["setup"]["origin_weight"] * 10], dtype=dtype) + torch.tensor([cfg["setup"]["gnn_target_weight"]], dtype=dtype) ) - gnn_cf = th.AutoDiffCostFunction( - optim_vars=positions, - err_fn=gnn_err_fn, - dim=n_agents, - cost_weight=target_weight, + gnn_cf = GNNTargets( + weight=target_weight, + agents=positions, + gnn_err_fn=gnn_err_fn, + name="gnn_factor", ) objective.add(gnn_cf) @@ -284,20 +330,12 @@ def __init__(self, cfg): # network outputs offset for target from agent position # cost is zero when offset is zero, i.e. agent is at the target - def _gnn_err_fn(self, optim_vars: List[th.Manifold], aux_vars): - assert len(aux_vars) == 0 - positions = optim_vars - - batch_size = positions[0].shape[0] + def _gnn_err_fn(self, positions: List[th.Manifold]): flattened_pos = torch.cat( [pos.tensor.unsqueeze(1) for pos in positions], dim=1 ).flatten(1, 2) - offsets = self.gnn(flattened_pos) - offsets = offsets.reshape(batch_size, len(positions), 2) - err = offsets.norm(dim=-1) - - return err + return offsets def forward(self, track_history=False): @@ -331,7 +369,7 @@ def forward(self, track_history=False): target_hist = target_hist.permute(0, 2, 3, 1) # time back to last dim for i in range(target_hist.shape[1]): - history[f"target_{i}"] = target_hist[:, i] + history[f"agent_{i}"] + history[f"target_{i}"] = -target_hist[:, i] + history[f"agent_{i}"] return outputs, history @@ -346,7 +384,7 @@ def main(cfg: omegaconf.OmegaConf): model.gnn.parameters(), lr=cfg["outer_optim"]["lr"] ) - viewer = SwarmViewer(cfg["setup"]["collision_radius"], cfg["setup"]["area_limits"]) + viewer = SwarmViewer(cfg["setup"]["collision_radius"], cfg["setup"]["vis_limits"]) losses = [] for epoch in range(cfg["outer_optim"]["num_epochs"]): @@ -354,8 +392,7 @@ def main(cfg: omegaconf.OmegaConf): start_time = time.time_ns() outer_optimizer.zero_grad() - track_history = epoch % 5 == 0 - + track_history = False # epoch % 20 == 0 outputs, history = model.forward(track_history=track_history) loss = target_char_loss(outputs, sdf) @@ -373,6 +410,10 @@ def main(cfg: omegaconf.OmegaConf): print("Loss values:", losses) + import ipdb + + ipdb.set_trace() + # outputs visualisations # viewer.vis_outer_targets_optim( # targets_history, @@ -389,13 +430,15 @@ def main(cfg: omegaconf.OmegaConf): "out_video_file": "outputs/swarm/inner_mlp.gif", "outer_optim_video_file": "outputs/swarm/outer_targets_mlp.gif", "setup": { - "num_agents": 100, + "num_agents": 50, "init_std": 1.0, "agent_radius": 0.1, "collision_radius": 1.0, "origin_weight": 0.1, "collision_weight": 1.0, - "area_limits": [[-3, -3], [3, 3]], + "gnn_target_weight": 10.0, + "area_limits": [[-20, -20], [20, 20]], + "vis_limits": [[-3, -3], [3, 3]], }, "optim": { "max_iters": 20, @@ -413,7 +456,7 @@ def main(cfg: omegaconf.OmegaConf): }, }, "outer_optim": { - "num_epochs": 50, + "num_epochs": 100, "lr": 2e-2, "optimizer": "adam", "target_char": "A", diff --git a/theseus/optimizer/gbp/swarm_viewer.py b/theseus/optimizer/gbp/swarm_viewer.py index ee526913e..f98893810 100644 --- a/theseus/optimizer/gbp/swarm_viewer.py +++ b/theseus/optimizer/gbp/swarm_viewer.py @@ -127,6 +127,12 @@ def draw_next(self): # draw target shape as background if self.target_sdf is not None: sdf = self.target_sdf.sdf_data.tensor[0].transpose(0, 1) + sdf_size = self.target_sdf.cell_size.tensor.item() * sdf.shape[0] + area_size = self.range[1, 0] - self.range[0, 0] + crop = np.round((1 - area_size / sdf_size) * sdf.shape[0] / 2).astype( + int + ) + sdf = sdf[crop:-crop, crop:-crop] sdf = torch.flip( sdf, [1] ) # flip vertically so y is increasing going up From 8a75767a43b00ce5fb69bd46b17d0b775c5e116c Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Mon, 26 Sep 2022 11:29:27 +0100 Subject: [PATCH 47/64] implicit backward mode for GBP using GN step --- theseus/optimizer/gbp/backward_analysis.py | 160 ++++++++++++++ theseus/optimizer/gbp/bundle_adjustment.py | 229 ++++++++++++++++----- theseus/optimizer/gbp/gbp.py | 64 ++++-- theseus/optimizer/gbp/pgo_test.py | 5 + 4 files changed, 386 insertions(+), 72 deletions(-) create mode 100644 theseus/optimizer/gbp/backward_analysis.py diff --git a/theseus/optimizer/gbp/backward_analysis.py b/theseus/optimizer/gbp/backward_analysis.py new file mode 100644 index 000000000..3c6c2d7de --- /dev/null +++ b/theseus/optimizer/gbp/backward_analysis.py @@ -0,0 +1,160 @@ +import numpy as np +import os +import json + +import matplotlib.pylab as plt + + +def plot_timing_memory(root): + dirs = os.listdir(root) + dirs.remove("figs") + + timings = {} + memory = {} + + for direc in dirs: + + with open(os.path.join(root, direc, "timings.txt"), "r") as f: + timings[direc] = json.load(f) + with open(os.path.join(root, direc, "memory.txt"), "r") as f: + memory[direc] = json.load(f) + + fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(15, 3)) + fig.subplots_adjust(hspace=0.0, wspace=0.3) + + exps = ["full", "implicit", "truncated_5", "truncated_10"] + labels = ["Unroll", "Implicit", "Trunc-5", "Trunc-10"] + colors = ["C0", "C1", "C2", "C3"] + markers = [".", "v", "o", "s"] + inner_iters = [25, 50, 100, 150, 200, 500] + + for i, exp in enumerate(exps): + + fwd_times = [] + bwd_times = [] + fwd_memory = [] + bwd_memory = [] + for iters in inner_iters: + key = f"{str(iters)}_{exp}" + fwd_times.append(np.mean(timings[key]["fwd"]) / 1e3) + bwd_times.append(np.mean(timings[key]["bwd"]) / 1e3) + fwd_memory.append(np.mean(memory[key]["fwd"])) + bwd_memory.append(np.mean(memory[key]["bwd"])) + + col = colors[i] + m = markers[i] + ax[0].plot(inner_iters, fwd_times, color=col, marker=m, label=labels[i]) + ax[1].plot(inner_iters, bwd_times, color=col, marker=m) + ax[2].plot(inner_iters, fwd_memory, color=col, marker=m) + ax[3].plot(inner_iters, bwd_memory, color=col, marker=m) + + title_fontsize = 11 + ax[0].title.set_text("Forward time") + ax[1].title.set_text("Backward time") + ax[2].title.set_text("Forward memory") + ax[3].title.set_text("Backward memory") + ax[0].title.set_size(title_fontsize) + ax[1].title.set_size(title_fontsize) + ax[2].title.set_size(title_fontsize) + ax[3].title.set_size(title_fontsize) + + ax[0].set_xlabel("Inner loop iterations") + ax[1].set_xlabel("Inner loop iterations") + ax[2].set_xlabel("Inner loop iterations") + ax[3].set_xlabel("Inner loop iterations") + ax[0].set_ylabel("Time (seconds)") + ax[1].set_ylabel("Time (seconds)") + ax[2].set_ylabel("Memory (MBs)") + ax[3].set_ylabel("Memory (MBs)") + + ax[0].legend( + loc="lower center", + bbox_to_anchor=(2.5, -0.5), + fancybox=True, + ncol=4, + fontsize=10, + ) + + # plt.tight_layout() + plt.subplots_adjust(bottom=0.3) + plt.show() + + +def plot_loss_traj(root, ref_loss=None): + + exps = ["full", "implicit", "truncated_5", "truncated_10"] + labels = ["Unroll", "Implicit", "Trunc-5", "Trunc-10"] + colors = ["C0", "C1", "C2", "C3"] + inner_iters = [25, 50, 100, 150, 200, 500] + + fig_loss, ax_loss = plt.subplots(nrows=1, ncols=len(inner_iters), figsize=(20, 3)) + fig_loss.subplots_adjust(hspace=0.0, wspace=0.5) + + fig_traj, ax_traj = plt.subplots(nrows=len(inner_iters), ncols=4, figsize=(20, 15)) + fig_traj.subplots_adjust(hspace=0.75, wspace=0.4) + + for i, iters in enumerate(inner_iters): + + for j, exp in enumerate(exps): + direc = f"{str(iters)}_{exp}" + + # plot sweep curves + if j == 0: + sweep_radii = np.loadtxt(os.path.join(root, direc, "sweep_radius.txt")) + sweep_loss = np.loadtxt(os.path.join(root, direc, "sweep_loss.txt")) + for k in range(len(exps)): + ax_traj[i, k].plot(sweep_radii, sweep_loss) + ax_traj[i, k].title.set_text(labels[k]) + + # plot trajectory over epochs + loss_traj = np.loadtxt(os.path.join(root, direc, "optim_loss.txt")) + radius_traj = np.loadtxt(os.path.join(root, direc, "optim_radius.txt")) + ax_traj[i, j].scatter( + radius_traj, + loss_traj, + c=range(len(loss_traj)), + cmap=plt.get_cmap("viridis"), + ) + + # plot loss over epochs + label = labels[j] if i == 0 else None + if ref_loss is not None: + loss_traj = np.array(loss_traj) * ref_loss + ref_loss + ax_loss[i].plot(loss_traj, color=colors[j], marker=None, label=label) + + ax_loss[0].legend( + loc="lower center", + bbox_to_anchor=(3.5, -0.5), + fancybox=True, + ncol=4, + fontsize=10, + ) + + title_fontsize = 11 + for i, iters in enumerate(inner_iters): + ax_loss[i].title.set_text(f"Train loss ({iters} inner GBP steps)") + ax_loss[i].title.set_size(title_fontsize) + ax_loss[i].set_xlabel("Epoch") + ax_loss[i].set_ylabel("Camera Loss") + + for j in range(4): + ax_traj[i, j].set_xlabel("Huber loss radius") + if j == 0: + ax_traj[i, j].set_ylabel(f"{iters} inner steps\n\n\nCamera Loss") + else: + ax_traj[i, j].set_ylabel("Camera Loss") + + fig_loss.subplots_adjust(bottom=0.3) + plt.show() + + +if __name__ == "__main__": + + root = ( + "/home/joe/projects/mpSLAM/theseus/theseus/optimizer/gbp/" + + "outputs/loss_radius_exp/backward_analysis/" + ) + + plot_timing_memory(root) + + plot_loss_traj(root, ref_loss=None) # 49.87 diff --git a/theseus/optimizer/gbp/bundle_adjustment.py b/theseus/optimizer/gbp/bundle_adjustment.py index e9a4cb666..8af1ad83c 100644 --- a/theseus/optimizer/gbp/bundle_adjustment.py +++ b/theseus/optimizer/gbp/bundle_adjustment.py @@ -37,14 +37,53 @@ "synchronous": GBPSchedule.SYNCHRONOUS, } +BACKWARD_MODE = { + "full": th.BackwardMode.FULL, + "implicit": th.BackwardMode.IMPLICIT, + "truncated": th.BackwardMode.TRUNCATED, + "dlm": th.BackwardMode.DLM, +} + + +def start_timing(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + else: + start = time.perf_counter() + end = None + return start, end + + +def end_timing(start, end): + if torch.cuda.is_available(): + torch.cuda.synchronize() + end.record() + # Waits for everything to finish running + torch.cuda.synchronize() + elapsed_time = start.elapsed_time(end) + else: + end = time.perf_counter() + elapsed_time = end - start + # Convert to milliseconds to have the same units + # as torch.cuda.Event.elapsed_time + elapsed_time = elapsed_time * 1000 + return elapsed_time -def save_res_loss_rad(save_dir, cfg, sweep_radii, sweep_losses, radius_vals, losses): + +def save_res_loss_rad( + save_dir, cfg, radius_vals, losses, sweep_radii=None, sweep_losses=None +): with open(f"{save_dir}/config.txt", "w") as f: json.dump(cfg, f, indent=4) # sweep values - np.savetxt(f"{save_dir}/sweep_radius.txt", sweep_radii) - np.savetxt(f"{save_dir}/sweep_loss.txt", sweep_losses) + if sweep_radii is not None: + np.savetxt(f"{save_dir}/sweep_radius.txt", sweep_radii) + if sweep_losses is not None: + np.savetxt(f"{save_dir}/sweep_loss.txt", sweep_losses) # optim trajectory np.savetxt(f"{save_dir}/optim_radius.txt", radius_vals) @@ -241,7 +280,8 @@ def setup_layer(cfg: omegaconf.OmegaConf): "track_err_history": True, "track_state_history": cfg["optim"]["track_state_history"], "verbose": True, - "backward_mode": th.BackwardMode.FULL, + "backward_mode": BACKWARD_MODE[cfg["optim"]["backward_mode"]], + "backward_num_iterations": cfg["optim"]["backward_num_iterations"], } if isinstance(optimizer, GaussianBeliefPropagation): gbp_args = cfg["optim"]["gbp_settings"].copy() @@ -305,7 +345,7 @@ def run_inner( """ Save for nesterov experiments """ - save_dir = os.getcwd() + "/outputs/nesterov/bal/" + save_dir = os.getcwd() + "/outputs/nesterov/synthetic_large/" if cfg["optim"]["gbp_settings"]["nesterov"]: save_dir += "1/" else: @@ -362,7 +402,13 @@ def run_inner( # np.savetxt(save_file, np.array(ares)) -def run_outer(cfg: omegaconf.OmegaConf): +def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): + + torch.manual_seed(cfg["seed"]) + np.random.seed(cfg["seed"]) + random.seed(cfg["seed"]) + + print(f"\nRunning experiment. Save directory: {out_dir}\n") ( theseus_optim, @@ -387,55 +433,80 @@ def run_outer(cfg: omegaconf.OmegaConf): print(f"CAMERA LOSS (no learning): {camera_loss_ref: .3f}") print_histogram(ba, theseus_inputs, "Input histogram:") - import matplotlib.pylab as plt - - sweep_radii = torch.linspace(0.01, 5.0, 20) - sweep_losses = [] - with torch.set_grad_enabled(False): - for r in sweep_radii: - theseus_inputs["log_loss_radius"][0] = r - - print(theseus_inputs["log_loss_radius"]) - - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=optim_arg, - ) - cam_loss = camera_loss(ba, camera_pose_vars) - loss = (cam_loss - camera_loss_ref) / camera_loss_ref - sweep_losses.append(torch.sum(loss.detach()).item()) + # import matplotlib.pylab as plt + sweep_radii, sweep_losses = None, None + if do_sweep: + sweep_radii = torch.linspace(0.01, 5.0, 20, dtype=torch.float64) + sweep_losses = [] + sweep_arg = optim_arg.copy() + sweep_arg["verbose"] = False + with torch.set_grad_enabled(False): + for radius in sweep_radii: + radius = radius.to(cfg["device"]) + theseus_inputs["log_loss_radius"] = radius.unsqueeze(0).unsqueeze(0) + + theseus_outputs, info = theseus_optim.forward( + input_tensors=theseus_inputs, + optimizer_kwargs=sweep_arg, + ) + cam_loss = camera_loss(ba, camera_pose_vars) + loss = (cam_loss - camera_loss_ref) / camera_loss_ref + sweep_losses.append(torch.sum(loss.detach()).item()) + print( + f"SWEEP radius {radius}, camera loss {cam_loss.detach().item():.3f}," + f" loss {sweep_losses[-1]:.3f}, ref loss {camera_loss_ref:.3f}" + ) - plt.plot(sweep_radii, sweep_losses) - plt.xlabel("Log loss radius") - plt.ylabel("(Camera loss - reference loss) / reference loss") + # plt.plot(sweep_radii, sweep_losses) + # plt.xlabel("Log loss radius") + # plt.ylabel("(Camera loss - reference loss) / reference loss") + # plt.show() losses = [] radius_vals = [] - theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() + theseus_inputs["log_loss_radius"] = ( + loss_radius_tensor.unsqueeze(1).clone().to(cfg["device"]) + ) + + times: Dict = {"fwd": [], "bwd": []} + memory: Dict = {"fwd": [], "bwd": []} for epoch in range(cfg["outer"]["num_epochs"]): print(f" ******************* EPOCH {epoch} ******************* ") start_time = time.time_ns() model_optimizer.zero_grad() - theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() + theseus_inputs["log_loss_radius"] = ( + loss_radius_tensor.unsqueeze(1).clone().to(cfg["device"]) + ) + start, end = start_timing() + torch.cuda.reset_peak_memory_stats() theseus_outputs, info = theseus_optim.forward( input_tensors=theseus_inputs, optimizer_kwargs=optim_arg, ) + times["fwd"].append(end_timing(start, end)) + memory["fwd"].append(torch.cuda.max_memory_allocated() / 1048576) cam_loss = camera_loss(ba, camera_pose_vars) loss = (cam_loss - camera_loss_ref) / camera_loss_ref + + start, end = start_timing() + torch.cuda.reset_peak_memory_stats() loss.backward() + times["bwd"].append(end_timing(start, end)) + memory["bwd"].append(torch.cuda.max_memory_allocated() / 1048576) radius_vals.append(loss_radius_tensor.data.item()) - print(loss_radius_tensor.grad) + # correct for implicit gradients step size != 1 + if cfg["optim"]["backward_mode"] == "implicit": + loss_radius_tensor.grad /= theseus_optim.optimizer.implicit_step_size model_optimizer.step() loss_value = torch.sum(loss.detach()).item() losses.append(loss_value) end_time = time.time_ns() # print_histogram(ba, theseus_outputs, "Output histogram:") - print(f"camera loss {cam_loss} and ref loss {camera_loss_ref}") + print(f"camera loss {cam_loss.detach().item()} and ref loss {camera_loss_ref}") print( f"Epoch: {epoch} Loss: {loss_value} " # f"Lin system damping {lin_system_damping}" @@ -447,24 +518,54 @@ def run_outer(cfg: omegaconf.OmegaConf): print("Loss values:", losses) now = datetime.now() - time_str = now.strftime("%m-%d-%y_%H-%M-%S") - save_dir = os.getcwd() + "/outputs/loss_radius_exp/" + time_str + if out_dir is None: + out_dir = now.strftime("%m-%d-%y_%H-%M-%S") + save_dir = os.getcwd() + "/outputs/loss_radius_exp/" + out_dir os.mkdir(save_dir) + with open(f"{save_dir}/config.txt", "w") as f: + json.dump(cfg, f, indent=4) - save_res_loss_rad(save_dir, cfg, sweep_radii, sweep_losses, radius_vals, losses) + print("\n=== Runtimes") + k1, k2 = "fwd", "bwd" + print(f"Forward: {np.mean(times[k1]):.2e} s +/- {np.std(times[k1]):.2e} s") + print(f"Backward (FULL): {np.mean(times[k2]):.2e} s +/- {np.std(times[k2]):.2e} s") - plt.scatter(radius_vals, losses, c=range(len(losses)), cmap=plt.get_cmap("viridis")) - plt.title(cfg["optim"]["optimizer_cls"] + " - " + time_str) - plt.show() + print("\n=== Memory") + k1, k2 = "fwd", "bwd" + print(f"Forward: {np.mean(memory[k1]):.2e} MB +/- {np.std(memory[k1]):.2e} MB") + print( + f"Backward (FULL): {np.mean(memory[k2]):.2e} MB +/- {np.std(memory[k2]):.2e} MB" + ) + + with open(f"{save_dir}/timings.txt", "w") as f: + json.dump(times, f, indent=4) + with open(f"{save_dir}/memory.txt", "w") as f: + json.dump(memory, f, indent=4) + + with open(f"{save_dir}/ref_loss.txt", "w") as f: + f.write(f"{camera_loss_ref:.5f}") + + save_res_loss_rad( + save_dir, + cfg, + radius_vals, + losses, + sweep_radii=sweep_radii, + sweep_losses=sweep_losses, + ) + + # plt.scatter(radius_vals, losses, c=range(len(losses)), cmap=plt.get_cmap("viridis")) + # plt.title(cfg["optim"]["optimizer_cls"] + " - " + dir_name) + # plt.show() if __name__ == "__main__": - cfg = { + cfg: Dict = { "seed": 1, - "device": "cpu", - # "bal_file": None, - "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", + "device": "cuda", + "bal_file": None, + # "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", "synthetic": { "num_cameras": 10, "num_points": 100, @@ -472,11 +573,13 @@ def run_outer(cfg: omegaconf.OmegaConf): "track_locality": 0.2, }, "optim": { - "max_iters": 300, + "max_iters": 100, "vectorize": True, "optimizer_cls": "gbp", # "optimizer_cls": "gauss_newton", # "optimizer_cls": "levenberg_marquardt", + "backward_mode": "implicit", + "backward_num_iterations": 10, "track_state_history": True, "regularize": True, "ratio_known_cameras": 0.1, @@ -487,21 +590,43 @@ def run_outer(cfg: omegaconf.OmegaConf): "dropout": 0.0, "schedule": "synchronous", "lin_system_damping": 1.0e-2, - "nesterov": True, + "nesterov": False, }, }, "outer": { - "num_epochs": 15, - "lr": 1e2, # 5.0e-1, + "num_epochs": 20, + "lr": 5.0e1, # 5.0e-1, "optimizer": "sgd", }, } - torch.manual_seed(cfg["seed"]) - np.random.seed(cfg["seed"]) - random.seed(cfg["seed"]) - - args = setup_layer(cfg) - run_inner(*args) - - # run_outer(cfg) + # args = setup_layer(cfg) + # run_inner(*args) + + # run_outer(cfg, "implicit_test", do_sweep=False) + + for max_iters in [25, 50, 100, 150, 200, 500]: + for backward_mode in ["implicit"]: + cfg_copy = cfg.copy() + cfg_copy["optim"]["max_iters"] = max_iters + cfg_copy["optim"]["backward_mode"] = backward_mode + + dir_name = str(max_iters) + "_" + backward_mode + + if backward_mode == "truncated": + for backward_num_iterations in [5, 10]: + cfg_copy["optim"][ + "backward_num_iterations" + ] = backward_num_iterations + dir_name = ( + str(max_iters) + + "_" + + backward_mode + + "_" + + str(cfg["optim"]["backward_num_iterations"]) + ) + + run_outer(cfg_copy, dir_name) + else: + do_sweep = True if backward_mode == "full" else False + run_outer(cfg_copy, dir_name, do_sweep=do_sweep) diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index 308d771ff..bf0f2bae3 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -185,7 +185,7 @@ def __init__( self.batch_size, device=device, dtype=torch.int ) - self.lm_damping = lin_system_damping.repeat(self.batch_size) + self.lm_damping = lin_system_damping.repeat(self.batch_size).to(device) self.min_damping = torch.Tensor([1e-4]).to(dtype=dtype, device=device) self.max_damping = torch.Tensor([1e2]).to(dtype=dtype, device=device) self.last_err: torch.Tensor = None @@ -263,7 +263,7 @@ def linearize( # damp precision matrix damped_D = self.lm_damping[:, None, None] * torch.eye( lam.shape[1], device=lam.device, dtype=lam.dtype - ).unsqueeze(0).repeat(self.batch_size, 1, 1) + ).unsqueeze(0).repeat(self.batch_size, 1, 1).to(self.lm_damping.device) lam = lam + damped_D self.potential_eta[do_lin] = eta[do_lin] @@ -1071,6 +1071,7 @@ def _optimize_impl( schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), nesterov: bool = False, + implicit_step_size: float = 1e-4, **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): @@ -1107,6 +1108,11 @@ def _optimize_impl( f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" ) + assert backward_mode in [ + BackwardMode.FULL, + BackwardMode.IMPLICIT, + BackwardMode.TRUNCATED, + ] if backward_mode == BackwardMode.FULL: self._optimize_loop( num_iter=self.params.max_iterations, @@ -1130,7 +1136,7 @@ def _optimize_impl( elif backward_mode in [BackwardMode.IMPLICIT, BackwardMode.TRUNCATED]: if backward_mode == BackwardMode.IMPLICIT: - backward_num_iterations = 1 + backward_num_iterations = 0 else: if "backward_num_iterations" not in kwargs: raise ValueError( @@ -1167,23 +1173,41 @@ def _optimize_impl( grad_loop_info = self._init_info( track_best_solution, track_err_history, track_state_history ) - grad_iters_done = self._optimize_loop( - num_iter=backward_num_iterations, - info=grad_loop_info, - verbose=verbose, - truncated_grad_loop=True, - relin_threshold=relin_threshold, - ftov_msg_damping=ftov_msg_damping, - dropout=dropout, - schedule=schedule, - lin_system_damping=lin_system_damping, - nesterov=nesterov, - clear_messages=False, - **kwargs, - ) - - # Adds grad_loop_info results to original info - self._merge_infos(grad_loop_info, no_grad_iters_done, grad_iters_done, info) + if backward_mode == BackwardMode.TRUNCATED: + grad_iters_done = self._optimize_loop( + num_iter=backward_num_iterations, + info=grad_loop_info, + verbose=verbose, + truncated_grad_loop=True, + relin_threshold=relin_threshold, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + nesterov=nesterov, + clear_messages=False, + **kwargs, + ) + # Adds grad_loop_info results to original info + self._merge_infos( + grad_loop_info, no_grad_iters_done, grad_iters_done, info + ) + else: + # use Gauss-Newton update to compute implicit gradient + self.implicit_step_size = implicit_step_size + gauss_newton_optimizer = th.GaussNewton(self.objective) + gauss_newton_optimizer.linear_solver.linearization.linearize() + delta = gauss_newton_optimizer.linear_solver.solve() + self.objective.retract_optim_vars( + delta * implicit_step_size, + gauss_newton_optimizer.linear_solver.linearization.ordering, + force_update=True, + ) + if verbose: + err = self.objective.error_squared_norm() / 2 + print( + f"Nonlinear optimizer implcit step. Error: {err.mean().item()}" + ) return info else: diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index 920df7dd9..4f8fee43b 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -172,6 +172,7 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): "damping": 0.0, "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, + "implicit_step_size": 1e-5, } outputs_gbp, info = theseus_optim.forward(inputs, optim_arg) @@ -179,6 +180,8 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): out_gbp_tensor = torch.cat(list(outputs_gbp.values())) loss = torch.norm(gt_poses_tensor - out_gbp_tensor) loss.backward() + if backward_mode == th.BackwardMode.IMPLICIT: + meas_std_tensor.grad /= optimizer.implicit_step_size print("loss", loss.item()) print("grad", meas_std_tensor.grad.item()) @@ -191,3 +194,5 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): gbp_solve_pgo(backward_mode=th.BackwardMode.FULL, max_iterations=20) gbp_solve_pgo(backward_mode=th.BackwardMode.TRUNCATED, max_iterations=20) + +gbp_solve_pgo(backward_mode=th.BackwardMode.IMPLICIT, max_iterations=20) From 75c315714a623842e2be94172744b7f8eaf04659 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 30 Sep 2022 09:00:16 +0100 Subject: [PATCH 48/64] implicit derivatives using gbp and plot backward modes against time --- theseus/optimizer/gbp/backward_analysis.py | 60 ++++++++++++++++------ theseus/optimizer/gbp/bundle_adjustment.py | 12 +++-- theseus/optimizer/gbp/gbp.py | 41 +++++++++++++-- theseus/optimizer/gbp/pgo_test.py | 16 ++++-- 4 files changed, 102 insertions(+), 27 deletions(-) diff --git a/theseus/optimizer/gbp/backward_analysis.py b/theseus/optimizer/gbp/backward_analysis.py index 3c6c2d7de..6d16049a6 100644 --- a/theseus/optimizer/gbp/backward_analysis.py +++ b/theseus/optimizer/gbp/backward_analysis.py @@ -85,11 +85,16 @@ def plot_loss_traj(root, ref_loss=None): exps = ["full", "implicit", "truncated_5", "truncated_10"] labels = ["Unroll", "Implicit", "Trunc-5", "Trunc-10"] colors = ["C0", "C1", "C2", "C3"] - inner_iters = [25, 50, 100, 150, 200, 500] + inner_iters = [150, 200, 500] # [25, 50, 100, 150, 200, 500] fig_loss, ax_loss = plt.subplots(nrows=1, ncols=len(inner_iters), figsize=(20, 3)) fig_loss.subplots_adjust(hspace=0.0, wspace=0.5) + fig_loss_t, ax_loss_t = plt.subplots( + nrows=1, ncols=len(inner_iters), figsize=(20, 3) + ) + fig_loss_t.subplots_adjust(hspace=0.0, wspace=0.5) + fig_traj, ax_traj = plt.subplots(nrows=len(inner_iters), ncols=4, figsize=(20, 15)) fig_traj.subplots_adjust(hspace=0.75, wspace=0.4) @@ -116,26 +121,50 @@ def plot_loss_traj(root, ref_loss=None): cmap=plt.get_cmap("viridis"), ) - # plot loss over epochs + # plot loss over epochs or over total time label = labels[j] if i == 0 else None if ref_loss is not None: loss_traj = np.array(loss_traj) * ref_loss + ref_loss + with open(os.path.join(root, direc, "timings.txt"), "r") as f: + timings = json.load(f) + step_times = [ + timings["fwd"][i] + timings["bwd"][i] + for i in range(len(timings["fwd"])) + ] + step_times = np.array(step_times) / 1000 + cum_times = np.cumsum(step_times) ax_loss[i].plot(loss_traj, color=colors[j], marker=None, label=label) + ax_loss_t[i].plot( + cum_times, loss_traj, color=colors[j], marker=None, label=label + ) - ax_loss[0].legend( - loc="lower center", - bbox_to_anchor=(3.5, -0.5), - fancybox=True, - ncol=4, - fontsize=10, - ) + for ax in [ax_loss, ax_loss_t]: + ax[0].legend( + loc="lower center", + bbox_to_anchor=(2.0, -0.5), # (3.5, -0.5) + fancybox=True, + ncol=4, + fontsize=10, + ) + + # align y axes + ymin, ymax = 1e10, -1e10 + for ax in ax_loss: + ymin = min(ymin, ax.get_ylim()[0]) + ymax = max(ymax, ax.get_ylim()[1]) + for ax in ax_loss: + ax.set_ylim((ymin, ymax)) + for ax in ax_loss_t: + ax.set_ylim((ymin, ymax)) title_fontsize = 11 for i, iters in enumerate(inner_iters): - ax_loss[i].title.set_text(f"Train loss ({iters} inner GBP steps)") - ax_loss[i].title.set_size(title_fontsize) + for ax in [ax_loss, ax_loss_t]: + ax[i].title.set_text(f"Train loss ({iters} inner GBP steps)") + ax[i].title.set_size(title_fontsize) + ax[i].set_ylabel("Camera Loss") ax_loss[i].set_xlabel("Epoch") - ax_loss[i].set_ylabel("Camera Loss") + ax_loss_t[i].set_xlabel("Time (seconds)") for j in range(4): ax_traj[i, j].set_xlabel("Huber loss radius") @@ -145,16 +174,17 @@ def plot_loss_traj(root, ref_loss=None): ax_traj[i, j].set_ylabel("Camera Loss") fig_loss.subplots_adjust(bottom=0.3) + fig_loss_t.subplots_adjust(bottom=0.3) plt.show() if __name__ == "__main__": root = ( - "/home/joe/projects/mpSLAM/theseus/theseus/optimizer/gbp/" + "/home/joe/projects/theseus/theseus/optimizer/gbp/" + "outputs/loss_radius_exp/backward_analysis/" ) - plot_timing_memory(root) + # plot_timing_memory(root) - plot_loss_traj(root, ref_loss=None) # 49.87 + plot_loss_traj(root, ref_loss=None) diff --git a/theseus/optimizer/gbp/bundle_adjustment.py b/theseus/optimizer/gbp/bundle_adjustment.py index 8af1ad83c..48e3d4367 100644 --- a/theseus/optimizer/gbp/bundle_adjustment.py +++ b/theseus/optimizer/gbp/bundle_adjustment.py @@ -172,7 +172,8 @@ def load_problem(cfg: omegaconf.OmegaConf): def setup_layer(cfg: omegaconf.OmegaConf): ba = load_problem(cfg) - print("Optimizer:", cfg["optim"]["optimizer_cls"], "\n") + print("Optimizer:", cfg["optim"]["optimizer_cls"]) + print("Backward mode:", cfg["optim"]["backward_mode"], "\n") # param that control transition from squared loss to huber radius_tensor = torch.tensor([1.0], dtype=torch.float64) @@ -499,7 +500,9 @@ def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): radius_vals.append(loss_radius_tensor.data.item()) # correct for implicit gradients step size != 1 if cfg["optim"]["backward_mode"] == "implicit": - loss_radius_tensor.grad /= theseus_optim.optimizer.implicit_step_size + if theseus_optim.optimizer.implicit_method == "gauss_newton": + loss_radius_tensor.grad /= theseus_optim.optimizer.implicit_step_size + print("\ngrad is: ", loss_radius_tensor.grad, "\n") model_optimizer.step() loss_value = torch.sum(loss.detach()).item() losses.append(loss_value) @@ -573,7 +576,7 @@ def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): "track_locality": 0.2, }, "optim": { - "max_iters": 100, + "max_iters": 200, "vectorize": True, "optimizer_cls": "gbp", # "optimizer_cls": "gauss_newton", @@ -591,6 +594,7 @@ def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): "schedule": "synchronous", "lin_system_damping": 1.0e-2, "nesterov": False, + "implicit_method": "gbp", }, }, "outer": { @@ -605,7 +609,7 @@ def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): # run_outer(cfg, "implicit_test", do_sweep=False) - for max_iters in [25, 50, 100, 150, 200, 500]: + for max_iters in [150, 200, 500]: # 25, 50, 100, for backward_mode in ["implicit"]: cfg_copy = cfg.copy() cfg_copy["optim"]["max_iters"] = max_iters diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp/gbp.py index bf0f2bae3..f08bd846d 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp/gbp.py @@ -235,7 +235,6 @@ def linearize( if torch.sum(do_lin) > 0: # if any factor in the batch needs relinearization J, error = self.cf.weighted_jacobians_error() - J_stk = torch.cat(J, dim=-1) lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) @@ -936,6 +935,8 @@ def _create_factors_beliefs(self, lin_system_damping): "We require at least one unary cost function to act as a prior." "This is because Gaussian Belief Propagation is performing Bayesian inference." ) + if self.objective.vectorized: + self.objective.update_vectorization_if_needed() self._linearize_factors() self.n_individual_factors = ( @@ -963,17 +964,21 @@ def _optimize_loop( lin_system_damping: torch.Tensor, nesterov: bool, clear_messages: bool = True, + implicit_gbp_loop: bool = False, **kwargs, ): # we only create the factors and beliefs right before runnig GBP as they are # not automatically updated when objective.update is called. if clear_messages: self._create_factors_beliefs(lin_system_damping) + if implicit_gbp_loop: + relin_threshold = 1e10 # no relinearisation + if self.objective.vectorized: + self.objective.update_vectorization_if_needed() + self._linearize_factors() if schedule == GBPSchedule.SYNCHRONOUS: - ftov_schedule = synchronous_schedule( - self.params.max_iterations, self.n_edges - ) + ftov_schedule = synchronous_schedule(num_iter, self.n_edges) if nesterov: nest_lambda, nest_gamma = next_nesterov_params(0.0) @@ -1072,6 +1077,7 @@ def _optimize_impl( lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), nesterov: bool = False, implicit_step_size: float = 1e-4, + implicit_method: str = "gbp", **kwargs, ) -> NonlinearOptimizerInfo: with torch.no_grad(): @@ -1136,6 +1142,13 @@ def _optimize_impl( elif backward_mode in [BackwardMode.IMPLICIT, BackwardMode.TRUNCATED]: if backward_mode == BackwardMode.IMPLICIT: + self.implicit_method = implicit_method + implicit_methods = ["gauss_newton", "gbp"] + if implicit_method not in implicit_methods: + raise ValueError( + f"implicit_method must be one of {implicit_methods}, " + f"but got {implicit_method}" + ) backward_num_iterations = 0 else: if "backward_num_iterations" not in kwargs: @@ -1192,7 +1205,7 @@ def _optimize_impl( self._merge_infos( grad_loop_info, no_grad_iters_done, grad_iters_done, info ) - else: + elif implicit_method == "gauss_newton": # use Gauss-Newton update to compute implicit gradient self.implicit_step_size = implicit_step_size gauss_newton_optimizer = th.GaussNewton(self.objective) @@ -1208,6 +1221,24 @@ def _optimize_impl( print( f"Nonlinear optimizer implcit step. Error: {err.mean().item()}" ) + elif implicit_method == "gbp": + # solve normal equation in a distributed way + max_lin_solve_iters = 1000 + grad_iters_done = self._optimize_loop( + num_iter=max_lin_solve_iters, + info=grad_loop_info, + verbose=verbose, + truncated_grad_loop=True, + relin_threshold=1e10, + ftov_msg_damping=ftov_msg_damping, + dropout=dropout, + schedule=schedule, + lin_system_damping=lin_system_damping, + nesterov=nesterov, + clear_messages=False, + implicit_gbp_loop=True, + **kwargs, + ) return info else: diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py index 4f8fee43b..88cab8c2a 100644 --- a/theseus/optimizer/gbp/pgo_test.py +++ b/theseus/optimizer/gbp/pgo_test.py @@ -142,7 +142,7 @@ def linear_solve_pgo(): print("outputs\n", outputs_linsolve) -def gbp_solve_pgo(backward_mode, max_iterations=20): +def gbp_solve_pgo(backward_mode, max_iterations=20, implicit_method="gbp"): print("\n\nWith GBP...") print("backward mode:", backward_mode, "\n") @@ -173,15 +173,18 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): "dropout": 0.0, "schedule": GBPSchedule.SYNCHRONOUS, "implicit_step_size": 1e-5, + "implicit_method": implicit_method, } outputs_gbp, info = theseus_optim.forward(inputs, optim_arg) out_gbp_tensor = torch.cat(list(outputs_gbp.values())) loss = torch.norm(gt_poses_tensor - out_gbp_tensor) + loss.backward() if backward_mode == th.BackwardMode.IMPLICIT: - meas_std_tensor.grad /= optimizer.implicit_step_size + if optimizer.implicit_method == "gauss_newton": + meas_std_tensor.grad /= optimizer.implicit_step_size print("loss", loss.item()) print("grad", meas_std_tensor.grad.item()) @@ -195,4 +198,11 @@ def gbp_solve_pgo(backward_mode, max_iterations=20): gbp_solve_pgo(backward_mode=th.BackwardMode.TRUNCATED, max_iterations=20) -gbp_solve_pgo(backward_mode=th.BackwardMode.IMPLICIT, max_iterations=20) +gbp_solve_pgo( + backward_mode=th.BackwardMode.IMPLICIT, max_iterations=20, implicit_method="gbp" +) +gbp_solve_pgo( + backward_mode=th.BackwardMode.IMPLICIT, + max_iterations=20, + implicit_method="gauss_newton", +) From 53e4b910e01941d62ef152c2d6f0728475f650c8 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 30 Dec 2022 18:42:08 +0000 Subject: [PATCH 49/64] moved into optimizer, removed experiments --- theseus/optimizer/__init__.py | 1 + theseus/optimizer/{gbp => }/gbp.py | 8 +- theseus/optimizer/gbp/__init__.py | 8 - theseus/optimizer/gbp/ba_viewer.py | 209 ---- theseus/optimizer/gbp/backward_analysis.py | 190 ---- theseus/optimizer/gbp/bundle_adjustment.py | 629 ------------ theseus/optimizer/gbp/gbp_baseline.py | 775 --------------- theseus/optimizer/gbp/gbp_euclidean.py | 1047 -------------------- theseus/optimizer/gbp/jax_torch_poc.py | 488 --------- theseus/optimizer/gbp/pgo_test.py | 208 ---- theseus/optimizer/gbp/plot_ba_err.py | 68 -- theseus/optimizer/gbp/swarm.py | 470 --------- theseus/optimizer/gbp/swarm_viewer.py | 226 ----- theseus/optimizer/gbp/vectorize_poc.py | 119 --- 14 files changed, 5 insertions(+), 4441 deletions(-) rename theseus/optimizer/{gbp => }/gbp.py (99%) delete mode 100644 theseus/optimizer/gbp/__init__.py delete mode 100644 theseus/optimizer/gbp/ba_viewer.py delete mode 100644 theseus/optimizer/gbp/backward_analysis.py delete mode 100644 theseus/optimizer/gbp/bundle_adjustment.py delete mode 100644 theseus/optimizer/gbp/gbp_baseline.py delete mode 100644 theseus/optimizer/gbp/gbp_euclidean.py delete mode 100644 theseus/optimizer/gbp/jax_torch_poc.py delete mode 100644 theseus/optimizer/gbp/pgo_test.py delete mode 100644 theseus/optimizer/gbp/plot_ba_err.py delete mode 100644 theseus/optimizer/gbp/swarm.py delete mode 100644 theseus/optimizer/gbp/swarm_viewer.py delete mode 100644 theseus/optimizer/gbp/vectorize_poc.py diff --git a/theseus/optimizer/__init__.py b/theseus/optimizer/__init__.py index 9860df0f3..34461a7d0 100644 --- a/theseus/optimizer/__init__.py +++ b/theseus/optimizer/__init__.py @@ -9,3 +9,4 @@ from .optimizer import Optimizer, OptimizerInfo from .sparse_linearization import SparseLinearization from .variable_ordering import VariableOrdering +from .gbp import GaussianBeliefPropagation, GBPSchedule diff --git a/theseus/optimizer/gbp/gbp.py b/theseus/optimizer/gbp.py similarity index 99% rename from theseus/optimizer/gbp/gbp.py rename to theseus/optimizer/gbp.py index 09198a8f4..2da9509de 100644 --- a/theseus/optimizer/gbp/gbp.py +++ b/theseus/optimizer/gbp.py @@ -19,7 +19,7 @@ import theseus.constants from theseus.core import CostFunction, Objective from theseus.geometry import Manifold -from theseus.optimizer import Optimizer, VariableOrdering +from theseus.optimizer import Optimizer, VariableOrdering, ManifoldGaussian from theseus.optimizer.nonlinear.nonlinear_optimizer import ( BackwardMode, NonlinearOptimizerInfo, @@ -106,7 +106,7 @@ def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: # Initialises message precision to zero -class Message(th.ManifoldGaussian): +class Message(ManifoldGaussian): def __init__( self, mean: Sequence[Manifold], @@ -885,9 +885,9 @@ def _pass_fac_to_var_messages( def _create_factors_beliefs(self, lin_system_damping): self.factors: List[Factor] = [] - self.beliefs: List[th.ManifoldGaussian] = [] + self.beliefs: List[ManifoldGaussian] = [] for var in self.ordering: - self.beliefs.append(th.ManifoldGaussian([var])) + self.beliefs.append(ManifoldGaussian([var])) if self.objective.vectorized: cf_iterator = iter(self.objective.vectorized_cost_fns) diff --git a/theseus/optimizer/gbp/__init__.py b/theseus/optimizer/gbp/__init__.py deleted file mode 100644 index 4e58cbe99..000000000 --- a/theseus/optimizer/gbp/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -from .ba_viewer import BAViewer -from .swarm_viewer import SwarmViewer -from .gbp import GaussianBeliefPropagation, GBPSchedule diff --git a/theseus/optimizer/gbp/ba_viewer.py b/theseus/optimizer/gbp/ba_viewer.py deleted file mode 100644 index 04686c1a2..000000000 --- a/theseus/optimizer/gbp/ba_viewer.py +++ /dev/null @@ -1,209 +0,0 @@ -import threading - -import numpy as np -import pyglet -import torch -import trimesh -import trimesh.viewer - - -def draw_camera( - transform, fov, resolution, color=(0.0, 1.0, 0.0, 0.8), marker_height=12.0 -): - camera = trimesh.scene.Camera(fov=fov, resolution=resolution) - marker = trimesh.creation.camera_marker(camera, marker_height=marker_height) - marker[0].apply_transform(transform) - marker[1].apply_transform(transform) - marker[1].colors = (color,) * len(marker[1].entities) - - return marker - - -class BAViewer(trimesh.viewer.SceneViewer): - def __init__( - self, - state_history, - msg_history=None, - cam_to_world=False, - flip_z=True, - gt_cameras=None, - gt_points=None, - ): - self._it = 0 - self.state_history = state_history - self.msg_history = msg_history - self.cam_to_world = cam_to_world - self.flip_z = flip_z - self.lock = threading.Lock() - - self.num_iters = (~list(state_history.values())[0].isinf()[0, 0, 0]).sum() - - pts = [] - for k, state in state_history.items(): - if "Pt" in k: - pts.append(state[:, :, 0]) - extents = torch.cat(pts).max(dim=0)[0] - torch.cat(pts).min(dim=0)[0] - self.marker_height = extents.max().item() / 50 - - scene = trimesh.Scene() - self.scene = scene - - if gt_cameras is not None: - for i, cam in enumerate(gt_cameras): - camera = self.make_cam(cam.pose.tensor[0].cpu()) - self.scene.add_geometry(camera[1], geom_name=f"gt_cam_{i}") - - if gt_points is not None: - pts = torch.cat([pt.tensor.cpu() for pt in gt_points]) - pc = trimesh.PointCloud(pts, [0, 255, 0, 200]) - self.scene.add_geometry(pc, geom_name="gt_points") - - self.next_iteration() - scene.set_camera() - super(BAViewer, self).__init__(scene=scene, resolution=(1080, 720)) - - def on_key_press(self, symbol, modifiers): - """ - Call appropriate functions given key presses. - """ - magnitude = 10 - if symbol == pyglet.window.key.W: - self.toggle_wireframe() - elif symbol == pyglet.window.key.Z: - self.reset_view() - elif symbol == pyglet.window.key.C: - self.toggle_culling() - elif symbol == pyglet.window.key.A: - self.toggle_axis() - elif symbol == pyglet.window.key.G: - self.toggle_grid() - elif symbol == pyglet.window.key.Q: - self.on_close() - elif symbol == pyglet.window.key.M: - self.maximize() - elif symbol == pyglet.window.key.F: - self.toggle_fullscreen() - elif symbol == pyglet.window.key.P: - print(self.scene.camera_transform) - elif symbol == pyglet.window.key.N: - if self._it + 1 < self.num_iters: - self._it += 1 - print("Iteration", self._it) - self.next_iteration() - else: - print("No more iterations to view") - - if symbol in [ - pyglet.window.key.LEFT, - pyglet.window.key.RIGHT, - pyglet.window.key.DOWN, - pyglet.window.key.UP, - ]: - self.view["ball"].down([0, 0]) - if symbol == pyglet.window.key.LEFT: - self.view["ball"].drag([-magnitude, 0]) - elif symbol == pyglet.window.key.RIGHT: - self.view["ball"].drag([magnitude, 0]) - elif symbol == pyglet.window.key.DOWN: - self.view["ball"].drag([0, -magnitude]) - elif symbol == pyglet.window.key.UP: - self.view["ball"].drag([0, magnitude]) - self.scene.camera_transform[...] = self.view["ball"].pose - - def make_cam(self, pose, color=(0.0, 1.0, 0.0, 0.8)): - T = torch.vstack( - ( - pose, - torch.tensor([[0.0, 0.0, 0.0, 1.0]], dtype=pose.dtype), - ) - ) - if not self.cam_to_world: - T = np.linalg.inv(T) - if self.flip_z: - T[:3, 2] *= -1.0 - camera = draw_camera( - T, - self.scene.camera.fov, - self.scene.camera.resolution, - color=color, - marker_height=self.marker_height, - ) - return camera - - def next_iteration(self): - with self.lock: - points = [] - n_cams, n_pts = 0, 0 - for state in self.state_history.values(): - state = state[..., self._it].cpu() - if state.ndim == 3: - camera = self.make_cam(state[0], color=(0.0, 0.0, 1.0, 0.8)) - self.scene.delete_geometry(f"cam_{n_cams}") - self.scene.add_geometry(camera[1], geom_name=f"cam_{n_cams}") - n_cams += 1 - elif state.shape[1] == 3: - points.append(state) - - # cov = torch.linalg.inv(belief.precision[0]) - # ellipse = make_ellipse(point[0], cov) - # ellipse.visual.vertex_colors[:] = [255, 0, 0, 100] - - # self.scene.delete_geometry(f"ellipse_{n_pts}") - # self.scene.add_geometry(ellipse, geom_name=f"ellipse_{n_pts}") - n_pts += 1 - - points = torch.cat(points) - points_tm = trimesh.PointCloud(points) - self.scene.delete_geometry("points") - self.scene.add_geometry(points_tm, geom_name="points") - - if self.msg_history: - for msg in self.msg_history[self._it]: - if msg.precision.count_nonzero() != 0: - if msg.mean[0].dof() == 3 and "Reprojection" in msg.name: - ellipse = make_ellipse( - msg.mean[0][0], torch.linalg.inv(msg.precision[0]) - ) - if f"ellipse_{msg.name}" in self.scene.geometry: - self.scene.delete_geometry(f"ellipse_{msg.name}") - self.scene.add_geometry( - ellipse, geom_name=f"ellipse_{msg.name}" - ) - - if self._it != 0: - self._update_vertex_list() - - -def make_ellipse(mean, cov, do_lines=False, color=None): - # eigvals_torch, eigvecs_torch = torch.linalg.eigh(cov) - eigvals, eigvecs = np.linalg.eigh(cov) # eigenvecs are columns - # print("eigvals", eigvals) # , eigvals_torch.numpy()) - eigvals = eigvals / 10 - signs = np.sign(eigvals) - eigvals = np.clip(np.abs(eigvals), 1.0, 100, eigvals) * signs - - if do_lines: - points = [] - for i, eigvalue in enumerate(eigvals): - disp = eigvalue * eigvecs[:, i] - points.extend([mean + disp, mean - disp]) - - paths = torch.cat(points).reshape(3, 2, 3) - lines = trimesh.load_path(paths) - - return lines - - else: - rotation = np.eye(4) - rotation[:3, :3] = eigvecs - - ellipse = trimesh.creation.icosphere() - ellipse.apply_scale(eigvals) - ellipse.apply_transform(rotation) - ellipse.apply_translation(mean) - if color is None: - color = trimesh.visual.random_color() - ellipse.visual.vertex_colors = color - ellipse.visual.vertex_colors[:, 3] = 100 - - return ellipse diff --git a/theseus/optimizer/gbp/backward_analysis.py b/theseus/optimizer/gbp/backward_analysis.py deleted file mode 100644 index 6d16049a6..000000000 --- a/theseus/optimizer/gbp/backward_analysis.py +++ /dev/null @@ -1,190 +0,0 @@ -import numpy as np -import os -import json - -import matplotlib.pylab as plt - - -def plot_timing_memory(root): - dirs = os.listdir(root) - dirs.remove("figs") - - timings = {} - memory = {} - - for direc in dirs: - - with open(os.path.join(root, direc, "timings.txt"), "r") as f: - timings[direc] = json.load(f) - with open(os.path.join(root, direc, "memory.txt"), "r") as f: - memory[direc] = json.load(f) - - fig, ax = plt.subplots(nrows=1, ncols=4, figsize=(15, 3)) - fig.subplots_adjust(hspace=0.0, wspace=0.3) - - exps = ["full", "implicit", "truncated_5", "truncated_10"] - labels = ["Unroll", "Implicit", "Trunc-5", "Trunc-10"] - colors = ["C0", "C1", "C2", "C3"] - markers = [".", "v", "o", "s"] - inner_iters = [25, 50, 100, 150, 200, 500] - - for i, exp in enumerate(exps): - - fwd_times = [] - bwd_times = [] - fwd_memory = [] - bwd_memory = [] - for iters in inner_iters: - key = f"{str(iters)}_{exp}" - fwd_times.append(np.mean(timings[key]["fwd"]) / 1e3) - bwd_times.append(np.mean(timings[key]["bwd"]) / 1e3) - fwd_memory.append(np.mean(memory[key]["fwd"])) - bwd_memory.append(np.mean(memory[key]["bwd"])) - - col = colors[i] - m = markers[i] - ax[0].plot(inner_iters, fwd_times, color=col, marker=m, label=labels[i]) - ax[1].plot(inner_iters, bwd_times, color=col, marker=m) - ax[2].plot(inner_iters, fwd_memory, color=col, marker=m) - ax[3].plot(inner_iters, bwd_memory, color=col, marker=m) - - title_fontsize = 11 - ax[0].title.set_text("Forward time") - ax[1].title.set_text("Backward time") - ax[2].title.set_text("Forward memory") - ax[3].title.set_text("Backward memory") - ax[0].title.set_size(title_fontsize) - ax[1].title.set_size(title_fontsize) - ax[2].title.set_size(title_fontsize) - ax[3].title.set_size(title_fontsize) - - ax[0].set_xlabel("Inner loop iterations") - ax[1].set_xlabel("Inner loop iterations") - ax[2].set_xlabel("Inner loop iterations") - ax[3].set_xlabel("Inner loop iterations") - ax[0].set_ylabel("Time (seconds)") - ax[1].set_ylabel("Time (seconds)") - ax[2].set_ylabel("Memory (MBs)") - ax[3].set_ylabel("Memory (MBs)") - - ax[0].legend( - loc="lower center", - bbox_to_anchor=(2.5, -0.5), - fancybox=True, - ncol=4, - fontsize=10, - ) - - # plt.tight_layout() - plt.subplots_adjust(bottom=0.3) - plt.show() - - -def plot_loss_traj(root, ref_loss=None): - - exps = ["full", "implicit", "truncated_5", "truncated_10"] - labels = ["Unroll", "Implicit", "Trunc-5", "Trunc-10"] - colors = ["C0", "C1", "C2", "C3"] - inner_iters = [150, 200, 500] # [25, 50, 100, 150, 200, 500] - - fig_loss, ax_loss = plt.subplots(nrows=1, ncols=len(inner_iters), figsize=(20, 3)) - fig_loss.subplots_adjust(hspace=0.0, wspace=0.5) - - fig_loss_t, ax_loss_t = plt.subplots( - nrows=1, ncols=len(inner_iters), figsize=(20, 3) - ) - fig_loss_t.subplots_adjust(hspace=0.0, wspace=0.5) - - fig_traj, ax_traj = plt.subplots(nrows=len(inner_iters), ncols=4, figsize=(20, 15)) - fig_traj.subplots_adjust(hspace=0.75, wspace=0.4) - - for i, iters in enumerate(inner_iters): - - for j, exp in enumerate(exps): - direc = f"{str(iters)}_{exp}" - - # plot sweep curves - if j == 0: - sweep_radii = np.loadtxt(os.path.join(root, direc, "sweep_radius.txt")) - sweep_loss = np.loadtxt(os.path.join(root, direc, "sweep_loss.txt")) - for k in range(len(exps)): - ax_traj[i, k].plot(sweep_radii, sweep_loss) - ax_traj[i, k].title.set_text(labels[k]) - - # plot trajectory over epochs - loss_traj = np.loadtxt(os.path.join(root, direc, "optim_loss.txt")) - radius_traj = np.loadtxt(os.path.join(root, direc, "optim_radius.txt")) - ax_traj[i, j].scatter( - radius_traj, - loss_traj, - c=range(len(loss_traj)), - cmap=plt.get_cmap("viridis"), - ) - - # plot loss over epochs or over total time - label = labels[j] if i == 0 else None - if ref_loss is not None: - loss_traj = np.array(loss_traj) * ref_loss + ref_loss - with open(os.path.join(root, direc, "timings.txt"), "r") as f: - timings = json.load(f) - step_times = [ - timings["fwd"][i] + timings["bwd"][i] - for i in range(len(timings["fwd"])) - ] - step_times = np.array(step_times) / 1000 - cum_times = np.cumsum(step_times) - ax_loss[i].plot(loss_traj, color=colors[j], marker=None, label=label) - ax_loss_t[i].plot( - cum_times, loss_traj, color=colors[j], marker=None, label=label - ) - - for ax in [ax_loss, ax_loss_t]: - ax[0].legend( - loc="lower center", - bbox_to_anchor=(2.0, -0.5), # (3.5, -0.5) - fancybox=True, - ncol=4, - fontsize=10, - ) - - # align y axes - ymin, ymax = 1e10, -1e10 - for ax in ax_loss: - ymin = min(ymin, ax.get_ylim()[0]) - ymax = max(ymax, ax.get_ylim()[1]) - for ax in ax_loss: - ax.set_ylim((ymin, ymax)) - for ax in ax_loss_t: - ax.set_ylim((ymin, ymax)) - - title_fontsize = 11 - for i, iters in enumerate(inner_iters): - for ax in [ax_loss, ax_loss_t]: - ax[i].title.set_text(f"Train loss ({iters} inner GBP steps)") - ax[i].title.set_size(title_fontsize) - ax[i].set_ylabel("Camera Loss") - ax_loss[i].set_xlabel("Epoch") - ax_loss_t[i].set_xlabel("Time (seconds)") - - for j in range(4): - ax_traj[i, j].set_xlabel("Huber loss radius") - if j == 0: - ax_traj[i, j].set_ylabel(f"{iters} inner steps\n\n\nCamera Loss") - else: - ax_traj[i, j].set_ylabel("Camera Loss") - - fig_loss.subplots_adjust(bottom=0.3) - fig_loss_t.subplots_adjust(bottom=0.3) - plt.show() - - -if __name__ == "__main__": - - root = ( - "/home/joe/projects/theseus/theseus/optimizer/gbp/" - + "outputs/loss_radius_exp/backward_analysis/" - ) - - # plot_timing_memory(root) - - plot_loss_traj(root, ref_loss=None) diff --git a/theseus/optimizer/gbp/bundle_adjustment.py b/theseus/optimizer/gbp/bundle_adjustment.py deleted file mode 100644 index 8f834e34a..000000000 --- a/theseus/optimizer/gbp/bundle_adjustment.py +++ /dev/null @@ -1,629 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -import random -from typing import Dict, List - -import numpy as np -import omegaconf -import time -import torch - -import os -import json -from datetime import datetime - -import theseus as th -from theseus.core import Vectorize -import theseus.utils.examples as theg -from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule - -# from theseus.optimizer.gbp import BAViewer - - -OPTIMIZER_CLASS = { - "gbp": GaussianBeliefPropagation, - "gauss_newton": th.GaussNewton, - "levenberg_marquardt": th.LevenbergMarquardt, -} - -OUTER_OPTIMIZER_CLASS = { - "sgd": torch.optim.SGD, - "adam": torch.optim.Adam, -} - -GBP_SCHEDULE = { - "synchronous": GBPSchedule.SYNCHRONOUS, -} - - -def start_timing(): - if torch.cuda.is_available(): - torch.cuda.synchronize() - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - else: - start = time.perf_counter() - end = None - return start, end - - -def end_timing(start, end): - if torch.cuda.is_available(): - torch.cuda.synchronize() - end.record() - # Waits for everything to finish running - torch.cuda.synchronize() - elapsed_time = start.elapsed_time(end) - else: - end = time.perf_counter() - elapsed_time = end - start - # Convert to milliseconds to have the same units - # as torch.cuda.Event.elapsed_time - elapsed_time = elapsed_time * 1000 - return elapsed_time - - -def save_res_loss_rad( - save_dir, cfg, radius_vals, losses, sweep_radii=None, sweep_losses=None -): - with open(f"{save_dir}/config.txt", "w") as f: - json.dump(cfg, f, indent=4) - - # sweep values - if sweep_radii is not None: - np.savetxt(f"{save_dir}/sweep_radius.txt", sweep_radii) - if sweep_losses is not None: - np.savetxt(f"{save_dir}/sweep_loss.txt", sweep_losses) - - # optim trajectory - np.savetxt(f"{save_dir}/optim_radius.txt", radius_vals) - np.savetxt(f"{save_dir}/optim_loss.txt", losses) - - -def print_histogram( - ba: theg.BundleAdjustmentDataset, var_dict: Dict[str, torch.Tensor], msg: str -): - print(msg) - histogram = theg.ba_histogram( - cameras=[ - theg.Camera( - th.SE3(tensor=var_dict[c.pose.name]), - c.focal_length, - c.calib_k1, - c.calib_k2, - ) - for c in ba.cameras - ], - points=[th.Point3(tensor=var_dict[pt.name]) for pt in ba.points], - observations=ba.observations, - ) - for line in histogram.split("\n"): - print(line) - - -def camera_loss( - ba: theg.BundleAdjustmentDataset, camera_pose_vars: List[th.LieGroup] -) -> torch.Tensor: - loss: torch.Tensor = 0 # type:ignore - for i in range(len(ba.cameras)): - cam_pose = camera_pose_vars[i].copy() - cam_pose.to(ba.gt_cameras[i].pose.device) - camera_loss = th.local(cam_pose, ba.gt_cameras[i].pose).norm(dim=1).cpu() - loss += camera_loss - return loss - - -# Assumes the weight of the cost functions are 1 -def average_repojection_error(objective, values_dict=None) -> float: - if values_dict is not None: - objective.update(values_dict) - if objective._vectorized is False: - Vectorize(objective) - reproj_norms = [] - for cost_function in objective._get_iterator(): - if "Reprojection" in cost_function.name: - # should equal error as weight is 1 - # need to call weighted_error as error is not cached - err = cost_function.weighted_error().norm(dim=1) - reproj_norms.append(err) - - are = torch.tensor(reproj_norms).mean().item() - return are - - -def load_problem(cfg: omegaconf.OmegaConf): - # create (or load) dataset - if cfg["bal_file"] is None: - ba = theg.BundleAdjustmentDataset.generate_synthetic( - num_cameras=cfg["synthetic"]["num_cameras"], - num_points=cfg["synthetic"]["num_points"], - average_track_length=cfg["synthetic"]["average_track_length"], - track_locality=cfg["synthetic"]["track_locality"], - feat_random=1.5, - prob_feat_is_outlier=0.02, - outlier_feat_random=70, - cam_pos_rand=5.0, - cam_rot_rand=0.9, - point_rand=10.0, - ) - else: - cams, points, obs = theg.BundleAdjustmentDataset.load_bal_dataset( - cfg["bal_file"], drop_obs=0.0 - ) - ba = theg.BundleAdjustmentDataset(cams, points, obs) - - print("Cameras:", len(ba.cameras)) - print("Points:", len(ba.points)) - print("Observations:", len(ba.observations), "\n") - - return ba - - -def setup_layer(cfg: omegaconf.OmegaConf): - ba = load_problem(cfg) - - print("Optimizer:", cfg["optim"]["optimizer_cls"]) - print("Backward mode:", cfg["optim"]["backward_mode"], "\n") - - # param that control transition from squared loss to huber - radius_tensor = torch.tensor([1.0], dtype=torch.float64) - log_loss_radius = th.Vector(tensor=radius_tensor, name="log_loss_radius") - - # Set up objective - print("Setting up objective") - t0 = time.time() - dtype = torch.float64 - objective = th.Objective(dtype=dtype) - dummy_objective = th.Objective(dtype=dtype) # for computing are - - weight = th.ScaleCostWeight(torch.tensor(1.0).to(dtype=ba.cameras[0].pose.dtype)) - for i, obs in enumerate(ba.observations): - # print(i, len(ba.observations)) - cam = ba.cameras[obs.camera_index] - cost_function = th.eb.Reprojection( - camera_pose=cam.pose, - world_point=ba.points[obs.point_index], - focal_length=cam.focal_length, - calib_k1=cam.calib_k1, - calib_k2=cam.calib_k2, - image_feature_point=obs.image_feature_point, - weight=weight, - ) - robust_cost_function = th.RobustCostFunction( - cost_function, - th.HuberLoss, - log_loss_radius, - name=f"robust_{cost_function.name}", - ) - objective.add(robust_cost_function) - dummy_objective.add(cost_function) - - # Add regularization - if cfg["optim"]["regularize"]: - # zero_point3 = th.Point3(dtype=dtype, name="zero_point") - # identity_se3 = th.SE3(dtype=dtype, name="zero_se3") - w = np.sqrt(cfg["optim"]["reg_w"]) - damping_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) - for name, var in objective.optim_vars.items(): - target: th.Manifold - if isinstance(var, th.SE3): - target = var.copy(new_name="target_" + var.name) - # target = identity_se3 - objective.add( - th.Difference(var, target, damping_weight, name=f"reg_{name}") - ) - # elif isinstance(var, th.Point3): - # target = var.copy(new_name="target_" + var.name) - # # target = zero_point3 - # else: - # assert False - # objective.add( - # th.Difference(var, target, damping_weight, name=f"reg_{name}") - # ) - - camera_pose_vars: List[th.LieGroup] = [ - objective.optim_vars[c.pose.name] for c in ba.cameras # type: ignore - ] - if cfg["optim"]["ratio_known_cameras"] > 0.0 and ba.gt_cameras is not None: - w = 1000.0 - camera_weight = th.ScaleCostWeight(w * torch.ones(1, dtype=dtype)) - for i in range(len(ba.cameras)): - if np.random.rand() > cfg["optim"]["ratio_known_cameras"]: - continue - print("fixing cam", i) - objective.add( - th.Difference( - camera_pose_vars[i], - ba.gt_cameras[i].pose, - camera_weight, - name=f"camera_diff_{i}", - ) - ) - print("done in:", time.time() - t0) - - # Create optimizer and theseus layer - vectorize = cfg["optim"]["vectorize"] - optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( - objective, - max_iterations=cfg["optim"]["max_iters"], - vectorize=vectorize, - # linearization_cls=th.SparseLinearization, - # linear_solver_cls=th.LUCudaSparseSolver, - ) - theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) - - if cfg["device"] == "cuda": - cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" - theseus_optim.to(cfg["device"]) - dummy_objective.to(cfg["device"]) - print("Device:", cfg["device"]) - - # create damping parameter - lin_system_damping = torch.nn.Parameter( - torch.tensor( - [cfg["optim"]["gbp_settings"]["lin_system_damping"]], dtype=torch.float64 - ) - ) - lin_system_damping.to(device=cfg["device"]) - - optim_arg = { - "track_best_solution": False, - "track_err_history": True, - "track_state_history": cfg["optim"]["track_state_history"], - "verbose": True, - "backward_mode": cfg["optim"]["backward_mode"], - "backward_num_iterations": cfg["optim"]["backward_num_iterations"], - } - if isinstance(optimizer, GaussianBeliefPropagation): - gbp_args = cfg["optim"]["gbp_settings"].copy() - gbp_args["lin_system_damping"] = lin_system_damping - gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] - optim_arg = {**optim_arg, **gbp_args} - - theseus_inputs = {} - for cam in ba.cameras: - theseus_inputs[cam.pose.name] = cam.pose.tensor.clone() - for pt in ba.points: - theseus_inputs[pt.name] = pt.tensor.clone() - - return ( - theseus_optim, - theseus_inputs, - optim_arg, - ba, - dummy_objective, - camera_pose_vars, - lin_system_damping, - ) - - -def run_inner( - theseus_optim, - theseus_inputs, - optim_arg, - ba, - dummy_objective, - camera_pose_vars, - lin_system_damping, -): - if ba.gt_cameras is not None: - with torch.no_grad(): - camera_loss_ref = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS: {camera_loss_ref: .3f}") - are = average_repojection_error(dummy_objective, values_dict=theseus_inputs) - print("Average reprojection error (pixels): ", are) - print_histogram(ba, theseus_inputs, "Input histogram:") - - with torch.no_grad(): - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=optim_arg, - ) - - if ba.gt_cameras is not None: - loss = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS: (loss, ref loss) {loss:.3f} {camera_loss_ref: .3f}") - - are = average_repojection_error(dummy_objective, values_dict=theseus_outputs) - print("Average reprojection error (pixels): ", are) - print_histogram(ba, theseus_outputs, "Final histogram:") - - # if info.state_history is not None: - # BAViewer( - # info.state_history, gt_cameras=ba.gt_cameras, gt_points=ba.gt_points - # ) # , msg_history=optimizer.ftov_msgs_history) - - """ - Save for nesterov experiments - """ - save_dir = os.getcwd() + "/outputs/nesterov/synthetic_large/" - if cfg["optim"]["gbp_settings"]["nesterov"]: - save_dir += "1/" - else: - save_dir += "0/" - os.mkdir(save_dir) - with open(f"{save_dir}/config.txt", "w") as f: - json.dump(cfg, f, indent=4) - np.savetxt(save_dir + "/error_history.txt", info.err_history[0].cpu().numpy()) - - """ - Save for bal sequences - """ - - # if cfg["bal_file"] is not None: - # save_dir = os.path.join(os.getcwd(), "outputs") - # if not os.path.exists(save_dir): - # os.mkdir(save_dir) - # err_history = info.err_history[0].cpu().numpy() - # save_file = os.path.join( - # save_dir, - # f"{cfg['optim']['optimizer_cls']}_err_{cfg['bal_file'].split('/')[-1]}", - # ) - # np.savetxt(save_file, err_history) - - # # get average reprojection error for each iteration - # if info.state_history is not None: - # ares = [] - # iters = ( - # info.converged_iter - # if info.converged_iter != -1 - # else cfg["optim"]["max_iters"] - # ) - # for i in range(iters): - # t0 = time.time() - # values_dict = {} - # for name, state in info.state_history.items(): - # values_dict[name] = ( - # state[..., i].to(dtype=torch.float64).to(dummy_objective.device) - # ) - # are = average_repojection_error(dummy_objective, values_dict=values_dict) - # ares.append(are) - # print(i, "-- ARE:", are, " -- time", time.time() - t0) - # are = average_repojection_error(dummy_objective, values_dict=theseus_outputs) - # ares.append(are) - - # if cfg["bal_file"] is not None: - # save_dir = os.path.join(os.getcwd(), "outputs") - # if not os.path.exists(save_dir): - # os.mkdir(save_dir) - # save_file = os.path.join( - # save_dir, - # f"{cfg['optim']['optimizer_cls']}_are_{cfg['bal_file'].split('/')[-1]}", - # ) - # np.savetxt(save_file, np.array(ares)) - - -def run_outer(cfg: omegaconf.OmegaConf, out_dir=None, do_sweep=False): - - torch.manual_seed(cfg["seed"]) - np.random.seed(cfg["seed"]) - random.seed(cfg["seed"]) - - print(f"\nRunning experiment. Save directory: {out_dir}\n") - - ( - theseus_optim, - theseus_inputs, - optim_arg, - ba, - dummy_objective, - camera_pose_vars, - lin_system_damping, - ) = setup_layer(cfg) - - loss_radius_tensor = torch.nn.Parameter(torch.tensor([3.0], dtype=torch.float64)) - model_optimizer = OUTER_OPTIMIZER_CLASS[cfg["outer"]["optimizer"]]( - [loss_radius_tensor], lr=cfg["outer"]["lr"] - ) - # model_optimizer = torch.optim.Adam([lin_system_damping], lr=cfg["outer"]["lr"]) - - theseus_inputs["log_loss_radius"] = loss_radius_tensor.unsqueeze(1).clone() - - with torch.no_grad(): - camera_loss_ref = camera_loss(ba, camera_pose_vars).item() - print(f"CAMERA LOSS (no learning): {camera_loss_ref: .3f}") - print_histogram(ba, theseus_inputs, "Input histogram:") - - # import matplotlib.pylab as plt - sweep_radii, sweep_losses = None, None - if do_sweep: - sweep_radii = torch.linspace(0.01, 5.0, 20, dtype=torch.float64) - sweep_losses = [] - sweep_arg = optim_arg.copy() - sweep_arg["verbose"] = False - with torch.set_grad_enabled(False): - for radius in sweep_radii: - radius = radius.to(cfg["device"]) - theseus_inputs["log_loss_radius"] = radius.unsqueeze(0).unsqueeze(0) - - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=sweep_arg, - ) - cam_loss = camera_loss(ba, camera_pose_vars) - loss = (cam_loss - camera_loss_ref) / camera_loss_ref - sweep_losses.append(torch.sum(loss.detach()).item()) - print( - f"SWEEP radius {radius}, camera loss {cam_loss.detach().item():.3f}," - f" loss {sweep_losses[-1]:.3f}, ref loss {camera_loss_ref:.3f}" - ) - - # plt.plot(sweep_radii, sweep_losses) - # plt.xlabel("Log loss radius") - # plt.ylabel("(Camera loss - reference loss) / reference loss") - # plt.show() - - losses = [] - radius_vals = [] - theseus_inputs["log_loss_radius"] = ( - loss_radius_tensor.unsqueeze(1).clone().to(cfg["device"]) - ) - - times: Dict = {"fwd": [], "bwd": []} - memory: Dict = {"fwd": [], "bwd": []} - - for epoch in range(cfg["outer"]["num_epochs"]): - print(f" ******************* EPOCH {epoch} ******************* ") - start_time = time.time_ns() - model_optimizer.zero_grad() - theseus_inputs["log_loss_radius"] = ( - loss_radius_tensor.unsqueeze(1).clone().to(cfg["device"]) - ) - - start, end = start_timing() - torch.cuda.reset_peak_memory_stats() - theseus_outputs, info = theseus_optim.forward( - input_tensors=theseus_inputs, - optimizer_kwargs=optim_arg, - ) - times["fwd"].append(end_timing(start, end)) - memory["fwd"].append(torch.cuda.max_memory_allocated() / 1048576) - - cam_loss = camera_loss(ba, camera_pose_vars) - loss = (cam_loss - camera_loss_ref) / camera_loss_ref - - start, end = start_timing() - torch.cuda.reset_peak_memory_stats() - loss.backward() - times["bwd"].append(end_timing(start, end)) - memory["bwd"].append(torch.cuda.max_memory_allocated() / 1048576) - radius_vals.append(loss_radius_tensor.data.item()) - # correct for implicit gradients step size != 1 - if cfg["optim"]["backward_mode"] == "implicit": - if theseus_optim.optimizer.implicit_method == "gauss_newton": - loss_radius_tensor.grad /= theseus_optim.optimizer.implicit_step_size - print("\ngrad is: ", loss_radius_tensor.grad, "\n") - model_optimizer.step() - loss_value = torch.sum(loss.detach()).item() - losses.append(loss_value) - end_time = time.time_ns() - - # print_histogram(ba, theseus_outputs, "Output histogram:") - print(f"camera loss {cam_loss.detach().item()} and ref loss {camera_loss_ref}") - print( - f"Epoch: {epoch} Loss: {loss_value} " - # f"Lin system damping {lin_system_damping}" - f"Kernel Radius: exp({loss_radius_tensor.data.item()})=" - f"{torch.exp(loss_radius_tensor.data).item()}" - ) - print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") - - print("Loss values:", losses) - - now = datetime.now() - if out_dir is None: - out_dir = now.strftime("%m-%d-%y_%H-%M-%S") - save_dir = os.getcwd() + "/outputs/loss_radius_exp/" + out_dir - os.mkdir(save_dir) - with open(f"{save_dir}/config.txt", "w") as f: - json.dump(cfg, f, indent=4) - - print("\n=== Runtimes") - k1, k2 = "fwd", "bwd" - print(f"Forward: {np.mean(times[k1]):.2e} s +/- {np.std(times[k1]):.2e} s") - print(f"Backward (unroll): {np.mean(times[k2]):.2e} s +/- {np.std(times[k2]):.2e} s") - - print("\n=== Memory") - k1, k2 = "fwd", "bwd" - print(f"Forward: {np.mean(memory[k1]):.2e} MB +/- {np.std(memory[k1]):.2e} MB") - print( - f"Backward (unroll): {np.mean(memory[k2]):.2e} MB +/- {np.std(memory[k2]):.2e} MB" - ) - - with open(f"{save_dir}/timings.txt", "w") as f: - json.dump(times, f, indent=4) - with open(f"{save_dir}/memory.txt", "w") as f: - json.dump(memory, f, indent=4) - - with open(f"{save_dir}/ref_loss.txt", "w") as f: - f.write(f"{camera_loss_ref:.5f}") - - save_res_loss_rad( - save_dir, - cfg, - radius_vals, - losses, - sweep_radii=sweep_radii, - sweep_losses=sweep_losses, - ) - - # plt.scatter(radius_vals, losses, c=range(len(losses)), cmap=plt.get_cmap("viridis")) - # plt.title(cfg["optim"]["optimizer_cls"] + " - " + dir_name) - # plt.show() - - -if __name__ == "__main__": - - cfg: Dict = { - "seed": 1, - "device": "cuda", - "bal_file": None, - # "bal_file": "/mnt/sda/bal/problem-21-11315-pre.txt", - "synthetic": { - "num_cameras": 10, - "num_points": 100, - "average_track_length": 8, - "track_locality": 0.2, - }, - "optim": { - "max_iters": 200, - "vectorize": True, - "optimizer_cls": "gbp", - # "optimizer_cls": "gauss_newton", - # "optimizer_cls": "levenberg_marquardt", - "backward_mode": "implicit", - "backward_num_iterations": 10, - "track_state_history": True, - "regularize": True, - "ratio_known_cameras": 0.1, - "reg_w": 1e-7, - "gbp_settings": { - "relin_threshold": 1e-8, - "ftov_msg_damping": 0.0, - "dropout": 0.0, - "schedule": "synchronous", - "lin_system_damping": 1.0e-2, - "nesterov": False, - "implicit_method": "gbp", - }, - }, - "outer": { - "num_epochs": 20, - "lr": 5.0e1, # 5.0e-1, - "optimizer": "sgd", - }, - } - - # args = setup_layer(cfg) - # run_inner(*args) - - # run_outer(cfg, "implicit_test", do_sweep=False) - - for max_iters in [150, 200, 500]: # 25, 50, 100, - for backward_mode in ["implicit"]: - cfg_copy = cfg.copy() - cfg_copy["optim"]["max_iters"] = max_iters - cfg_copy["optim"]["backward_mode"] = backward_mode - - dir_name = str(max_iters) + "_" + backward_mode - - if backward_mode == "truncated": - for backward_num_iterations in [5, 10]: - cfg_copy["optim"][ - "backward_num_iterations" - ] = backward_num_iterations - dir_name = ( - str(max_iters) - + "_" - + backward_mode - + "_" - + str(cfg["optim"]["backward_num_iterations"]) - ) - - run_outer(cfg_copy, dir_name) - else: - do_sweep = True if backward_mode == "unroll" else False - run_outer(cfg_copy, dir_name, do_sweep=do_sweep) diff --git a/theseus/optimizer/gbp/gbp_baseline.py b/theseus/optimizer/gbp/gbp_baseline.py deleted file mode 100644 index ea81e1a62..000000000 --- a/theseus/optimizer/gbp/gbp_baseline.py +++ /dev/null @@ -1,775 +0,0 @@ -from typing import Callable, List, Optional, Union - -import matplotlib.pylab as plt -import numpy as np - -""" -Defines squared loss functions that correspond to Gaussians. -Robust losses are implemented by scaling the Gaussian covariance. -""" - - -class Gaussian: - def __init__( - self, - dim, - eta=None, - lam=None, - ): - self.dim = dim - - if eta is not None and eta.shape == (dim,): - self.eta = eta - else: - self.eta = np.zeros(dim) - - if lam is not None and lam.shape == (dim, dim): - self.lam = lam - else: - self.lam = np.zeros([dim, dim]) - - def mean(self) -> np.ndarray: - return np.matmul(np.linalg.inv(self.lam), self.eta) - - def cov(self) -> np.ndarray: - return np.linalg.inv(self.lam) - - def mean_and_cov(self) -> List[np.ndarray]: - cov = self.cov() - mean = np.matmul(cov, self.eta) - return [mean, cov] - - def set_with_cov_form(self, mean: np.ndarray, cov: np.ndarray) -> None: - self.lam = np.linalg.inv(cov) - self.eta = np.matmul(self.lam, mean) - - -class GBPSettings: - def __init__( - self, - damping: float = 0.0, - beta: float = 0.1, - num_undamped_iters: int = 5, - min_linear_iters: int = 10, - dropout: float = 0.0, - reset_iters_since_relin: List[int] = [], - ): - # Parameters for damping the eta component of the message - self.damping = damping - # Number of undamped iterations after relin before damping is on - self.num_undamped_iters = num_undamped_iters - - self.dropout = dropout - - # Parameters for just in time factor relinearisation. - # Threshold absolute distance between linpoint - # and adjacent belief means for relinearisation. - self.beta = beta - # Minimum number of linear iterations before - # a factor is allowed to realinearise. - self.min_linear_iters = min_linear_iters - self.reset_iters_since_relin = reset_iters_since_relin - - def get_damping(self, iters_since_relin: int) -> float: - if iters_since_relin > self.num_undamped_iters: - return self.damping - else: - return 0.0 - - -class SquaredLoss: - def __init__(self, dofs: int, diag_cov: Union[float, np.ndarray]): - """ - dofs: dofs of the measurement - cov: diagonal elements of covariance matrix - """ - if isinstance(diag_cov, np.ndarray): - assert diag_cov.shape[0] == dofs - mat = np.zeros([dofs, dofs]) - mat[range(dofs), range(dofs)] = diag_cov - self.cov = mat - self.effective_cov = mat.copy() - - def get_effective_cov(self, residual: np.ndarray) -> None: - """ - Returns the covariance of the Gaussian (squared loss) - that matches the loss at the error value. - """ - self.effective_cov = self.cov.copy() - - def robust(self) -> bool: - return not np.equal(self.cov, self.effective_cov) - - -class HuberLoss(SquaredLoss): - def __init__( - self, dofs: int, diag_cov: Union[float, np.ndarray], stds_transition: float - ): - """ - stds_transition: num standard deviations from minimum at - which quadratic loss transitions to linear. - """ - SquaredLoss.__init__(self, dofs, diag_cov) - self.stds_transition = stds_transition - - def get_effective_cov(self, residual: np.ndarray) -> None: - energy = residual @ np.linalg.inv(self.cov) @ residual - mahalanobis_dist = np.sqrt(energy) - if mahalanobis_dist > self.stds_transition: - denom = ( - 2 * self.stds_transition * mahalanobis_dist - self.stds_transition**2 - ) - self.effective_cov = self.cov * mahalanobis_dist**2 / denom - else: - self.effective_cov = self.cov.copy() - - -class MeasModel: - def __init__( - self, - meas_fn: Callable, - jac_fn: Callable, - loss: SquaredLoss, - *args, - ): - self._meas_fn = meas_fn - self._jac_fn = jac_fn - self.loss = loss - self.args = args - self.linear = True - - def jac_fn(self, x: np.ndarray) -> np.ndarray: - return self._jac_fn(x, *self.args) - - def meas_fn(self, x: np.ndarray) -> np.ndarray: - return self._meas_fn(x, *self.args) - - -def lin_meas_fn(x): - length = int(x.shape[0] / 2) - J = np.concatenate((-np.eye(length), np.eye(length)), axis=1) - return J @ x - - -def lin_jac_fn(x): - length = int(x.shape[0] / 2) - return np.concatenate((-np.eye(length), np.eye(length)), axis=1) - - -class LinearDisplacementModel(MeasModel): - def __init__(self, loss: SquaredLoss) -> None: - MeasModel.__init__(self, lin_meas_fn, lin_jac_fn, loss) - self.linear = True - - -""" -Main GBP functions. -Defines classes for variable nodes, factor nodes and edges and factor graph. -""" - - -class FactorGraph: - def __init__( - self, - gbp_settings: GBPSettings = GBPSettings(), - ): - self.var_nodes: List[VariableNode] = [] - self.factors: List[Factor] = [] - self.gbp_settings = gbp_settings - - def add_var_node( - self, - dofs: int, - prior_mean: Optional[np.ndarray] = None, - prior_diag_cov: Optional[Union[float, np.ndarray]] = None, - ) -> None: - variableID = len(self.var_nodes) - self.var_nodes.append(VariableNode(variableID, dofs)) - if prior_mean is not None and prior_diag_cov is not None: - prior_cov = np.zeros([dofs, dofs]) - prior_cov[range(dofs), range(dofs)] = prior_diag_cov - self.var_nodes[-1].prior.set_with_cov_form(prior_mean, prior_cov) - self.var_nodes[-1].update_belief() - - def add_factor( - self, - adj_var_ids: List[int], - measurement: np.ndarray, - meas_model: MeasModel, - ) -> None: - factorID = len(self.factors) - adj_var_nodes = [self.var_nodes[i] for i in adj_var_ids] - self.factors.append(Factor(factorID, adj_var_nodes, measurement, meas_model)) - for var in adj_var_nodes: - var.adj_factors.append(self.factors[-1]) - - def update_all_beliefs(self) -> None: - for var_node in self.var_nodes: - var_node.update_belief() - - def compute_all_messages(self, apply_dropout: bool = True) -> None: - for factor in self.factors: - dropout_off = apply_dropout and np.random.rand() > self.gbp_settings.dropout - if dropout_off or not apply_dropout: - damping = self.gbp_settings.get_damping(factor.iters_since_relin) - factor.compute_messages(damping) - - def linearise_all_factors(self) -> None: - for factor in self.factors: - factor.compute_factor() - - def robustify_all_factors(self) -> None: - for factor in self.factors: - factor.robustify_loss() - - def jit_linearisation(self) -> None: - """ - Check for all factors that the current estimate - is close to the linearisation point. - If not, relinearise the factor distribution. - Relinearisation is only allowed at a maximum frequency - of once every min_linear_iters iterations. - """ - for factor in self.factors: - if not factor.meas_model.linear: - adj_belief_means = factor.get_adj_means() - factor.iters_since_relin += 1 - diff_cond = ( - np.linalg.norm(factor.linpoint - adj_belief_means) - > self.gbp_settings.beta - ) - iters_cond = ( - factor.iters_since_relin >= self.gbp_settings.min_linear_iters - ) - if diff_cond and iters_cond: - factor.compute_factor() - - def synchronous_iteration(self) -> None: - self.robustify_all_factors() - self.jit_linearisation() # For linear factors, no compute is done - self.compute_all_messages() - self.update_all_beliefs() - - def random_message(self) -> None: - """ - Sends messages to all adjacent nodes from a random factor. - """ - self.robustify_all_factors() - self.jit_linearisation() # For linear factors, no compute is done - ix = np.random.randint(len(self.factors)) - factor = self.factors[ix] - damping = self.gbp_settings.get_damping(factor.iters_since_relin) - factor.compute_messages(damping) - self.update_all_beliefs() - - def gbp_solve( - self, - n_iters: Optional[int] = 20, - converged_threshold: Optional[float] = 1e-6, - include_priors: bool = True, - ) -> None: - energy_log = [self.energy()] - print(f"\nInitial Energy {energy_log[0]:.5f}") - - i = 0 - count = 0 - not_converged = True - - while not_converged and i < n_iters: - self.synchronous_iteration() - if i in self.gbp_settings.reset_iters_since_relin: - for f in self.factors: - f.iters_since_relin = 1 - - energy_log.append(self.energy(include_priors=include_priors)) - print(f"Iter {i+1} --- " f"Energy {energy_log[-1]:.5f} --- ") - i += 1 - if abs(energy_log[-2] - energy_log[-1]) < converged_threshold: - count += 1 - if count == 3: - not_converged = False - else: - count = 0 - - def energy( - self, eval_point: np.ndarray = None, include_priors: bool = True - ) -> float: - """ - Computes the sum of all of the squared errors in the graph - using the appropriate local loss function. - """ - if eval_point is None: - energy = sum([factor.get_energy() for factor in self.factors]) - else: - var_dofs = np.ndarray([v.dofs for v in self.var_nodes]) - var_ix = np.concatenate([np.ndarray([0]), np.cumsum(var_dofs, axis=0)[:-1]]) - energy = 0.0 - for f in self.factors: - local_eval_point = np.concatenate( - [ - eval_point[var_ix[v.variableID] : var_ix[v.variableID] + v.dofs] - for v in f.adj_var_nodes - ] - ) - energy += f.get_energy(local_eval_point) - if include_priors: - prior_energy = sum([var.get_prior_energy() for var in self.var_nodes]) - energy += prior_energy - return energy - - def get_joint_dim(self) -> int: - return sum([var.dofs for var in self.var_nodes]) - - def get_joint(self) -> Gaussian: - """ - Get the joint distribution over all variables in the information form - If nonlinear factors, it is taken at the current linearisation point. - """ - dim = self.get_joint_dim() - joint = Gaussian(dim) - - # Priors - var_ix = [0] * len(self.var_nodes) - counter = 0 - for var in self.var_nodes: - var_ix[var.variableID] = int(counter) - joint.eta[counter : counter + var.dofs] += var.prior.eta - joint.lam[ - counter : counter + var.dofs, counter : counter + var.dofs - ] += var.prior.lam - counter += var.dofs - - # Other factors - for factor in self.factors: - factor_ix = 0 - for adj_var_node in factor.adj_var_nodes: - vID = adj_var_node.variableID - # Diagonal contribution of factor - joint.eta[ - var_ix[vID] : var_ix[vID] + adj_var_node.dofs - ] += factor.factor.eta[factor_ix : factor_ix + adj_var_node.dofs] - joint.lam[ - var_ix[vID] : var_ix[vID] + adj_var_node.dofs, - var_ix[vID] : var_ix[vID] + adj_var_node.dofs, - ] += factor.factor.lam[ - factor_ix : factor_ix + adj_var_node.dofs, - factor_ix : factor_ix + adj_var_node.dofs, - ] - other_factor_ix = 0 - for other_adj_var_node in factor.adj_var_nodes: - if other_adj_var_node.variableID > adj_var_node.variableID: - other_vID = other_adj_var_node.variableID - # Off diagonal contributions of factor - joint.lam[ - var_ix[vID] : var_ix[vID] + adj_var_node.dofs, - var_ix[other_vID] : var_ix[other_vID] - + other_adj_var_node.dofs, - ] += factor.factor.lam[ - factor_ix : factor_ix + adj_var_node.dofs, - other_factor_ix : other_factor_ix + other_adj_var_node.dofs, - ] - joint.lam[ - var_ix[other_vID] : var_ix[other_vID] - + other_adj_var_node.dofs, - var_ix[vID] : var_ix[vID] + adj_var_node.dofs, - ] += factor.factor.lam[ - other_factor_ix : other_factor_ix + other_adj_var_node.dofs, - factor_ix : factor_ix + adj_var_node.dofs, - ] - other_factor_ix += other_adj_var_node.dofs - factor_ix += adj_var_node.dofs - - return joint - - def MAP(self) -> np.ndarray: - return self.get_joint().mean() - - def dist_from_MAP(self) -> np.ndarray: - return np.linalg.norm(self.get_joint().mean() - self.belief_means()) - - def belief_means(self) -> np.ndarray: - """Get an array containing all current estimates of belief means.""" - return np.concatenate([var.belief.mean() for var in self.var_nodes]) - - def belief_covs(self) -> List[np.ndarray]: - """Get a list of all belief covariances.""" - covs = [var.belief.cov() for var in self.var_nodes] - return covs - - def print(self, brief=False) -> None: - print("\nFactor Graph:") - print(f"# Variable nodes: {len(self.var_nodes)}") - if not brief: - for i, var in enumerate(self.var_nodes): - print( - f"Variable {i}: connects to factors {[f.factorID for f in var.adj_factors]}" - ) - print(f" dofs: {var.dofs}") - print(f" prior mean: {var.prior.mean()}") - print( - f" prior covariance: diagonal sigma {np.diag(var.prior.cov())}" - ) - print(f"# Factors: {len(self.factors)}") - if not brief: - for i, factor in enumerate(self.factors): - if factor.meas_model.linear: - print("Linear", end=" ") - else: - print("Nonlinear", end=" ") - print(f"Factor {i}: connects to variables {factor.adj_vIDs}") - print( - f" measurement model: {type(factor.meas_model).__name__}," - f" {type(factor.meas_model.loss).__name__}," - f" diagonal sigma {np.diag(factor.meas_model.loss.effective_cov)}" - ) - print(f" measurement: {factor.measurement}") - print("\n") - - -class VariableNode: - def __init__(self, id: int, dofs: int): - self.variableID = id - self.dofs = dofs - self.adj_factors: List[Factor] = [] - # prior factor, implemented as part of variable node - self.prior = Gaussian(dofs) - self.belief = Gaussian(dofs) - - def update_belief(self) -> None: - """ - Update local belief estimate by taking product - of all incoming messages along all edges. - """ - # message from prior factor - self.belief.eta = self.prior.eta.copy() - self.belief.lam = self.prior.lam.copy() - # messages from other adjacent variables - for factor in self.adj_factors: - message_ix = factor.adj_vIDs.index(self.variableID) - self.belief.eta += factor.messages[message_ix].eta - self.belief.lam += factor.messages[message_ix].lam - - def get_prior_energy(self) -> float: - energy = 0.0 - if self.prior.lam[0, 0] != 0.0: - residual = self.belief.mean() - self.prior.mean() - energy += 0.5 * residual @ self.prior.lam @ residual - return energy - - -class Factor: - def __init__( - self, - id: int, - adj_var_nodes: List[VariableNode], - measurement: np.ndarray, - meas_model: MeasModel, - ) -> None: - - self.factorID = id - - self.adj_var_nodes = adj_var_nodes - self.dofs = sum([var.dofs for var in adj_var_nodes]) - self.adj_vIDs = [var.variableID for var in adj_var_nodes] - self.messages = [Gaussian(var.dofs) for var in adj_var_nodes] - - self.factor = Gaussian(self.dofs) - self.linpoint = np.zeros(self.dofs) - - self.measurement = measurement - self.meas_model = meas_model - - # For smarter GBP implementations - self.iters_since_relin = 0 - - self.compute_factor() - - def get_adj_means(self) -> np.ndarray: - adj_belief_means = [var.belief.mean() for var in self.adj_var_nodes] - return np.concatenate(adj_belief_means) - - def get_residual(self, eval_point: np.ndarray = None) -> np.ndarray: - """Compute the residual vector.""" - if eval_point is None: - eval_point = self.get_adj_means() - return self.meas_model.meas_fn(eval_point) - self.measurement - - def get_energy(self, eval_point: np.ndarray = None) -> float: - """Computes the squared error using the appropriate loss function.""" - residual = self.get_residual(eval_point) - inf_mat = np.linalg.inv(self.meas_model.loss.effective_cov) - return 0.5 * residual @ inf_mat @ residual - - def robust(self) -> bool: - return self.meas_model.loss.robust() - - def compute_factor(self) -> None: - """ - Compute the factor at current adjacente beliefs using robust. - If measurement model is linear then factor will always be - the same regardless of linearisation point. - """ - self.linpoint = self.get_adj_means() - J = self.meas_model.jac_fn(self.linpoint) - pred_measurement = self.meas_model.meas_fn(self.linpoint) - self.meas_model.loss.get_effective_cov(pred_measurement - self.measurement) - effective_lam = np.linalg.inv(self.meas_model.loss.effective_cov) - self.factor.lam = J.T @ effective_lam @ J - self.factor.eta = ( - (J.T @ effective_lam) - @ (J @ self.linpoint + self.measurement - pred_measurement) - ).flatten() - self.iters_since_relin = 0 - - def robustify_loss(self) -> None: - """ - Rescale the variance of the noise in the Gaussian - measurement model if necessary and update the factor - correspondingly. - """ - old_effective_cov = self.meas_model.loss.effective_cov[0, 0] - self.meas_model.loss.get_effective_cov(self.get_residual()) - self.factor.eta *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] - self.factor.lam *= old_effective_cov / self.meas_model.loss.effective_cov[0, 0] - - def compute_messages(self, damping: float = 0.0) -> None: - """Compute all outgoing messages from the factor.""" - messages_eta, messages_lam = [], [] - - sdim = 0 - for v in range(len(self.adj_vIDs)): - eta_factor = self.factor.eta.copy() - lam_factor = self.factor.lam.copy() - - # Take product of factor with incoming messages - start = 0 - for var in range(len(self.adj_vIDs)): - if var != v: - var_dofs = self.adj_var_nodes[var].dofs - eta_mess = ( - self.adj_var_nodes[var].belief.eta - self.messages[var].eta - ) - lam_mess = ( - self.adj_var_nodes[var].belief.lam - self.messages[var].lam - ) - eta_factor[start : start + var_dofs] += eta_mess - lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess - start += self.adj_var_nodes[var].dofs - - # Divide up parameters of distribution - dofs = self.adj_var_nodes[v].dofs - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - axis=1, - ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno - messages_eta.append( - (1 - damping) * new_message_eta + damping * self.messages[v].eta - ) - messages_lam.append( - (1 - damping) * new_message_lam + damping * self.messages[v].lam - ) - sdim += self.adj_var_nodes[v].dofs - - for v in range(len(self.adj_vIDs)): - self.messages[v].lam = messages_lam[v] - self.messages[v].eta = messages_eta[v] - - -""" -Visualisation function -""" - - -def draw(i): - fig, ax = plt.subplots(figsize=(7, 6)) - fig.set_tight_layout(True) - plt.title(i) - - # plot beliefs - means = fg.belief_means().reshape([size * size, 2]) - plt.scatter(means[:, 0], means[:, 1], color="blue") - for j, cov in enumerate(fg.belief_covs()): - circle = plt.Circle( - (means[j, 0], means[j, 1]), - np.sqrt(cov[0, 0]), - linewidth=0.5, - color="blue", - fill=False, - ) - ax.add_patch(circle) - - # plot true marginals - plt.scatter(map_soln[:, 0], map_soln[:, 1], color="g") - for j, cov in enumerate(marg_covs): - circle = plt.Circle( - (map_soln[j, 0], map_soln[j, 1]), - np.sqrt(marg_covs[j]), - linewidth=0.5, - color="g", - fill=False, - ) - ax.add_patch(circle) - - # draw lines for factors - for f in fg.factors: - bels = np.array([means[f.adj_vIDs[0]], means[f.adj_vIDs[1]]]) - plt.plot(bels[:, 0], bels[:, 1], color="black", linewidth=0.3) - - # draw lines for belief error - for i in range(len(means)): - xs = [means[i, 0], map_soln[i, 0]] - ys = [means[i, 1], map_soln[i, 1]] - plt.plot(xs, ys, color="grey", linewidth=0.3, linestyle="dashed") - - plt.axis("scaled") - plt.xlim([-1, size]) - plt.ylim([-1, size]) - - # convert to image - ax.axis("off") - fig.tight_layout(pad=0) - ax.margins(0) - fig.canvas.draw() - img = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) - img = img.reshape(fig.canvas.get_width_height()[::-1] + (3,)) - plt.close() - - return img - - -if __name__ == "__main__": - - np.random.seed(1) - - size = 3 - dim = 2 - - prior_noise_std = 0.2 - - gbp_settings = GBPSettings( - damping=0.0, - beta=0.1, - num_undamped_iters=1, - min_linear_iters=10, - dropout=0.0, - ) - - # GBP library soln ------------------------------------------ - - noise_cov = np.array([0.01, 0.01]) - - prior_sigma = np.array([1.3**2, 1.3**2]) - prior_noise_std = 0.2 - - fg = FactorGraph(gbp_settings) - - init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) - meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) - - for i in range(size): - for j in range(size): - init = np.array([j, i]) - noise_init = init_noises[j + i * size] - init = init + noise_init - sigma = prior_sigma - if i == 0 and j == 0: - init = np.array([j, i]) - sigma = np.array([0.0001, 0.0001]) - print(init, sigma) - fg.add_var_node(2, init, sigma) - - m = 0 - for i in range(size): - for j in range(size): - if j < size - 1: - meas = np.array([1.0, 0.0]) - meas += meas_noises[m] - fg.add_factor( - [i * size + j, i * size + j + 1], - meas, - LinearDisplacementModel(SquaredLoss(dim, noise_cov)), - ) - m += 1 - if i < size - 1: - meas = np.array([0.0, 1.0]) - meas += meas_noises[m] - fg.add_factor( - [i * size + j, (i + 1) * size + j], - meas, - LinearDisplacementModel(SquaredLoss(dim, noise_cov)), - ) - m += 1 - - fg.print(brief=True) - - # # for vis --------------- - - joint = fg.get_joint() - marg_covs = np.diag(joint.cov())[::2] - map_soln = fg.MAP().reshape([size * size, 2]) - - # # run gbp --------------- - - gbp_settings = GBPSettings( - damping=0.0, - beta=0.1, - num_undamped_iters=1, - min_linear_iters=10, - dropout=0.0, - ) - - # fg.compute_all_messages() - - # i = 0 - n_iters = 20 - while i <= n_iters: - # img = draw(i) - # cv2.imshow('img', img) - # cv2.waitKey(1) - - print(f"Iter {i} --- Energy {fg.energy():.5f}") - - # fg.random_message() - fg.synchronous_iteration() - i += 1 - - # for f in fg.factors: - # for m in f.messages: - # print(np.linalg.inv(m.lam) @ m.eta) - - print(fg.belief_means()) diff --git a/theseus/optimizer/gbp/gbp_euclidean.py b/theseus/optimizer/gbp/gbp_euclidean.py deleted file mode 100644 index 8505fbb7d..000000000 --- a/theseus/optimizer/gbp/gbp_euclidean.py +++ /dev/null @@ -1,1047 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import abc -import math -from dataclasses import dataclass -from itertools import count -from typing import Dict, List, Optional, Sequence, Tuple - -import numpy as np -import torch - -import theseus as th -import theseus.constants -from theseus.core import CostFunction, Objective -from theseus.geometry import Manifold -from theseus.optimizer import Optimizer, VariableOrdering -from theseus.optimizer.nonlinear.nonlinear_optimizer import ( - BackwardMode, - NonlinearOptimizerInfo, - NonlinearOptimizerStatus, -) - -""" -TODO - - add class for message schedule - - damping for lie algebra vars - - solving inverse problem to compute message mean -""" - - -""" -Utitily functions -""" - - -@dataclass -class GBPOptimizerParams: - abs_err_tolerance: float - rel_err_tolerance: float - max_iterations: int - - def update(self, params_dict): - for param, value in params_dict.items(): - if hasattr(self, param): - setattr(self, param, value) - else: - raise ValueError(f"Invalid nonlinear optimizer parameter {param}.") - - -class ManifoldGaussian: - _ids = count(0) - - def __init__( - self, - mean: Sequence[Manifold], - precision: Optional[torch.Tensor] = None, - name: Optional[str] = None, - ): - self._id = next(ManifoldGaussian._ids) - if name: - self.name = name - else: - self.name = f"{self.__class__.__name__}__{self._id}" - - dof = 0 - for v in mean: - dof += v.dof() - self._dof = dof - - self.mean = mean - self.precision = torch.zeros(mean[0].shape[0], self.dof, self.dof).to( - dtype=mean[0].dtype, device=mean[0].device - ) - - @property - def dof(self) -> int: - return self._dof - - @property - def device(self) -> torch.device: - return self.precision[0].device - - @property - def dtype(self) -> torch.dtype: - return self.precision[0].dtype - - # calls to() on the internal tensors - def to(self, *args, **kwargs): - for var in self.mean: - var = var.to(*args, **kwargs) - self.precision = self.precision.to(*args, **kwargs) - - def copy(self, new_name: Optional[str] = None) -> "ManifoldGaussian": - if not new_name: - new_name = f"{self.name}_copy" - mean_copy = [var.copy() for var in self.mean] - return ManifoldGaussian(mean_copy, name=new_name) - - def __deepcopy__(self, memo): - if id(self) in memo: - return memo[id(self)] - the_copy = self.copy() - memo[id(self)] = the_copy - return the_copy - - def update( - self, - mean: Optional[Sequence[Manifold]] = None, - precision: Optional[torch.Tensor] = None, - ): - if mean is not None: - if len(mean) != len(self.mean): - raise ValueError( - f"Tried to update mean with sequence of different" - f"lenght to original mean sequence. Given {len(mean)}. " - f"Expected: {len(self.mean)}" - ) - for i in range(len(self.mean)): - self.mean[i].update(mean[i]) - - if precision is not None: - if precision.shape != self.precision.shape: - raise ValueError( - f"Tried to update precision with data " - f"incompatible with original tensor shape. Given {precision.shape}. " - f"Expected: {self.precision.shape}" - ) - if precision.dtype != self.dtype: - raise ValueError( - f"Tried to update using tensor of dtype {precision.dtype} but precision " - f"has dtype {self.dtype}." - ) - - self.precision = precision - - -class Marginal(ManifoldGaussian): - pass - - -class Message(ManifoldGaussian): - pass - - -class CostFunctionOrdering: - def __init__(self, objective: Objective, default_order: bool = True): - self.objective = objective - self._cf_order: List[CostFunction] = [] - self._cf_name_to_index: Dict[str, int] = {} - if default_order: - self._compute_default_order(objective) - - def _compute_default_order(self, objective: Objective): - assert not self._cf_order and not self._cf_name_to_index - cur_idx = 0 - for cf_name, cf in objective.cost_functions.items(): - if cf_name in self._cf_name_to_index: - continue - self._cf_order.append(cf) - self._cf_name_to_index[cf_name] = cur_idx - cur_idx += 1 - - def index_of(self, key: str) -> int: - return self._cf_name_to_index[key] - - def __getitem__(self, index) -> CostFunction: - return self._cf_order[index] - - def __iter__(self): - return iter(self._cf_order) - - def append(self, cf: CostFunction): - if cf in self._cf_order: - raise ValueError( - f"Cost Function {cf.name} has already been added to the order." - ) - if cf.name not in self.objective.cost_functions: - raise ValueError( - f"Cost Function {cf.name} is not a cost function for the objective." - ) - self._cf_order.append(cf) - self._cf_name_to_index[cf.name] = len(self._cf_order) - 1 - - def remove(self, cf: CostFunction): - self._cf_order.remove(cf) - del self._cf_name_to_index[cf.name] - - def extend(self, cfs: Sequence[CostFunction]): - for cf in cfs: - self.append(cf) - - @property - def complete(self): - return len(self._cf_order) == self.objective.size_variables() - - -""" -GBP functions -""" - - -# Compute the factor at current adjacent beliefs. -def compute_factor(cf, lie=True): - J, error = cf.weighted_jacobians_error() - J_stk = torch.cat(J, dim=-1) - - lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) - - optim_vars_stk = torch.cat([v.data for v in cf.optim_vars], dim=-1) - eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) - if lie is False: - eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) - eta = eta.squeeze(-1) - - return eta, lam - - -def pass_var_to_fac_messages( - ftov_msgs_eta, - ftov_msgs_lam, - var_ix_for_edges, - n_vars, - max_dofs, -): - belief_eta = torch.zeros(n_vars, max_dofs, dtype=ftov_msgs_eta.dtype) - belief_lam = torch.zeros(n_vars, max_dofs, max_dofs, dtype=ftov_msgs_eta.dtype) - - belief_eta = belief_eta.index_add(0, var_ix_for_edges, ftov_msgs_eta) - belief_lam = belief_lam.index_add(0, var_ix_for_edges, ftov_msgs_lam) - - vtof_msgs_eta = belief_eta[var_ix_for_edges] - ftov_msgs_eta - vtof_msgs_lam = belief_lam[var_ix_for_edges] - ftov_msgs_lam - - return vtof_msgs_eta, vtof_msgs_lam, belief_eta, belief_lam - - -def pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - adj_var_dofs_nested: List[List], -): - ftov_msgs_eta = torch.zeros_like(vtof_msgs_eta) - ftov_msgs_lam = torch.zeros_like(vtof_msgs_lam) - - start = 0 - for i in range(len(adj_var_dofs_nested)): - adj_var_dofs = adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - ftov_eta, ftov_lam = ftov_comp_mess( - adj_var_dofs, - potentials_eta[i], - potentials_lam[i], - vtof_msgs_eta[start : start + num_optim_vars], - vtof_msgs_lam[start : start + num_optim_vars], - ) - - ftov_msgs_eta[start : start + num_optim_vars] = torch.cat(ftov_eta) - ftov_msgs_lam[start : start + num_optim_vars] = torch.cat(ftov_lam) - - start += num_optim_vars - - return ftov_msgs_eta, ftov_msgs_lam - - -# Transforms message gaussian to tangent plane at var -# if return_mean is True, return the (mean, lam) else return (eta, lam). -# Generalises the local function by transforming the covariance as well as mean. -def local_gaussian( - mess: Message, - var: th.LieGroup, - return_mean: bool = True, -) -> Tuple[torch.Tensor, torch.Tensor]: - # mean_tp is message mean in tangent space / plane at var - mean_tp = var.local(mess.mean[0]) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - - # lam_tp is the precision matrix in the tangent plane - lam_tp = torch.bmm(torch.bmm(jac[0].transpose(-1, -2), mess.precision), jac[0]) - - if return_mean: - return mean_tp, lam_tp - - else: - eta_tp = torch.matmul(lam_tp, mean_tp.unsqueeze(-1)).squeeze(-1) - return eta_tp, lam_tp - - -# Transforms Gaussian in the tangent plane at var to Gaussian where the mean -# is a group element and the precision matrix is defined in the tangent plane -# at the mean. -# Generalises the retract function by transforming the covariance as well as mean. -# out_gauss is the transformed Gaussian that is updated in place. -def retract_gaussian( - mean_tp: torch.Tensor, - precision_tp: torch.Tensor, - var: th.LieGroup, - out_gauss: ManifoldGaussian, -): - mean = var.retract(mean_tp) - - jac: List[torch.Tensor] = [] - th.exp_map(var, mean_tp, jacobians=jac) - inv_jac = torch.inverse(jac[0]) - precision = torch.bmm(torch.bmm(inv_jac.transpose(-1, -2), precision_tp), inv_jac) - - out_gauss.update(mean=[mean], precision=precision) - - -# Compute all outgoing messages from the factor. -def ftov_comp_mess( - adj_var_dofs, - potential_eta, - potential_lam, - vtof_msgs_eta, - vtof_msgs_lam, -): - num_optim_vars = len(adj_var_dofs) - messages_eta, messages_lam = [], [] - - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potential_eta.clone()[0] - lam_factor = potential_lam.clone()[0] - - # Take product of factor with incoming messages - start = 0 - for var in range(num_optim_vars): - var_dofs = adj_var_dofs[var] - if var != v: - eta_mess = vtof_msgs_eta[var] - lam_mess = vtof_msgs_lam[var] - eta_factor[start : start + var_dofs] += eta_mess - lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess - start += var_dofs - - # Divide up parameters of distribution - dofs = adj_var_dofs[v] - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), axis=1 - ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno - - messages_eta.append(new_message_eta[None, :]) - messages_lam.append(new_message_lam[None, :]) - - sdim += dofs - - return messages_eta, messages_lam - - -# Follows notation from https://arxiv.org/pdf/2202.03314.pdf - - -class GaussianBeliefPropagation(Optimizer, abc.ABC): - def __init__( - self, - objective: Objective, - abs_err_tolerance: float = 1e-10, - rel_err_tolerance: float = 1e-8, - max_iterations: int = 20, - ): - super().__init__(objective) - - # ordering is required to identify which messages to send where - self.ordering = VariableOrdering(objective, default_order=True) - self.cf_ordering = CostFunctionOrdering(objective) - - self.schedule = None - - self.params = GBPOptimizerParams( - abs_err_tolerance, rel_err_tolerance, max_iterations - ) - - self.n_edges = sum([cf.num_optim_vars() for cf in self.cf_ordering]) - self.max_dofs = max([var.dof() for var in self.ordering]) - - # create arrays for indexing the messages - var_ixs_nested = [ - [self.ordering.index_of(var.name) for var in cf.optim_vars] - for cf in self.cf_ordering - ] - var_ixs = [item for sublist in var_ixs_nested for item in sublist] - self.var_ix_for_edges = torch.tensor(var_ixs).long() - - self.adj_var_dofs_nested = [ - [var.shape[1] for var in cf.optim_vars] for cf in self.cf_ordering - ] - - lie_groups = False - for v in self.ordering: - if isinstance(v, th.LieGroup) and not isinstance(v, th.Vector): - lie_groups = True - self.lie_groups = lie_groups - print("lie groups:", self.lie_groups) - - """ - Copied and slightly modified from nonlinear optimizer class - """ - - def set_params(self, **kwargs): - self.params.update(kwargs) - - def _check_convergence(self, err: torch.Tensor, last_err: torch.Tensor): - assert not torch.is_grad_enabled() - if err.abs().mean() < theseus.constants.EPS: - return torch.ones_like(err).bool() - - abs_error = (last_err - err).abs() - rel_error = abs_error / last_err - return (abs_error < self.params.abs_err_tolerance).logical_or( - rel_error < self.params.rel_err_tolerance - ) - - def _maybe_init_best_solution( - self, do_init: bool = False - ) -> Optional[Dict[str, torch.Tensor]]: - if not do_init: - return None - solution_dict = {} - for var in self.ordering: - solution_dict[var.name] = var.data.detach().clone().cpu() - return solution_dict - - def _init_info( - self, track_best_solution: bool, track_err_history: bool, verbose: bool - ) -> NonlinearOptimizerInfo: - with torch.no_grad(): - last_err = self.objective.error_squared_norm() / 2 - best_err = last_err.clone() if track_best_solution else None - if track_err_history: - err_history = ( - torch.ones(self.objective.batch_size, self.params.max_iterations + 1) - * math.inf - ) - assert last_err.grad_fn is None - err_history[:, 0] = last_err.clone().cpu() - else: - err_history = None - return NonlinearOptimizerInfo( - best_solution=self._maybe_init_best_solution(do_init=track_best_solution), - last_err=last_err, - best_err=best_err, - status=np.array( - [NonlinearOptimizerStatus.START] * self.objective.batch_size - ), - converged_iter=torch.zeros_like(last_err, dtype=torch.long), - best_iter=torch.zeros_like(last_err, dtype=torch.long), - err_history=err_history, - ) - - def _update_info( - self, - info: NonlinearOptimizerInfo, - current_iter: int, - err: torch.Tensor, - converged_indices: torch.Tensor, - ): - info.converged_iter += 1 - converged_indices.long() - if info.err_history is not None: - assert err.grad_fn is None - info.err_history[:, current_iter + 1] = err.clone().cpu() - - if info.best_solution is not None: - # Only copy best solution if needed (None means track_best_solution=False) - assert info.best_err is not None - good_indices = err < info.best_err - info.best_iter[good_indices] = current_iter - for var in self.ordering: - info.best_solution[var.name][good_indices] = ( - var.data.detach().clone()[good_indices].cpu() - ) - - info.best_err = torch.minimum(info.best_err, err) - - converged_indices = self._check_convergence(err, info.last_err) - info.status[ - np.array(converged_indices.detach().cpu()) - ] = NonlinearOptimizerStatus.CONVERGED - - # Modifies the (no grad) info in place to add data of grad loop info - def _merge_infos( - self, - grad_loop_info: NonlinearOptimizerInfo, - num_no_grad_iter: int, - backward_num_iterations: int, - info: NonlinearOptimizerInfo, - ): - # Concatenate error histories - if info.err_history is not None: - info.err_history[:, num_no_grad_iter:] = grad_loop_info.err_history[ - :, : backward_num_iterations + 1 - ] - # Merge best solution and best error - if info.best_solution is not None: - best_solution = {} - best_err_no_grad = info.best_err - best_err_grad = grad_loop_info.best_err - idx_no_grad = best_err_no_grad < best_err_grad - best_err = torch.minimum(best_err_no_grad, best_err_grad) - for var_name in info.best_solution: - sol_no_grad = info.best_solution[var_name] - sol_grad = grad_loop_info.best_solution[var_name] - best_solution[var_name] = torch.where( - idx_no_grad, sol_no_grad, sol_grad - ) - info.best_solution = best_solution - info.best_err = best_err - - # Merge the converged status into the info from the detached loop, - M = info.status == NonlinearOptimizerStatus.MAX_ITERATIONS - assert np.all( - (grad_loop_info.status[M] == NonlinearOptimizerStatus.MAX_ITERATIONS) - | (grad_loop_info.status[M] == NonlinearOptimizerStatus.CONVERGED) - ) - info.status[M] = grad_loop_info.status[M] - info.converged_iter[M] = ( - info.converged_iter[M] + grad_loop_info.converged_iter[M] - ) - # If didn't coverge in either loop, remove misleading converged_iter value - info.converged_iter[ - M & (grad_loop_info.status == NonlinearOptimizerStatus.MAX_ITERATIONS) - ] = -1 - - """ - GBP specific functions - """ - - # Linearizes factors at current belief if beliefs have deviated - # from the linearization point by more than the threshold. - def _linearize( - self, - potentials_eta, - potentials_lam, - lin_points, - relin_threshold: float = None, - lie=False, - ): - do_lins = [] - for i, cf in enumerate(self.cf_ordering): - - do_lin = False - if relin_threshold is None: - do_lin = True - else: - lp_dists = [ - lp.local(cf.optim_var_at(j)).norm() - for j, lp in enumerate(lin_points[i]) - ] - do_lin = np.max(lp_dists) > relin_threshold - - do_lins.append(do_lin) - - if do_lin: - potential_eta, potential_lam = compute_factor(cf, lie=lie) - - potentials_eta[i] = potential_eta - potentials_lam[i] = potential_lam - - for j, var in enumerate(cf.optim_vars): - lin_points[i][j].update(var.data) - - # print(f"Linearised {np.sum(do_lins)} out of {len(do_lins)} factors.") - return potentials_eta, potentials_lam, lin_points - - def _pass_var_to_fac_messages( - self, - ftov_msgs, - vtof_msgs, - update_belief=True, - ): - for i, var in enumerate(self.ordering): - - # Collect all incoming messages in the tangent space at the current belief - taus = [] # message means - lams_tp = [] # message lams - for j, msg in enumerate(ftov_msgs): - if self.var_ix_for_edges[j] == i: - tau, lam_tp = local_gaussian(msg, var, return_mean=True) - taus.append(tau[None, ...]) - lams_tp.append(lam_tp[None, ...]) - - taus = torch.cat(taus) - lams_tp = torch.cat(lams_tp) - - lam_tau = lams_tp.sum(dim=0) - - # Compute outgoing messages - ix = 0 - for j, msg in enumerate(ftov_msgs): - if self.var_ix_for_edges[j] == i: - taus_inc = torch.cat((taus[:ix], taus[ix + 1 :])) - lams_inc = torch.cat((lams_tp[:ix], lams_tp[ix + 1 :])) - - lam_a = lams_inc.sum(dim=0) - if lam_a.count_nonzero() == 0: - vtof_msgs[j].mean[0].data[:] = 0.0 - vtof_msgs[j].precision = lam_a - else: - inv_lam_a = torch.inverse(lam_a) - sum_taus = torch.matmul(lams_inc, taus_inc.unsqueeze(-1)).sum( - dim=0 - ) - tau_a = torch.matmul(inv_lam_a, sum_taus).squeeze(-1) - retract_gaussian(tau_a, lam_a, var, vtof_msgs[j]) - ix += 1 - - # update belief mean and variance - # if no incoming messages then leave current belief unchanged - if update_belief and lam_tau.count_nonzero() != 0: - inv_lam_tau = torch.inverse(lam_tau) - sum_taus = torch.matmul(lams_tp, taus.unsqueeze(-1)).sum(dim=0) - tau = torch.matmul(inv_lam_tau, sum_taus).squeeze(-1) - - retract_gaussian(tau, lam_tau, var, self.beliefs[i]) - - def _pass_fac_to_var_messages( - self, - potentials_eta, - potentials_lam, - lin_points, - vtof_msgs, - ftov_msgs, - damping: torch.Tensor, - ): - start = 0 - for i in range(len(self.adj_var_dofs_nested)): - adj_var_dofs = self.adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - self._ftov_comp_mess( - potentials_eta[i], - potentials_lam[i], - lin_points[i], - vtof_msgs[start : start + num_optim_vars], - ftov_msgs[start : start + num_optim_vars], - damping[start : start + num_optim_vars], - ) - - start += num_optim_vars - - # Compute all outgoing messages from the factor. - def _ftov_comp_mess( - self, - potential_eta, - potential_lam, - lin_points, - vtof_msgs, - ftov_msgs, - damping, - ): - num_optim_vars = len(lin_points) - new_messages = [] - - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potential_eta.clone()[0] - lam_factor = potential_lam.clone()[0] - - # Take product of factor with incoming messages. - # Convert mesages to tangent space at linearisation point. - start = 0 - for i in range(num_optim_vars): - var_dofs = lin_points[i].dof() - if i != v: - eta_mess, lam_mess = local_gaussian( - vtof_msgs[i], lin_points[i], return_mean=False - ) - eta_factor[start : start + var_dofs] += eta_mess[0] - lam_factor[ - start : start + var_dofs, start : start + var_dofs - ] += lam_mess[0] - start += var_dofs - - # Divide up parameters of distribution - dofs = lin_points[v].dof() - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - axis=1, - ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_mess_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_mess_eta = eo - lono @ np.linalg.inv(lnono) @ eno - - # damping in tangent space at linearisation point - # prev_mess_eta, prev_mess_lam = local_gaussian( - # vtof_msgs[v], lin_points[v], return_mean=False) - # new_mess_eta = (1 - damping[v]) * new_mess_eta + damping[v] * prev_mess_eta[0] - # new_mess_lam = (1 - damping[v]) * new_mess_lam + damping[v] * prev_mess_lam[0] - - if new_mess_lam.count_nonzero() == 0: - new_mess = ManifoldGaussian([lin_points[v].copy()]) - new_mess.mean[0].data[:] = 0.0 - else: - new_mess_mean = torch.matmul(torch.inverse(new_mess_lam), new_mess_eta) - new_mess_mean = new_mess_mean[None, ...] - new_mess_lam = new_mess_lam[None, ...] - - new_mess = ManifoldGaussian([lin_points[v].copy()]) - retract_gaussian(new_mess_mean, new_mess_lam, lin_points[v], new_mess) - new_messages.append(new_mess) - - sdim += dofs - - # update messages - for v in range(num_optim_vars): - ftov_msgs[v].update( - mean=new_messages[v].mean, precision=new_messages[v].precision - ) - - return new_messages - - """ - Optimization loop functions - """ - - # loop for the iterative optimizer - def _optimize_loop( - self, - start_iter: int, - num_iter: int, - info: NonlinearOptimizerInfo, - verbose: bool, - truncated_grad_loop: bool, - relin_threshold: float = 0.1, - damping: float = 0.0, - dropout: float = 0.0, - **kwargs, - ): - # initialise messages with zeros - vtof_msgs_eta = torch.zeros( - self.n_edges, self.max_dofs, dtype=self.objective.dtype - ) - vtof_msgs_lam = torch.zeros( - self.n_edges, self.max_dofs, self.max_dofs, dtype=self.objective.dtype - ) - ftov_msgs_eta = vtof_msgs_eta.clone() - ftov_msgs_lam = vtof_msgs_lam.clone() - - # compute factor potentials for the first time - potentials_eta = [None] * self.objective.size_cost_functions() - potentials_lam = [None] * self.objective.size_cost_functions() - lin_points = [ - [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering - ] - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, relin_threshold=None - ) - - converged_indices = torch.zeros_like(info.last_err).bool() - for it_ in range(start_iter, start_iter + num_iter): - - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, relin_threshold=None - ) - - msgs_eta, msgs_lam = pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - self.adj_var_dofs_nested, - ) - - # damping - # damping = self.gbp_settings.get_damping(iters_since_relin) - damping_arr = torch.full([len(msgs_eta)], damping) - - # dropout can be implemented through damping - if dropout != 0.0: - dropout_ixs = torch.rand(len(msgs_eta)) < dropout - damping_arr[dropout_ixs] = 1.0 - - ftov_msgs_eta = (1 - damping_arr[:, None]) * msgs_eta + damping_arr[ - :, None - ] * ftov_msgs_eta - ftov_msgs_lam = (1 - damping_arr[:, None, None]) * msgs_lam + damping_arr[ - :, None, None - ] * ftov_msgs_lam - - ( - vtof_msgs_eta, - vtof_msgs_lam, - belief_eta, - belief_lam, - ) = pass_var_to_fac_messages( - ftov_msgs_eta, - ftov_msgs_lam, - self.var_ix_for_edges, - len(self.ordering._var_order), - self.max_dofs, - ) - - # update beliefs - belief_cov = torch.inverse(belief_lam) - belief_mean = torch.matmul(belief_cov, belief_eta.unsqueeze(-1)).squeeze() - for i, var in enumerate(self.ordering): - var.update(data=belief_mean[i][None, :]) - - # check for convergence - with torch.no_grad(): - err = self.objective.error_squared_norm() / 2 - self._update_info(info, it_, err, converged_indices) - if verbose: - print(f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}") - converged_indices = self._check_convergence(err, info.last_err) - info.status[ - converged_indices.cpu().numpy() - ] = NonlinearOptimizerStatus.CONVERGED - if converged_indices.all(): - break # nothing else will happen at this point - info.last_err = err - - info.status[ - info.status == NonlinearOptimizerStatus.START - ] = NonlinearOptimizerStatus.MAX_ITERATIONS - return info - - # loop for the iterative optimizer - def _optimize_loop_lie( - self, - start_iter: int, - num_iter: int, - info: NonlinearOptimizerInfo, - verbose: bool, - truncated_grad_loop: bool, - relin_threshold: float = 0.1, - damping: float = 0.0, - dropout: float = 0.0, - **kwargs, - ): - # initialise messages with zeros - vtof_msgs: List[Message] = [] - ftov_msgs: List[Message] = [] - for cf in self.cf_ordering: - for var in cf.optim_vars: - vtof_msg_mu = var.copy(new_name=f"msg_{var.name}_to_{cf.name}") - # mean of initial message doesn't matter as long as precision is zero - vtof_msg_mu.data[:] = 0 - ftov_msg_mu = vtof_msg_mu.copy(new_name=f"msg_{cf.name}_to_{var.name}") - vtof_msgs.append(Message([vtof_msg_mu])) - ftov_msgs.append(Message([ftov_msg_mu])) - - # initialise gaussian for belief - self.beliefs: List[Marginal] = [] - for var in self.ordering: - self.beliefs.append(Marginal([var])) - - # compute factor potentials for the first time - potentials_eta = [None] * self.objective.size_cost_functions() - potentials_lam = [None] * self.objective.size_cost_functions() - lin_points = [ - [var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars] - for cf in self.cf_ordering - ] - potentials_eta, potentials_lam, lin_points = self._linearize( - potentials_eta, potentials_lam, lin_points, relin_threshold=None, lie=True - ) - - converged_indices = torch.zeros_like(info.last_err).bool() - for it_ in range(start_iter, start_iter + num_iter): - - potentials_eta, potentials_lam, self.lin_points = self._linearize( - potentials_eta, - potentials_lam, - lin_points, - relin_threshold=None, - lie=True, - ) - - # damping - # damping = self.gbp_settings.get_damping(iters_since_relin) - damping_arr = torch.full([self.n_edges], damping) - - # dropout can be implemented through damping - if dropout != 0.0: - dropout_ixs = torch.rand(self.n_edges) < dropout - damping_arr[dropout_ixs] = 1.0 - - self._pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - lin_points, - vtof_msgs, - ftov_msgs, - damping_arr, - ) - - self._pass_var_to_fac_messages( - ftov_msgs, - vtof_msgs, - update_belief=True, - ) - - # check for convergence - if it_ > 0: - with torch.no_grad(): - err = self.objective.error_squared_norm() / 2 - self._update_info(info, it_, err, converged_indices) - if verbose: - print( - f"GBP. Iteration: {it_+1}. " f"Error: {err.mean().item()}" - ) - converged_indices = self._check_convergence(err, info.last_err) - info.status[ - converged_indices.cpu().numpy() - ] = NonlinearOptimizerStatus.CONVERGED - if converged_indices.all() and it_ > 1: - break # nothing else will happen at this point - info.last_err = err - - info.status[ - info.status == NonlinearOptimizerStatus.START - ] = NonlinearOptimizerStatus.MAX_ITERATIONS - return info - - # `track_best_solution` keeps a **detached** copy (as in no gradient info) - # of the best variables found, but it is optional to avoid unnecessary copying - # if this is not needed - def _optimize_impl( - self, - track_best_solution: bool = False, - track_err_history: bool = False, - verbose: bool = False, - backward_mode: BackwardMode = BackwardMode.FULL, - damping: float = 0.0, - dropout: float = 0.0, - **kwargs, - ) -> NonlinearOptimizerInfo: - if damping > 1.0 or damping < 0.0: - raise NotImplementedError("Damping must be in between 0 and 1.") - if dropout > 1.0 or dropout < 0.0: - raise NotImplementedError("Dropout probability must be in between 0 and 1.") - - with torch.no_grad(): - info = self._init_info(track_best_solution, track_err_history, verbose) - - if verbose: - print( - f"GBP optimizer. Iteration: 0. " f"Error: {info.last_err.mean().item()}" - ) - - grad = False - if backward_mode == BackwardMode.FULL: - grad = True - - with torch.set_grad_enabled(grad): - - # if self.lie_groups: - info = self._optimize_loop_lie( - start_iter=0, - num_iter=self.params.max_iterations, - info=info, - verbose=verbose, - truncated_grad_loop=False, - damping=damping, - dropout=dropout, - **kwargs, - ) - # else: - # info = self._optimize_loop( - # start_iter=0, - # num_iter=self.params.max_iterations, - # info=info, - # verbose=verbose, - # truncated_grad_loop=False, - # damping=damping, - # dropout=dropout, - # **kwargs, - # ) - # If didn't coverge, remove misleading converged_iter value - info.converged_iter[ - info.status == NonlinearOptimizerStatus.MAX_ITERATIONS - ] = -1 - return info diff --git a/theseus/optimizer/gbp/jax_torch_poc.py b/theseus/optimizer/gbp/jax_torch_poc.py deleted file mode 100644 index fb4fe3e97..000000000 --- a/theseus/optimizer/gbp/jax_torch_poc.py +++ /dev/null @@ -1,488 +0,0 @@ -import time - -import jax -import jax.numpy as jnp -import numpy as np -import torch - - -def pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - adj_var_dofs_nested, -): - ftov_msgs_eta = [None] * len(vtof_msgs_eta) - ftov_msgs_lam = [None] * len(vtof_msgs_eta) - - start = 0 - for i in range(len(adj_var_dofs_nested)): - adj_var_dofs = adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - ftov_eta, ftov_lam = [], [] - - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potentials_eta[i].clone()[0] - lam_factor = potentials_lam[i].clone()[0] - - # Take product of factor with incoming messages - start_in = 0 - for var in range(num_optim_vars): - var_dofs = adj_var_dofs[var] - if var != v: - eta_mess = vtof_msgs_eta[var] - lam_mess = vtof_msgs_lam[var] - eta_factor[start_in : start_in + var_dofs] += eta_mess - lam_factor[ - start_in : start_in + var_dofs, start_in : start_in + var_dofs - ] += lam_mess - start_in += var_dofs - - # Divide up parameters of distribution - dofs = adj_var_dofs[v] - eo = eta_factor[sdim : sdim + dofs] - eno = np.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = np.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = np.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = np.concatenate( - ( - np.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - axis=1, - ), - np.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_message_lam = loo - lono @ np.linalg.inv(lnono) @ lnoo - new_message_eta = eo - lono @ np.linalg.inv(lnono) @ eno - - ftov_eta.append(new_message_eta[None, :]) - ftov_lam.append(new_message_lam[None, :]) - - sdim += dofs - - ftov_msgs_eta[start : start + num_optim_vars] = ftov_eta - ftov_msgs_lam[start : start + num_optim_vars] = ftov_lam - - start += num_optim_vars - - return ftov_msgs_eta, ftov_msgs_lam - - -@jax.jit -def pass_fac_to_var_messages_jax( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - adj_var_dofs_nested, -): - ftov_msgs_eta = [None] * len(vtof_msgs_eta) - ftov_msgs_lam = [None] * len(vtof_msgs_eta) - - start = 0 - for i in range(len(adj_var_dofs_nested)): - adj_var_dofs = adj_var_dofs_nested[i] - num_optim_vars = len(adj_var_dofs) - - ftov_eta, ftov_lam = [], [] - - sdim = 0 - for v in range(num_optim_vars): - eta_factor = potentials_eta[i][0] - lam_factor = potentials_lam[i][0] - - # Take product of factor with incoming messages - start_in = 0 - for var in range(num_optim_vars): - var_dofs = adj_var_dofs[var] - if var != v: - eta_mess = vtof_msgs_eta[var] - lam_mess = vtof_msgs_lam[var] - eta_factor = eta_factor.at[start_in : start_in + var_dofs].add( - eta_mess - ) - lam_factor = lam_factor.at[ - start_in : start_in + var_dofs, start_in : start_in + var_dofs - ].add(lam_mess) - start_in += var_dofs - - # Divide up parameters of distribution - dofs = adj_var_dofs[v] - eo = eta_factor[sdim : sdim + dofs] - eno = jnp.concatenate((eta_factor[:sdim], eta_factor[sdim + dofs :])) - - loo = lam_factor[sdim : sdim + dofs, sdim : sdim + dofs] - lono = jnp.concatenate( - ( - lam_factor[sdim : sdim + dofs, :sdim], - lam_factor[sdim : sdim + dofs, sdim + dofs :], - ), - axis=1, - ) - lnoo = jnp.concatenate( - ( - lam_factor[:sdim, sdim : sdim + dofs], - lam_factor[sdim + dofs :, sdim : sdim + dofs], - ), - axis=0, - ) - lnono = jnp.concatenate( - ( - jnp.concatenate( - (lam_factor[:sdim, :sdim], lam_factor[:sdim, sdim + dofs :]), - axis=1, - ), - jnp.concatenate( - ( - lam_factor[sdim + dofs :, :sdim], - lam_factor[sdim + dofs :, sdim + dofs :], - ), - axis=1, - ), - ), - axis=0, - ) - - new_message_lam = loo - lono @ jnp.linalg.inv(lnono) @ lnoo - new_message_eta = eo - lono @ jnp.linalg.inv(lnono) @ eno - - ftov_eta.append(new_message_eta[None, :]) - ftov_lam.append(new_message_lam[None, :]) - - sdim += dofs - - ftov_msgs_eta[start : start + num_optim_vars] = ftov_eta - ftov_msgs_lam[start : start + num_optim_vars] = ftov_lam - - start += num_optim_vars - - return ftov_msgs_eta, ftov_msgs_lam - - -if __name__ == "__main__": - - adj_var_dofs_nested = [ - [2], - [2], - [2], - [2], - [2], - [2], - [2], - [2], - [2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - [2, 2], - ] - - potentials_eta = [ - torch.tensor([[0.0, 0.0]]), - torch.tensor([[0.5292, -0.1270]]), - torch.tensor([[1.2858, -0.2724]]), - torch.tensor([[0.2065, 0.5016]]), - torch.tensor([[0.6295, 0.5622]]), - torch.tensor([[1.3565, 0.3479]]), - torch.tensor([[-0.0382, 1.1380]]), - torch.tensor([[0.7259, 1.0533]]), - torch.tensor([[1.1630, 1.0795]]), - torch.tensor([[-100.4221, -5.8282, 100.4221, 5.8282]]), - torch.tensor([[11.0062, -111.4472, -11.0062, 111.4472]]), - torch.tensor([[-109.0159, -5.0249, 109.0159, 5.0249]]), - torch.tensor([[-9.0086, -93.1627, 9.0086, 93.1627]]), - torch.tensor([[1.2289, -90.6423, -1.2289, 90.6423]]), - torch.tensor([[-97.3211, -5.3036, 97.3211, 5.3036]]), - torch.tensor([[6.9166, -96.0325, -6.9166, 96.0325]]), - torch.tensor([[-93.1283, 8.4521, 93.1283, -8.4521]]), - torch.tensor([[6.7125, -99.8733, -6.7125, 99.8733]]), - torch.tensor([[11.1731, -102.3442, -11.1731, 102.3442]]), - torch.tensor([[-116.5980, -7.4204, 116.5980, 7.4204]]), - torch.tensor([[-98.0816, 8.8763, 98.0816, -8.8763]]), - ] - potentials_lam = [ - torch.tensor([[[10000.0, 0.0], [0.0, 10000.0]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor([[[0.5917, 0.0000], [0.0000, 0.5917]]]), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - torch.tensor( - [ - [ - [100.0, 0.0, -100.0, 0.0], - [0.0, 100.0, 0.0, -100.0], - [-100.0, 0.0, 100.0, 0.0], - [0.0, -100.0, 0.0, 100.0], - ] - ] - ), - ] - - vtof_msgs_eta_list = [ - torch.tensor([[0.8536, -1.5929]]), - torch.tensor([[182.3461, 16.7745]]), - torch.tensor([[222.8854, 13.1250]]), - torch.tensor([[-10.1678, 202.9393]]), - torch.tensor([[200.4927, 213.6843]]), - torch.tensor([[264.5976, 132.6887]]), - torch.tensor([[-17.9007, 222.3988]]), - torch.tensor([[127.5813, 277.0478]]), - torch.tensor([[191.0187, 201.1600]]), - torch.tensor([[5.6620, -4.6983]]), - torch.tensor([[83.3856, 10.9277]]), - torch.tensor([[-4.8085, 3.1053]]), - torch.tensor([[0.9854, 93.0631]]), - torch.tensor([[153.1307, 16.3761]]), - torch.tensor([[98.1263, 3.3349]]), - torch.tensor([[129.7635, 5.8644]]), - torch.tensor([[140.2319, 158.5661]]), - torch.tensor([[127.3308, 9.2454]]), - torch.tensor([[187.8824, 92.8337]]), - torch.tensor([[-16.5414, 145.2973]]), - torch.tensor([[152.8149, 148.6686]]), - torch.tensor([[-4.1601, 169.0230]]), - torch.tensor([[-12.0344, 99.1287]]), - torch.tensor([[153.7062, 168.3496]]), - torch.tensor([[149.0974, 72.7772]]), - torch.tensor([[157.2429, 167.7175]]), - torch.tensor([[70.8858, 152.1307]]), - torch.tensor([[196.2848, 100.8102]]), - torch.tensor([[99.5512, 100.5530]]), - torch.tensor([[-5.9426, 125.5461]]), - torch.tensor([[87.5787, 197.8408]]), - torch.tensor([[98.8758, 207.2840]]), - torch.tensor([[93.7936, 102.7661]]), - ] - vtof_msgs_lam = [ - torch.tensor([[95.7949, 0.0000], [0.0000, 95.7949]]), - torch.tensor([[190.3769, 0.0000], [0.0000, 190.3769]]), - torch.tensor([[109.9605, 0.0000], [0.0000, 109.9605]]), - torch.tensor([[190.3769, 0.0000], [0.0000, 190.3769]]), - torch.tensor([[197.8604, 0.0000], [0.0000, 197.8604]]), - torch.tensor([[132.5915, 0.0000], [0.0000, 132.5915]]), - torch.tensor([[109.9605, 0.0000], [0.0000, 109.9605]]), - torch.tensor([[132.5915, 0.0000], [0.0000, 132.5915]]), - torch.tensor([[99.8496, 0.0000], [0.0000, 99.8496]]), - torch.tensor([[10047.8975, 0.0000], [0.0000, 10047.8975]]), - torch.tensor([[91.9540, 0.0000], [0.0000, 91.9540]]), - torch.tensor([[10047.8975, 0.0000], [0.0000, 10047.8975]]), - torch.tensor([[91.9540, 0.0000], [0.0000, 91.9540]]), - torch.tensor([[158.0642, 0.0000], [0.0000, 158.0642]]), - torch.tensor([[49.3043, 0.0000], [0.0000, 49.3043]]), - torch.tensor([[132.5106, 0.0000], [0.0000, 132.5106]]), - torch.tensor([[141.4631, 0.0000], [0.0000, 141.4631]]), - torch.tensor([[61.8396, 0.0000], [0.0000, 61.8396]]), - torch.tensor([[94.9975, 0.0000], [0.0000, 94.9975]]), - torch.tensor([[132.5106, 0.0000], [0.0000, 132.5106]]), - torch.tensor([[141.4631, 0.0000], [0.0000, 141.4631]]), - torch.tensor([[158.0642, 0.0000], [0.0000, 158.0642]]), - torch.tensor([[49.3043, 0.0000], [0.0000, 49.3043]]), - torch.tensor([[156.5110, 0.0000], [0.0000, 156.5110]]), - torch.tensor([[72.2502, 0.0000], [0.0000, 72.2502]]), - torch.tensor([[156.5110, 0.0000], [0.0000, 156.5110]]), - torch.tensor([[72.2502, 0.0000], [0.0000, 72.2502]]), - torch.tensor([[99.7104, 0.0000], [0.0000, 99.7104]]), - torch.tensor([[50.5165, 0.0000], [0.0000, 50.5165]]), - torch.tensor([[61.8396, 0.0000], [0.0000, 61.8396]]), - torch.tensor([[94.9975, 0.0000], [0.0000, 94.9975]]), - torch.tensor([[99.7104, 0.0000], [0.0000, 99.7104]]), - torch.tensor([[50.5165, 0.0000], [0.0000, 50.5165]]), - ] - vtof_msgs_eta = torch.cat(vtof_msgs_eta_list) - # vtof_msgs_lam = torch.cat([m[None, ...] for m in vtof_msgs_lam]) - - t1 = time.time() - times = [] - for i in range(100): - t_start = time.time() - ftov_msgs_eta, ftov_msgs_lam = pass_fac_to_var_messages( - potentials_eta, - potentials_lam, - vtof_msgs_eta, - vtof_msgs_lam, - adj_var_dofs_nested, - ) - times.append(time.time() - t_start) - - t2 = time.time() - print("-------------- TORCH --------------") - print("elapsed", t2 - t1) - print("min max mean", np.min(times), np.max(times), np.mean(times)) - - # print(ftov_msgs_eta) - # print(ftov_msgs_lam) - - potentials_eta_jax = [jnp.array(pe) for pe in potentials_eta] - potentials_lam_jax = [jnp.array(pe) for pe in potentials_lam] - vtof_msgs_eta_jax = jnp.array(vtof_msgs_eta) - vtof_msgs_lam_jax = [jnp.array(pe) for pe in vtof_msgs_lam] - - t1 = time.time() - times = [] - for i in range(10): - t_start = time.time() - ftov_msgs_eta_jax, ftov_msgs_lam_jax = pass_fac_to_var_messages_jax( - potentials_eta_jax, - potentials_lam_jax, - vtof_msgs_eta_jax, - vtof_msgs_lam_jax, - adj_var_dofs_nested, - ) - times.append(time.time() - t_start) - - t2 = time.time() - print("\n\n") - print("-------------- JAX --------------") - print("elapsed", t2 - t1) - print("min max mean", np.min(times), np.max(times), np.mean(times)) - - # print(ftov_msgs_eta_jax) - # print(ftov_msgs_lam_jax) diff --git a/theseus/optimizer/gbp/pgo_test.py b/theseus/optimizer/gbp/pgo_test.py deleted file mode 100644 index 1c39da6c8..000000000 --- a/theseus/optimizer/gbp/pgo_test.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - -import numpy as np -import torch - -import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule - -# This example illustrates the Gaussian Belief Propagation (GBP) optimizer -# for a 2D pose graph optimization problem. -# Linear problem where we are estimating the (x, y)position of 9 nodes, -# arranged in a 3x3 grid. -# Linear factors connect each node to its adjacent nodes. - -np.random.seed(1) -torch.manual_seed(0) - -size = 3 -dim = 2 - -noise_cov = np.array([0.01, 0.01]) - -prior_noise_std = 0.2 -prior_sigma = np.array([1.3**2, 1.3**2]) - -init_noises = np.random.normal(np.zeros([size * size, 2]), prior_noise_std) -meas_noises = np.random.normal(np.zeros([100, 2]), np.sqrt(noise_cov[0])) - -# create theseus objective ------------------------------------- - - -def create_pgo(): - - objective = th.Objective() - inputs = {} - - n_poses = size * size - - # create variables - poses = [] - for i in range(n_poses): - poses.append(th.Vector(tensor=torch.rand(1, 2), name=f"x{i}")) - - # add prior cost constraints with VariableDifference cost - prior_std = 1.3 - anchor_std = 0.01 - prior_w = th.ScaleCostWeight(1 / prior_std, name="prior_weight") - anchor_w = th.ScaleCostWeight(1 / anchor_std, name="anchor_weight") - - gt_poses = [] - - p = 0 - for i in range(size): - for j in range(size): - init = torch.Tensor([j, i]) - gt_poses.append(init[None, :]) - if i == 0 and j == 0: - w = anchor_w - else: - # noise_init = torch.normal(torch.zeros(2), prior_noise_std) - init = init + torch.FloatTensor(init_noises[p]) - w = prior_w - - prior_target = th.Vector(tensor=init, name=f"prior_{p}") - inputs[f"x{p}"] = init[None, :] - inputs[f"prior_{p}"] = init[None, :] - - cf_prior = th.Difference(poses[p], prior_target, w, name=f"prior_cost_{p}") - - objective.add(cf_prior) - - p += 1 - - # Measurement cost functions - - meas_std_tensor = torch.nn.Parameter(torch.tensor([0.1])) - meas_w = th.ScaleCostWeight(1 / meas_std_tensor, name="prior_weight") - - m = 0 - for i in range(size): - for j in range(size): - if j < size - 1: - measurement = torch.Tensor([1.0, 0.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = i * size + j + 1 - - meas = th.Vector(tensor=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] - - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas, meas_w, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 - - if i < size - 1: - measurement = torch.Tensor([0.0, 1.0]) - # measurement += torch.normal(torch.zeros(2), meas_std) - measurement += torch.FloatTensor(meas_noises[m]) - ix0 = i * size + j - ix1 = (i + 1) * size + j - - meas = th.Vector(tensor=measurement, name=f"meas_{m}") - inputs[f"meas_{m}"] = measurement[None, :] - - cf_meas = th.eb.Between( - poses[ix0], poses[ix1], meas, meas_w, name=f"meas_cost_{m}" - ) - objective.add(cf_meas) - m += 1 - - return objective, gt_poses, meas_std_tensor, inputs - - -def linear_solve_pgo(): - print("\n\nLinear solver...\n") - - objective, gt_poses, meas_std_tensor, inputs = create_pgo() - - # outer optimizer - gt_poses_tensor = torch.cat(gt_poses) - lr = 1e-3 - outer_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) - outer_optimizer.zero_grad() - - linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) - th_layer = th.TheseusLayer(linear_optimizer) - outputs_linsolve, _ = th_layer.forward(inputs, {"verbose": True}) - - out_ls_tensor = torch.cat(list(outputs_linsolve.values())) - loss = torch.norm(gt_poses_tensor - out_ls_tensor) - loss.backward() - - print("loss", loss.item()) - print("grad", meas_std_tensor.grad.item()) - - print("outputs\n", outputs_linsolve) - - -def gbp_solve_pgo(backward_mode, max_iterations=20, implicit_method="gbp"): - print("\n\nWith GBP...") - print("backward mode:", backward_mode, "\n") - - objective, gt_poses, meas_std_tensor, inputs = create_pgo() - - gt_poses_tensor = torch.cat(gt_poses) - lr = 1e-3 - outer_optimizer = torch.optim.Adam([meas_std_tensor], lr=lr) - outer_optimizer.zero_grad() - - vectorize = True - - optimizer = GaussianBeliefPropagation( - objective, - max_iterations=max_iterations, - vectorize=vectorize, - ) - theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) - - optim_arg = { - "verbose": True, - # "track_best_solution": True, - # "track_err_history": True, - "backward_mode": backward_mode, - "backward_num_iterations": 5, - "relin_threshold": 1e-8, - "damping": 0.0, - "dropout": 0.0, - "schedule": GBPSchedule.SYNCHRONOUS, - "implicit_step_size": 1e-5, - "implicit_method": implicit_method, - } - - outputs_gbp, info = theseus_optim.forward(inputs, optim_arg) - - out_gbp_tensor = torch.cat(list(outputs_gbp.values())) - loss = torch.norm(gt_poses_tensor - out_gbp_tensor) - - loss.backward() - if backward_mode == "implicit": - if optimizer.implicit_method == "gauss_newton": - meas_std_tensor.grad /= optimizer.implicit_step_size - - print("loss", loss.item()) - print("grad", meas_std_tensor.grad.item()) - - print("outputs\n", outputs_gbp) - - -linear_solve_pgo() - -gbp_solve_pgo(backward_mode="unroll", max_iterations=20) - -gbp_solve_pgo(backward_mode="truncated", max_iterations=20) - -gbp_solve_pgo( - backward_mode="implicit", max_iterations=20, implicit_method="gbp" -) -gbp_solve_pgo( - backward_mode="implicit", - max_iterations=20, - implicit_method="gauss_newton", -) diff --git a/theseus/optimizer/gbp/plot_ba_err.py b/theseus/optimizer/gbp/plot_ba_err.py deleted file mode 100644 index 1d5342977..000000000 --- a/theseus/optimizer/gbp/plot_ba_err.py +++ /dev/null @@ -1,68 +0,0 @@ -import numpy as np -import matplotlib.pylab as plt -import os - - -""" -Nesterov experiments -""" - - -def nesterov_plots(): - - root_dir = ( - "/home/joe/projects/mpSLAM/theseus/theseus/optimizer/gbp/outputs/nesterov/" - ) - exp_dir = root_dir + "bal/" - - err_normal = np.loadtxt(exp_dir + "0/error_history.txt") - err_nesterov_normalize = np.loadtxt(exp_dir + "normalize/error_history.txt") - err_nesterov_tp = np.loadtxt(exp_dir + "tangent_plane/error_history.txt") - - plt.plot(err_normal, label="Normal GBP") - plt.plot(err_nesterov_normalize, label="Nesterov acceleration - normalize") - plt.plot(err_nesterov_tp, label="Nesterov acceleration - lie algebra") - plt.legend() - plt.yscale("log") - plt.xlabel("Iterations") - plt.ylabel("Error") - plt.show() - - -""" -Comparing GBP to Levenberg Marquardt -""" - - -def gbp_vs_lm(): - - root_dir = "/home/joe/projects/theseus/theseus/optimizer/gbp/outputs" - err_files1 = [ - "gbp_problem-21-11315-pre.txt", - "levenberg_marquardt_problem-21-11315-pre.txt", - ] - err_files2 = [ - "gbp_problem-50-20431-pre.txt", - "levenberg_marquardt_problem-50-20431-pre.txt", - ] - - err_files = err_files1 - - for err_files in [err_files1, err_files2]: - - gbp_err = np.loadtxt(os.path.join(root_dir, err_files[0])) - lm_err = np.loadtxt(os.path.join(root_dir, err_files[1])) - - plt.plot(gbp_err, label="GBP") - plt.plot(lm_err, label="Levenberg Marquardt") - plt.xscale("log") - plt.title(err_files[0][4:]) - plt.xlabel("Iterations") - plt.ylabel("Total Energy") - plt.legend() - plt.show() - - -if __name__ == "__main__": - - nesterov_plots() diff --git a/theseus/optimizer/gbp/swarm.py b/theseus/optimizer/gbp/swarm.py deleted file mode 100644 index a340ee711..000000000 --- a/theseus/optimizer/gbp/swarm.py +++ /dev/null @@ -1,470 +0,0 @@ -import numpy as np -import random -import omegaconf -import time -from PIL import Image, ImageDraw, ImageFont, ImageFilter -from typing import Optional, Tuple, List, Callable - -import torch -import torch.nn as nn - -import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule -from theseus.optimizer.gbp import SwarmViewer - - -OPTIMIZER_CLASS = { - "gbp": GaussianBeliefPropagation, - "gauss_newton": th.GaussNewton, - "levenberg_marquardt": th.LevenbergMarquardt, -} - -OUTER_OPTIMIZER_CLASS = { - "sgd": torch.optim.SGD, - "adam": torch.optim.Adam, -} - -GBP_SCHEDULE = { - "synchronous": GBPSchedule.SYNCHRONOUS, -} - - -# create image from a character and font -def gen_char_img( - char, dilate=True, fontname="LiberationSerif-Bold.ttf", size=(200, 200) -): - img = Image.new("L", size, "white") - draw = ImageDraw.Draw(img) - fontsize = int(size[0] * 0.5) - font = ImageFont.truetype(fontname, fontsize) - char_displaysize = font.getsize(char) - offset = tuple((si - sc) // 2 for si, sc in zip(size, char_displaysize)) - draw.text((offset[0], offset[1] * 3 // 4), char, font=font, fill="#000") - - if dilate: - img = img.filter(ImageFilter.MinFilter(3)) - - return img - - -# all agents should be inside object (negative SDF values) -def target_char_loss(outputs, sdf): - positions = torch.cat(list(outputs.values())) - dists = sdf.signed_distance(positions)[0] - if torch.sum(dists == 0).item() != 0: - print("\n\nNumber of agents out of bounds: ", torch.sum(dists == 0).item()) - loss = torch.relu(dists) - return loss.sum() - - -def gen_target_sdf(cfg): - # setup target shape for outer loop loss fn - vis_limits = np.array(cfg["setup"]["vis_limits"]) - cell_size = 0.05 - img_size = tuple(np.rint((vis_limits[1] - vis_limits[0]) / cell_size).astype(int)) - img = gen_char_img( - cfg["outer_optim"]["target_char"], - dilate=True, - fontname="DejaVuSans-Bold.ttf", - size=img_size, - ) - occ_map = torch.Tensor(np.array(img) < 255) - occ_map = torch.flip( - occ_map, [0] - ) # flip vertically so y axis is upwards wrt character - # pad to expand area - area_limits = np.array(cfg["setup"]["area_limits"]) - padded_size = tuple( - np.rint((area_limits[1] - area_limits[0]) / cell_size).astype(int) - ) - pad = int((padded_size[0] - img_size[0]) / 2) - larger_occ_map = torch.zeros(padded_size) - larger_occ_map[pad:-pad, pad:-pad] = occ_map - sdf = th.eb.SignedDistanceField2D( - th.Variable(torch.Tensor(area_limits[0][None, :])), - th.Variable(torch.Tensor([cell_size])), - occupancy_map=th.Variable(larger_occ_map[None, :]), - ) - return sdf - - -def fc_block(in_f, out_f): - return nn.Sequential(nn.Linear(in_f, out_f), nn.ReLU()) - - -class SimpleMLP(nn.Module): - def __init__( - self, - input_dim=2, - output_dim=2, - hidden_dim=8, - hidden_layers=0, - scale_output=1.0, - ): - super(SimpleMLP, self).__init__() - # input is agent index - self.scale_output = scale_output - self.relu = nn.ReLU() - self.in_layer = nn.Linear(input_dim, hidden_dim) - hidden = [fc_block(hidden_dim, hidden_dim) for _ in range(hidden_layers)] - self.mid = nn.Sequential(*hidden) - self.out_layer = nn.Linear(hidden_dim, output_dim) - - def forward(self, x): - y = self.relu(self.in_layer(x)) - y = self.mid(y) - out = self.out_layer(y) * self.scale_output - return out - - -# custom factor for two agents collision -class TwoAgentsCollision(th.CostFunction): - def __init__( - self, - weight: th.CostWeight, - var1: th.Point2, - var2: th.Point2, - radius: th.Vector, - name: Optional[str] = None, - ): - super().__init__(weight, name=name) - self.var1 = var1 - self.var2 = var2 - self.radius = radius - # skips data checks - self.register_optim_vars(["var1", "var2"]) - self.register_aux_vars(["radius"]) - - # no error when distance exceeds radius - def error(self) -> torch.Tensor: - dist = torch.norm(self.var1.tensor - self.var2.tensor, dim=1, keepdim=True) - return torch.relu(1 - dist / self.radius.tensor) - - def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: - dist = torch.norm(self.var1.tensor - self.var2.tensor, dim=1, keepdim=True) - denom = dist * self.radius.tensor - jac = (self.var1.tensor - self.var2.tensor) / denom - jac = jac[:, None, :] - jac[dist > self.radius.tensor] = 0.0 - return [-jac, jac], self.error() - - def dim(self) -> int: - return 1 - - def _copy_impl(self, new_name: Optional[str] = None) -> "TwoAgentsCollision": - return TwoAgentsCollision( - self.weight.copy(), - self.var1.copy(), - self.var2.copy(), - self.radius.copy(), - name=new_name, - ) - - -# custom factor for GNN -class GNNTargets(th.CostFunction): - def __init__( - self, - weight: th.CostWeight, - agents: List[th.Point2], - gnn_err_fn: Callable, - name: Optional[str] = None, - ): - super().__init__(weight, name=name) - self.agents = agents - self.n_agents = len(agents) - self._gnn_err_fn = gnn_err_fn - # skips data checks - for agent in self.agents: - setattr(self, agent.name, agent) - self.register_optim_vars([v.name for v in agents]) - - # no error when distance exceeds radius - def error(self) -> torch.Tensor: - return self._gnn_err_fn(self.agents) - - # Cannot use autodiff for jacobians as we want the factor to be - # independent for each agent. i.e. GNN is implemented as many prior factors - def jacobians(self) -> Tuple[List[torch.Tensor], torch.Tensor]: - batch_size = self.agents[0].shape[0] - jacs = torch.zeros( - batch_size, - self.n_agents, - self.dim(), - 2, - dtype=self.agents[0].dtype, - device=self.agents[0].device, - ) - jacs[:, torch.arange(self.n_agents), 2 * torch.arange(self.n_agents), 0] = 1.0 - jacs[ - :, torch.arange(self.n_agents), 2 * torch.arange(self.n_agents) + 1, 1 - ] = 1.0 - jac_list = [jacs[:, i] for i in range(self.n_agents)] - return jac_list, self.error() - - def dim(self) -> int: - return self.n_agents * 2 - - def _copy_impl(self, new_name: Optional[str] = None) -> "GNNTargets": - return GNNTargets( - self.weight.copy(), - [agent.copy() for agent in self.agents], - self._gnn_err_fn, - name=new_name, - ) - - -def setup_problem(cfg: omegaconf.OmegaConf, gnn_err_fn): - dtype = torch.float32 - n_agents = cfg["setup"]["num_agents"] - - # create variables, one per agent - positions = [] - for i in range(n_agents): - init = torch.normal(torch.zeros(2), cfg["setup"]["init_std"]) - position = th.Point2(tensor=init, name=f"agent_{i}") - positions.append(position) - - objective = th.Objective(dtype=dtype) - - # prior factor drawing each robot to the origin - origin = th.Point2(name="origin") - origin_weight = th.ScaleCostWeight( - torch.tensor([cfg["setup"]["origin_weight"]], dtype=dtype) - ) - for i in range(n_agents): - origin_cf = th.Difference( - positions[i], - origin, - origin_weight, - name=f"origin_pull_{i}", - ) - objective.add(origin_cf) - - # create collision factors, fully connected - radius = th.Vector( - tensor=torch.tensor([cfg["setup"]["collision_radius"]]), name="radius" - ) - collision_weight = th.ScaleCostWeight( - torch.tensor([cfg["setup"]["collision_weight"]], dtype=dtype) - ) - for i in range(n_agents): - for j in range(i + 1, n_agents): - collision_cf = TwoAgentsCollision( - weight=collision_weight, - var1=positions[i], - var2=positions[j], - radius=radius, - name=f"collision_{i}_{j}", - ) - objective.add(collision_cf) - - # GNN factor - GNN takes in all current belief means and outputs all targets - target_weight = th.ScaleCostWeight( - torch.tensor([cfg["setup"]["gnn_target_weight"]], dtype=dtype) - ) - gnn_cf = GNNTargets( - weight=target_weight, - agents=positions, - gnn_err_fn=gnn_err_fn, - name="gnn_factor", - ) - objective.add(gnn_cf) - - return objective - - -class SwarmGBPAndGNN(nn.Module): - def __init__(self, cfg): - super().__init__() - - n_agents = cfg["setup"]["num_agents"] - self.gnn = SimpleMLP( - input_dim=2 * n_agents, - output_dim=2 * n_agents, - hidden_dim=64, - hidden_layers=2, - scale_output=1.0, - ) - - # setup objective, optimizer and theseus layer - objective = setup_problem(cfg, self._gnn_err_fn) - vectorize = cfg["optim"]["vectorize"] - optimizer = OPTIMIZER_CLASS[cfg["optim"]["optimizer_cls"]]( - objective, - max_iterations=cfg["optim"]["max_iters"], - vectorize=vectorize, - ) - self.layer = th.TheseusLayer(optimizer, vectorize=vectorize) - - # put on device - if cfg["device"] == "cuda": - cfg["device"] = "cuda" if torch.cuda.is_available() else "cpu" - self.gnn.to(cfg["device"]) - self.layer.to(cfg["device"]) - - # optimizer arguments - optim_arg = { - "track_best_solution": False, - "track_err_history": True, - "track_state_history": True, - "verbose": True, - "backward_mode": th.BackwardMode.UNROLL, - } - if isinstance(optimizer, GaussianBeliefPropagation): - gbp_args = cfg["optim"]["gbp_settings"].copy() - lin_system_damping = torch.tensor( - [cfg["optim"]["gbp_settings"]["lin_system_damping"]], - dtype=torch.float32, - ) - lin_system_damping.to(device=cfg["device"]) - gbp_args["lin_system_damping"] = lin_system_damping - gbp_args["schedule"] = GBP_SCHEDULE[gbp_args["schedule"]] - optim_arg = {**optim_arg, **gbp_args} - self.optim_arg = optim_arg - - # fixed inputs to theseus layer - self.inputs = {} - for agent in objective.optim_vars.values(): - self.inputs[agent.name] = agent.tensor.clone() - - # network outputs offset for target from agent position - # cost is zero when offset is zero, i.e. agent is at the target - def _gnn_err_fn(self, positions: List[th.Manifold]): - flattened_pos = torch.cat( - [pos.tensor.unsqueeze(1) for pos in positions], dim=1 - ).flatten(1, 2) - offsets = self.gnn(flattened_pos) - return offsets - - def forward(self, track_history=False): - - optim_arg = self.optim_arg.copy() - optim_arg["track_state_history"] = track_history - - outputs, info = self.layer.forward( - input_tensors=self.inputs, - optimizer_kwargs=optim_arg, - ) - - history = None - if track_history: - - history = info.state_history - - # recover target history - agent_histories = torch.cat( - [state_hist.unsqueeze(1) for state_hist in history.values()], dim=1 - ) - history["agent_0"] - - batch_size = agent_histories.shape[0] - ts = agent_histories.shape[-1] - agent_histories = agent_histories.permute( - 0, 3, 1, 2 - ) # time dim is second dim - agent_histories = agent_histories.flatten(-2, -1) - target_hist = self.gnn(agent_histories) - target_hist = target_hist.reshape(batch_size, ts, -1, 2) - target_hist = target_hist.permute(0, 2, 3, 1) # time back to last dim - - for i in range(target_hist.shape[1]): - history[f"target_{i}"] = -target_hist[:, i] + history[f"agent_{i}"] - - return outputs, history - - -def main(cfg: omegaconf.OmegaConf): - - sdf = gen_target_sdf(cfg) - - model = SwarmGBPAndGNN(cfg) - - outer_optimizer = OUTER_OPTIMIZER_CLASS[cfg["outer_optim"]["optimizer"]]( - model.gnn.parameters(), lr=cfg["outer_optim"]["lr"] - ) - - viewer = SwarmViewer(cfg["setup"]["collision_radius"], cfg["setup"]["vis_limits"]) - - losses = [] - for epoch in range(cfg["outer_optim"]["num_epochs"]): - print(f" ******************* EPOCH {epoch} ******************* ") - start_time = time.time_ns() - outer_optimizer.zero_grad() - - track_history = False # epoch % 20 == 0 - outputs, history = model.forward(track_history=track_history) - - loss = target_char_loss(outputs, sdf) - - loss.backward() - outer_optimizer.step() - losses.append(loss.detach().item()) - end_time = time.time_ns() - - print(f"Loss {losses[-1]}") - print(f"Epoch took {(end_time - start_time) / 1e9: .3f} seconds") - - if track_history: - viewer.vis_inner_optim(history, target_sdf=sdf, show_edges=False) - - print("Loss values:", losses) - - import ipdb - - ipdb.set_trace() - - # outputs visualisations - # viewer.vis_outer_targets_optim( - # targets_history, - # target_sdf=sdf, - # video_file=cfg["outer_optim_video_file"], - # ) - - -if __name__ == "__main__": - - cfg = { - "seed": 0, - "device": "cpu", - "out_video_file": "outputs/swarm/inner_mlp.gif", - "outer_optim_video_file": "outputs/swarm/outer_targets_mlp.gif", - "setup": { - "num_agents": 50, - "init_std": 1.0, - "agent_radius": 0.1, - "collision_radius": 1.0, - "origin_weight": 0.1, - "collision_weight": 1.0, - "gnn_target_weight": 10.0, - "area_limits": [[-20, -20], [20, 20]], - "vis_limits": [[-3, -3], [3, 3]], - }, - "optim": { - "max_iters": 20, - "vectorize": True, - "optimizer_cls": "gbp", - # "optimizer_cls": "gauss_newton", - # "optimizer_cls": "levenberg_marquardt", - "gbp_settings": { - "relin_threshold": 1e-8, - "ftov_msg_damping": 0.0, - "dropout": 0.0, - "schedule": "synchronous", - "lin_system_damping": 1.0e-2, - "nesterov": False, - }, - }, - "outer_optim": { - "num_epochs": 100, - "lr": 2e-2, - "optimizer": "adam", - "target_char": "A", - }, - } - - torch.manual_seed(cfg["seed"]) - np.random.seed(cfg["seed"]) - random.seed(cfg["seed"]) - - main(cfg) diff --git a/theseus/optimizer/gbp/swarm_viewer.py b/theseus/optimizer/gbp/swarm_viewer.py deleted file mode 100644 index f98893810..000000000 --- a/theseus/optimizer/gbp/swarm_viewer.py +++ /dev/null @@ -1,226 +0,0 @@ -import numpy as np -import shutil -import os -import torch - -import pygame - - -class SwarmViewer: - def __init__( - self, - collision_radius, - area_limits, - ): - self.agent_cols = None - self.scale = 100 - self.agent_r_pix = collision_radius / 20 * self.scale - self.collision_radius = collision_radius - self.target_sdf = None - - self.range = np.array(area_limits) - self.h = (self.range[1, 1] - self.range[0, 1]) * self.scale - self.w = (self.range[1, 0] - self.range[0, 0]) * self.scale - - self.video_file = None - - pygame.init() - pygame.display.set_caption("Swarm") - self.myfont = pygame.font.SysFont("Jokerman", 40) - self.screen = pygame.display.set_mode([self.w, self.h]) - - def vis_target_step( - self, - targets_history, - target_sdf, - ): - self.state_history = targets_history - self.t = (~list(targets_history.values())[0].isinf()[0, 0]).sum().item() - 1 - self.num_iters = self.t + 1 - - self.targets = None - self.show_edges = False - self.width = 3 - self.target_sdf = target_sdf - - self.draw_next() - - def prepare_video(self, video_file): - self.video_file = video_file - if self.video_file is not None: - self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" - self.save_ix = 0 - if os.path.exists(self.tmp_dir): - shutil.rmtree(self.tmp_dir) - os.mkdir(self.tmp_dir) - - def vis_inner_optim( - self, - history, - target_sdf=None, - show_edges=True, - video_file=None, - ): - self.prepare_video(video_file) - - self.state_history = {k: v for k, v in history.items() if "agent" in k} - self.target_history = {k: v for k, v in history.items() if "target" in k} - - self.t = 0 - self.num_iters = (~list(self.state_history.values())[0].isinf()[0, 0]).sum() - - self.show_edges = show_edges - self.width = 0 - self.target_sdf = target_sdf - - self.run() - - def vis_outer_targets_optim( - self, - targets_history, - target_sdf=None, - video_file=None, - ): - self.state_history = targets_history - self.t = 0 - self.num_iters = list(targets_history.values())[0].shape[-1] - - self.video_file = video_file - if self.video_file is not None: - self.tmp_dir = "/".join(self.video_file.split("/")[:-1]) + "/tmp" - self.save_ix = 0 - if os.path.exists(self.tmp_dir): - shutil.rmtree(self.tmp_dir) - os.mkdir(self.tmp_dir) - - self.targets = None - self.show_edges = False - self.width = 3 - self.target_sdf = target_sdf - - self.run() - - def run(self): - self.draw_next() - - running = True - while running: - - # Did the user click the window close button? - for event in pygame.event.get(): - if event.type == pygame.QUIT: - running = False - if event.type == pygame.KEYDOWN: - if event.key == pygame.K_SPACE: - self.draw_next() - - def draw_next(self): - if self.agent_cols is None: - self.agent_cols = [ - tuple(np.random.choice(range(256), size=3)) - for i in range(len(self.state_history)) - ] - - if self.t < self.num_iters: - self.screen.fill((255, 255, 255)) - - # draw target shape as background - if self.target_sdf is not None: - sdf = self.target_sdf.sdf_data.tensor[0].transpose(0, 1) - sdf_size = self.target_sdf.cell_size.tensor.item() * sdf.shape[0] - area_size = self.range[1, 0] - self.range[0, 0] - crop = np.round((1 - area_size / sdf_size) * sdf.shape[0] / 2).astype( - int - ) - sdf = sdf[crop:-crop, crop:-crop] - sdf = torch.flip( - sdf, [1] - ) # flip vertically so y is increasing going up - repeats = self.screen.get_width() // sdf.shape[0] - sdf = torch.repeat_interleave(sdf, repeats, dim=0) - sdf = torch.repeat_interleave(sdf, repeats, dim=1) - sdf = sdf.detach().cpu().numpy() - bg_img = np.zeros([*sdf.shape, 3]) - bg_img[sdf > 0] = 255 - bg_img[sdf <= 0] = [144, 238, 144] - bg = pygame.surfarray.make_surface(bg_img) - self.screen.blit(bg, (0, 0)) - - # draw agents - for i, state in enumerate(self.state_history.values()): - pos = state[0, :, self.t].detach().cpu().numpy() - centre = self.pos_to_canvas(pos) - pygame.draw.circle( - self.screen, - self.agent_cols[i], - centre, - self.agent_r_pix, - self.width, - ) - - # draw edges between agents - if self.show_edges: - for i, state1 in enumerate(self.state_history.values()): - pos1 = state1[0, :, self.t].detach().cpu().numpy() - j = 0 - for state2 in self.state_history.values(): - if j <= i: - j += 1 - continue - pos2 = state2[0, :, self.t].detach().cpu().numpy() - dist = np.linalg.norm(pos1 - pos2) - if dist < self.collision_radius: - start = self.pos_to_canvas(pos1) - end = self.pos_to_canvas(pos2) - pygame.draw.line(self.screen, (0, 0, 0), start, end) - - # draw agents - for i, state in enumerate(self.target_history.values()): - pos = state[0, :, self.t].detach().cpu().numpy() - centre = self.pos_to_canvas(pos) - pygame.draw.circle( - self.screen, self.agent_cols[i], centre, self.agent_r_pix, 3 - ) - - # draw line between agent and target - - # draw text - ssshow = self.myfont.render( - f"t = {self.t} / {self.num_iters - 1}", True, (0, 0, 0) - ) - self.screen.blit(ssshow, (10, 10)) # choose location of text - - pygame.display.flip() - - if self.video_file: - self.save_image() - - self.t += 1 - - elif self.t == self.num_iters and self.video_file: - if self.video_file[-3:] == "mp4": - os.system( - f"ffmpeg -r 4 -i {self.tmp_dir}/%06d.png -vcodec mpeg4 -y {self.video_file}" - ) - elif self.video_file[-3:] == "gif": - os.system( - f"ffmpeg -i {self.tmp_dir}/%06d.png -vf palettegen {self.tmp_dir}/palette.png" - ) - os.system( - f"ffmpeg -r 4 -i {self.tmp_dir}/%06d.png -i {self.tmp_dir}/palette.png" - f" -lavfi paletteuse {self.video_file}" - ) - else: - raise ValueError("video file must be either mp4 or gif.") - shutil.rmtree(self.tmp_dir) - self.t += 1 - - def pos_to_canvas(self, pos): - x = (pos - self.range[0]) / (self.range[1] - self.range[0]) - x[1] = 1 - x[1] - return x * np.array([self.h, self.w]) - - def save_image(self): - fname = self.tmp_dir + f"/{self.save_ix:06d}.png" - pygame.image.save(self.screen, fname) - self.save_ix += 1 diff --git a/theseus/optimizer/gbp/vectorize_poc.py b/theseus/optimizer/gbp/vectorize_poc.py deleted file mode 100644 index 51b172161..000000000 --- a/theseus/optimizer/gbp/vectorize_poc.py +++ /dev/null @@ -1,119 +0,0 @@ -import torch - -import theseus as th -from theseus.optimizer.gbp import GaussianBeliefPropagation, GBPSchedule - -torch.manual_seed(0) - - -def generate_data(num_points=100, a=1, b=0.5, noise_factor=0.01): - # Generate data: 100 points sampled from the quadratic curve listed above - data_x = torch.rand((1, num_points)) - noise = torch.randn((1, num_points)) * noise_factor - data_y = a * data_x.square() + b + noise - return data_x, data_y - - -def generate_learning_data(num_points, num_models): - a, b = 3, 1 - data_batches = [] - for i in range(num_models): - b = b + 2 - data = generate_data(num_points, a, b) - data_batches.append(data) - return data_batches - - -num_models = 10 -data_batches = generate_learning_data(100, num_models) - - -# updated error function reflects change in 'a' -def quad_error_fn2(optim_vars, aux_vars): - [a, b] = optim_vars - x, y = aux_vars - est = a.tensor * x.tensor.square() + b.tensor - err = y.tensor - est - return err - - -# The theseus_inputs dictionary is also constructed similarly to before, -# but with data matching the new shapes of the variables -def construct_theseus_layer_inputs(): - theseus_inputs = {} - theseus_inputs.update( - { - "x": data_x, - "y": data_y, - "b": torch.ones((num_models, 1)), - "a": a_tensor, - } - ) - return theseus_inputs - - -# convert data_x, data_y into torch.tensors of shape [num_models, 100] -data_x = torch.stack([data_x.squeeze() for data_x, _ in data_batches]) -data_y = torch.stack([data_y.squeeze() for _, data_y in data_batches]) - -# construct one variable each of x, y of shape [num_models, 100] -x = th.Variable(data_x, name="x") -y = th.Variable(data_y, name="y") - -# construct a as before -a = th.Vector(tensor=torch.rand(num_models, 1), name="a") - -# construct one variable b, now of shape [num_models, 1] -b = th.Vector(tensor=torch.rand(num_models, 1), name="b") - -# Again, 'b' is the only optim_var, and 'a' is part of aux_vars along with x, y -aux_vars = [x, y] - -# cost function constructed as before -cost_function = th.AutoDiffCostFunction( - [a, b], quad_error_fn2, 100, aux_vars=aux_vars, name="quadratic_cost_fn" -) - -prior_weight = th.ScaleCostWeight(torch.ones(1)) -prior_a = th.Difference(a, th.Vector(1), prior_weight) -prior_b = th.Difference(b, th.Vector(1), prior_weight) - -# objective, optimizer and theseus layer constructed as before -objective = th.Objective() -objective.add(cost_function) -objective.add(prior_a) -objective.add(prior_b) - -print([cf.name for cf in objective.cost_functions.values()]) - -vectorize = True - -optimizer = GaussianBeliefPropagation( - objective, - max_iterations=50, # step_size=0.5, - vectorize=vectorize, -) - -theseus_optim = th.TheseusLayer(optimizer, vectorize=vectorize) - -a_tensor = torch.nn.Parameter(torch.rand(num_models, 1)) - - -optim_arg = { - "track_best_solution": True, - "track_err_history": True, - "verbose": True, - "backward_mode": th.BackwardMode.FULL, - "relin_threshold": 0.0000000001, - "damping": 0.5, - "dropout": 0.0, - "schedule": GBPSchedule.SYNCHRONOUS, - "lin_system_damping": 1e-5, -} - - -theseus_inputs = construct_theseus_layer_inputs() -print("inputs\n", theseus_inputs["a"], theseus_inputs["x"].shape) -updated_inputs, _ = theseus_optim.forward(theseus_inputs, optim_arg) - -print(updated_inputs) From 82fc597ec64cc86d82352e7d275accb63a6072f4 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 4 Jan 2023 10:18:16 +0000 Subject: [PATCH 50/64] Moved import, fixed single wrapper vectorization --- theseus/__init__.py | 1 + theseus/core/vectorizer.py | 8 ++++---- theseus/optimizer/gbp.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/theseus/__init__.py b/theseus/__init__.py index db25c3818..9560ae738 100644 --- a/theseus/__init__.py +++ b/theseus/__init__.py @@ -64,6 +64,7 @@ ) from .optimizer import ( # usort: skip DenseLinearization, + GaussianBeliefPropagation, Linearization, ManifoldGaussian, OptimizerInfo, diff --git a/theseus/core/vectorizer.py b/theseus/core/vectorizer.py index 315a66fa5..93bb4a51e 100644 --- a/theseus/core/vectorizer.py +++ b/theseus/core/vectorizer.py @@ -393,10 +393,10 @@ def _vectorize( } ret = [cf for cf_list in schema_dict.values() for cf in cf_list] for schema, cost_fn_wrappers in schema_dict.items(): - if len(cost_fn_wrappers) == 1: - self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode) - else: - self._handle_schema_vectorization(schema, cost_fn_wrappers, mode) + # if len(cost_fn_wrappers) == 1: + # self._handle_singleton_wrapper(schema, cost_fn_wrappers, mode) + # else: + self._handle_schema_vectorization(schema, cost_fn_wrappers, mode) return ret def _get_vectorized_error_iter(self) -> Iterable[_CostFunctionWrapper]: diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 2da9509de..dfa315ef6 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -893,7 +893,7 @@ def _create_factors_beliefs(self, lin_system_damping): cf_iterator = iter(self.objective.vectorized_cost_fns) self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_vectorized else: - cf_iterator = self.objective._get_iterator() + cf_iterator = iter(self.objective) self._pass_var_to_fac_messages = self._pass_var_to_fac_messages_loop # compute factor potentials for the first time @@ -1032,12 +1032,12 @@ def _optimize_loop( self.objective.update_vectorization_if_needed() t_vec = time.time() - t1 - t_tot = time.time() - t0 - if verbose: - print( - f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," - f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" - ) + # if verbose: + # t_tot = time.time() - t0 + # print( + # f"Timings ----- relin {t_relin:.4f}, ftov {t_ftov:.4f}, vtof {t_vtof:.4f}," + # f" vectorization {t_vec:.4f}, TOTAL {t_tot:.4f}" + # ) # check for convergence if it_ >= 0: @@ -1210,7 +1210,7 @@ def _optimize_impl( gauss_newton_optimizer = th.GaussNewton(self.objective) gauss_newton_optimizer.linear_solver.linearization.linearize() delta = gauss_newton_optimizer.linear_solver.solve() - self.objective.retract_optim_vars( + self.objective.retract_vars_sequence( delta * implicit_step_size, gauss_newton_optimizer.linear_solver.linearization.ordering, force_update=True, From 8a011549454db8b6225efce8eaed6a6e337b63e6 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 4 Jan 2023 10:54:32 +0000 Subject: [PATCH 51/64] update vectorization before truncated steps --- theseus/optimizer/gbp.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index dfa315ef6..79fd2d558 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -973,6 +973,9 @@ def _optimize_loop( # not automatically updated when objective.update is called. if clear_messages: self._create_factors_beliefs(lin_system_damping) + else: + self.objective.update_vectorization_if_needed() + if implicit_gbp_loop: relin_threshold = 1e10 # no relinearisation if self.objective.vectorized: From 0644e94875b2cb6078a31c0948979ce6f29020b0 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 4 Jan 2023 17:09:10 +0000 Subject: [PATCH 52/64] Moved bundle adjustment edits to experimental branch --- .../utils/examples/bundle_adjustment/data.py | 60 ++++++------------- .../utils/examples/bundle_adjustment/util.py | 23 ------- 2 files changed, 17 insertions(+), 66 deletions(-) diff --git a/theseus/utils/examples/bundle_adjustment/data.py b/theseus/utils/examples/bundle_adjustment/data.py index 42502f200..3c4a791b0 100644 --- a/theseus/utils/examples/bundle_adjustment/data.py +++ b/theseus/utils/examples/bundle_adjustment/data.py @@ -39,11 +39,6 @@ def to_params(self) -> List[float]: float(self.calib_k2[0, 0]), ] - def position(self) -> torch.Tensor: - R = self.pose.tensor[:, :, :3].squeeze(0) - t = self.pose.tensor[:, :, 3].squeeze(0) - return -R.T @ t - @staticmethod def from_params(params: List[float], name: str = "Cam") -> "Camera": r = th.SO3.exp_map(torch.tensor(params[:3], dtype=torch.float64).unsqueeze(0)) @@ -169,7 +164,7 @@ def __init__( self.gt_points = gt_points @staticmethod - def load_bal_dataset(path: str, drop_obs=0.0): + def load_bal_dataset(path: str): observations = [] cameras = [] points = [] @@ -177,41 +172,26 @@ def load_bal_dataset(path: str, drop_obs=0.0): num_cameras, num_points, num_observations = [ int(x) for x in out.readline().rstrip().split() ] - - fields = out.readline().rstrip().split() - intrinsics = None - if "." in fields[0]: - intrinsics = [ - (float(fields[0]) + float(fields[1])) / 2.0, - float(fields[2]), - float(fields[3]), - ] - for i in range(num_observations): - if i > 0 or intrinsics is not None: - fields = out.readline().rstrip().split() - if np.random.rand() > drop_obs: - feat = th.Point2( - tensor=torch.tensor( - [float(fields[2]), float(fields[3])], dtype=torch.float64 - ).unsqueeze(0), - name=f"Feat{i}", - ) - observations.append( - Observation( - camera_index=int(fields[0]), - point_index=int(fields[1]), - image_feature_point=feat, - ) + fields = out.readline().rstrip().split() + feat = th.Point2( + tensor=torch.tensor( + [float(fields[2]), float(fields[3])], dtype=torch.float64 + ).unsqueeze(0), + name=f"Feat{i}", + ) + observations.append( + Observation( + camera_index=int(fields[0]), + point_index=int(fields[1]), + image_feature_point=feat, ) + ) for i in range(num_cameras): params = [] - n_params = 6 if intrinsics is not None else 9 - for _ in range(n_params): + for _ in range(9): params.append(float(out.readline().rstrip())) - if intrinsics is not None: - params.extend(intrinsics) cameras.append(Camera.from_params(params, name=f"Cam{i}")) for i in range(num_points): @@ -298,9 +278,6 @@ def generate_synthetic( feat_random: float = 1.5, prob_feat_is_outlier: float = 0.02, outlier_feat_random: float = 70, - cam_pos_rand: float = 0.2, - cam_rot_rand: float = 0.1, - point_rand: float = 0.2, ): # add cameras @@ -313,10 +290,7 @@ def generate_synthetic( ) for i in range(num_cameras) ] - cameras = [ - cam.perturbed(rot_random=cam_rot_rand, pos_random=cam_pos_rand) - for cam in gt_cameras - ] + cameras = [cam.perturbed() for cam in gt_cameras] # add points gt_points = [ @@ -329,7 +303,7 @@ def generate_synthetic( ] points = [ th.Point3( - tensor=gt_points[i].tensor + (torch.rand((1, 3)) * 2 - 1) * point_rand, + tensor=gt_points[i].tensor + (torch.rand((1, 3)) * 2 - 1) * 0.2, name=gt_points[i].name + "_copy", ) for i in range(num_points) diff --git a/theseus/utils/examples/bundle_adjustment/util.py b/theseus/utils/examples/bundle_adjustment/util.py index e670e91a6..bd30d7447 100644 --- a/theseus/utils/examples/bundle_adjustment/util.py +++ b/theseus/utils/examples/bundle_adjustment/util.py @@ -31,29 +31,6 @@ def soft_loss_huber_like( return val, der -# For reprojection cost functions where the loss is 2 dimensional, -# x and y pixel loss, but the robust loss region is determined -# by the norm of the (x, y) pixel loss. -def soft_loss_huber_like_reprojection( - x: torch.Tensor, radius: torch.Tensor -) -> Tuple[torch.Tensor, torch.Tensor]: - x_norm = torch.norm(x, dim=1).unsqueeze(1) - val, der = soft_loss_huber_like(x_norm, radius) - scaling = val / x_norm - x_loss = x * scaling - - term1 = scaling[..., None] * torch.eye(2, dtype=x.dtype, device=x.device).reshape( - 1, 2, 2 - ).repeat(x.shape[0], 1, 1) - term2 = ( - torch.bmm(x.unsqueeze(2), x.unsqueeze(1)) - * ((der - scaling) / (x_norm**2))[..., None] - ) - der = term1 + term2 - - return x_loss, der - - # ------------------------------------------------------------ # # ----------------------------- RNG -------------------------- # # ------------------------------------------------------------ # From 4a8190120f885dd87982d8847870c149cc8dd93a Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Wed, 4 Jan 2023 18:05:32 +0000 Subject: [PATCH 53/64] flake8 on github not gitlab --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba0f3b271..dace5302b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,7 +4,7 @@ repos: hooks: - id: black - - repo: https://gitlab.com/pycqa/flake8 + - repo: https://github.com/pycqa/flake8 rev: 3.9.2 hooks: - id: flake8 From 4489141b698a9d4fff4d61ef63fb87bcb18cf4be Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Thu, 5 Jan 2023 08:58:15 +0000 Subject: [PATCH 54/64] Remove nesterov acceleration and timing --- theseus/optimizer/gbp.py | 124 +++++---------------------------------- 1 file changed, 15 insertions(+), 109 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 79fd2d558..a68dc4dc8 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -5,7 +5,8 @@ import abc import math -import time + +# import time import warnings from dataclasses import dataclass from enum import Enum @@ -38,41 +39,6 @@ """ -# as in https://blogs.princeton.edu/imabandit/2013/04/01/acceleratedgradientdescent/ -def next_nesterov_params(lam) -> Tuple[float, float]: - new_lambda = (1 + np.sqrt(4 * lam * lam + 1)) / 2.0 - new_gamma = (lam - 1) / new_lambda - return new_lambda, new_gamma - - -def apply_nesterov( - y_curr: th.Manifold, - y_last: th.Manifold, - nesterov_gamma: float, - normalize_method: bool = True, -) -> th.Manifold: - if normalize_method: - # apply to tensors and then project back to closest group element - nesterov_mean_tensor = ( - 1 + nesterov_gamma - ) * y_curr.tensor - nesterov_gamma * y_last.tensor - nesterov_mean_tensor = y_curr.__class__.normalize(nesterov_mean_tensor) - nesterov_mean = y_curr.__class__(tensor=nesterov_mean_tensor) - - else: - # apply nesterov damping in tanget plane. - # Cannot use new_belief or nesterov_y as the tangent plance, because tangent vector is 0. - # Use identity as tangent plane, may not be best choice as could be far from identity. - tp = y_curr.__class__(dtype=y_curr.dtype) - tp.to(y_curr.device) - tp_mean = (1 + nesterov_gamma) * tp.local(y_curr) - nesterov_gamma * tp.local( - y_last - ) - nesterov_mean = tp.retract(tp_mean) - - return nesterov_mean - - # Same of NonlinearOptimizerParams but without step size @dataclass class GBPOptimizerParams: @@ -621,13 +587,7 @@ def _merge_infos( GBP functions """ - def _pass_var_to_fac_messages_loop(self, update_belief=True, nesterov_gamma=None): - if nesterov_gamma is not None: - if nesterov_gamma == 0: # only on the first call - self.nesterov_ys = [ - belief.mean[0].copy(new_name="nesterov_y_" + belief.mean[0].name) - for belief in self.beliefs - ] + def _pass_var_to_fac_messages_loop(self, update_belief=True): for i, var in enumerate(self.ordering): # Collect all incoming messages in the tangent space at the current belief @@ -674,27 +634,9 @@ def _pass_var_to_fac_messages_loop(self, update_belief=True, nesterov_gamma=None tau = torch.matmul(inv_lam_tau, sum_taus.unsqueeze(-1)).squeeze(-1) new_belief = th.retract_gaussian(var, tau, lam_tau) - - # nesterov acceleration - if nesterov_gamma is not None: - nesterov_mean = apply_nesterov( - new_belief.mean[0], - self.nesterov_ys[i], - nesterov_gamma, - normalize_method=False, - ) - # belief mean as calculated by GBP step is the new nesterov y value at this step - self.nesterov_ys[i] = new_belief.mean[0].copy() - # use nesterov mean for new belief - new_belief.update( - mean=[nesterov_mean], precision=new_belief.precision - ) - self.beliefs[i].update(new_belief.mean, new_belief.precision) - def _pass_var_to_fac_messages_vectorized( - self, update_belief=True, nesterov_gamma=None - ): + def _pass_var_to_fac_messages_vectorized(self, update_belief=True): # Each (variable-type, dof) gets mapped to a tuple with: # - the variable that will hold the vectorized data # - all the variables of that type that will be vectorized together @@ -739,10 +681,6 @@ def _pass_var_to_fac_messages_vectorized( lam_tp_acc = lam_tp_acc.to(vectorized_data.device, vectorized_data.dtype) eta_lam.extend([eta_tp_acc, lam_tp_acc]) - if nesterov_gamma is not None: - if nesterov_gamma == 0: # only on the first call - self.nesterov_ys = [info[0].copy() for info in var_info.values()] - # add ftov messages to eta_tp and lam_tp accumulator tensors for factor in self.factors: for i, msg in enumerate(factor.ftov_msgs): @@ -817,7 +755,6 @@ def _pass_var_to_fac_messages_vectorized( msg.update(new_mess.mean, new_mess.precision) # compute the new belief for the vectorized variables - i = 0 for (vectorized_var, _, var_ixs, eta_lam) in var_info.values(): eta_tp_acc = eta_lam[0] lam_tau = eta_lam[1] @@ -832,22 +769,6 @@ def _pass_var_to_fac_messages_vectorized( new_belief = th.retract_gaussian(vectorized_var, tau, lam_tau) - # nesterov acceleration - if nesterov_gamma is not None: - nesterov_mean = apply_nesterov( - new_belief.mean[0], - self.nesterov_ys[i], - nesterov_gamma, - normalize_method=False, - ) - # belief mean as calculated by GBP step is the new nesterov y value at this step - self.nesterov_ys[i] = new_belief.mean[0].copy() - # use nesterov mean for new belief - new_belief.update( - mean=[nesterov_mean], precision=new_belief.precision - ) - i += 1 - # update non vectorized beliefs with slices start_idx = 0 for ix in var_ixs: @@ -964,7 +885,6 @@ def _optimize_loop( dropout: float, schedule: GBPSchedule, lin_system_damping: torch.Tensor, - nesterov: bool, clear_messages: bool = True, implicit_gbp_loop: bool = False, **kwargs, @@ -985,9 +905,6 @@ def _optimize_loop( if schedule == GBPSchedule.SYNCHRONOUS: ftov_schedule = synchronous_schedule(num_iter, self.n_edges) - if nesterov: - nest_lambda, nest_gamma = next_nesterov_params(0.0) - self.ftov_msgs_history = {} converged_indices = torch.zeros_like(info.last_err).bool() @@ -1011,29 +928,23 @@ def _optimize_loop( dropout_ixs = torch.rand(self.n_edges) < dropout ftov_schedule[it_, dropout_ixs] = False - t0 = time.time() + # t0 = time.time() relins = self._linearize_factors(relin_threshold) - t_relin = time.time() - t0 + # t_relin = time.time() - t0 - t1 = time.time() + # t1 = time.time() self._pass_fac_to_var_messages(ftov_schedule[it_], ftov_damping_arr) - t_ftov = time.time() - t1 - - t1 = time.time() - nest_gamma = None - if nesterov: - nest_lambda, nest_gamma = next_nesterov_params(nest_lambda) - print("nesterov gamma", nest_gamma) - self._pass_var_to_fac_messages( - update_belief=True, nesterov_gamma=nest_gamma - ) - t_vtof = time.time() - t1 + # t_ftov = time.time() - t1 + + # t1 = time.time() + self._pass_var_to_fac_messages(update_belief=True) + # t_vtof = time.time() - t1 - t_vec = 0.0 + # t_vec = 0.0 if self.objective.vectorized: - t1 = time.time() + # t1 = time.time() self.objective.update_vectorization_if_needed() - t_vec = time.time() - t1 + # t_vec = time.time() - t1 # if verbose: # t_tot = time.time() - t0 @@ -1080,7 +991,6 @@ def _optimize_impl( dropout: float = 0.0, schedule: GBPSchedule = GBPSchedule.SYNCHRONOUS, lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), - nesterov: bool = False, implicit_step_size: float = 1e-4, implicit_method: str = "gbp", **kwargs, @@ -1132,7 +1042,6 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, - nesterov=nesterov, **kwargs, ) @@ -1181,7 +1090,6 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, - nesterov=nesterov, **kwargs, ) @@ -1199,7 +1107,6 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, - nesterov=nesterov, clear_messages=False, **kwargs, ) @@ -1236,7 +1143,6 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, - nesterov=nesterov, clear_messages=False, implicit_gbp_loop=True, **kwargs, From 9d5d5f73dff2affa802b1c6c8533d9c2cbf60a64 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 6 Jan 2023 10:44:25 +0000 Subject: [PATCH 55/64] End of iter callback, updated mypy version --- .pre-commit-config.yaml | 2 +- theseus/optimizer/gbp.py | 25 ++++++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dace5302b..e902f6ad1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,7 +10,7 @@ repos: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.931 + rev: v0.991 hooks: - id: mypy additional_dependencies: [torch, tokenize-rt==3.2.0, types-PyYAML, types-mock] diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index a68dc4dc8..7fab88404 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -11,7 +11,17 @@ from dataclasses import dataclass from enum import Enum from itertools import count -from typing import Dict, List, Optional, Sequence, Tuple, Type, Union +from typing import ( + Callable, + Dict, + List, + NoReturn, + Optional, + Sequence, + Tuple, + Type, + Union, +) import numpy as np import torch @@ -38,6 +48,10 @@ Utitily functions """ +EndIterCallbackType = Callable[ + ["GaussianBeliefPropagation", NonlinearOptimizerInfo, None, int], NoReturn +] + # Same of NonlinearOptimizerParams but without step size @dataclass @@ -887,6 +901,7 @@ def _optimize_loop( lin_system_damping: torch.Tensor, clear_messages: bool = True, implicit_gbp_loop: bool = False, + end_iter_callback: Optional[EndIterCallbackType] = None, **kwargs, ): # we only create the factors and beliefs right before runnig GBP as they are @@ -971,6 +986,9 @@ def _optimize_loop( break # nothing else will happen at this point info.last_err = err + if end_iter_callback is not None: + end_iter_callback(self, info, None, it_) + info.status[ info.status == NonlinearOptimizerStatus.START ] = NonlinearOptimizerStatus.MAX_ITERATIONS @@ -993,6 +1011,7 @@ def _optimize_impl( lin_system_damping: torch.Tensor = torch.Tensor([1e-4]), implicit_step_size: float = 1e-4, implicit_method: str = "gbp", + end_iter_callback: Optional[EndIterCallbackType] = None, **kwargs, ) -> NonlinearOptimizerInfo: backward_mode = BackwardMode.resolve(backward_mode) @@ -1042,6 +1061,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + end_iter_callback=end_iter_callback, **kwargs, ) @@ -1090,6 +1110,7 @@ def _optimize_impl( dropout=dropout, schedule=schedule, lin_system_damping=lin_system_damping, + end_iter_callback=end_iter_callback, **kwargs, ) @@ -1108,6 +1129,7 @@ def _optimize_impl( schedule=schedule, lin_system_damping=lin_system_damping, clear_messages=False, + end_iter_callback=end_iter_callback, **kwargs, ) # Adds grad_loop_info results to original info @@ -1145,6 +1167,7 @@ def _optimize_impl( lin_system_damping=lin_system_damping, clear_messages=False, implicit_gbp_loop=True, + end_iter_callback=end_iter_callback, **kwargs, ) From 748fe38a2c27295124e6f9a29bd168461fddbc01 Mon Sep 17 00:00:00 2001 From: joeaortiz Date: Fri, 6 Jan 2023 15:13:52 +0000 Subject: [PATCH 56/64] First attempt at GBP linear solver test --- .../linear/test_gbp_linear_solver.py | 157 ++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 tests/optimizer/linear/test_gbp_linear_solver.py diff --git a/tests/optimizer/linear/test_gbp_linear_solver.py b/tests/optimizer/linear/test_gbp_linear_solver.py new file mode 100644 index 000000000..a6122634e --- /dev/null +++ b/tests/optimizer/linear/test_gbp_linear_solver.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import theseus as th + + +""" +Build linear 1D surface estimation problem. +Solve using GBP and using matrix inversion and compare answers. +GBP exactly computes the marginal means on convergence. + +Test cases: +- with / without vectorization +- with / without factor to variable message damping +- with / without factor linear system damping +""" + + +def _check_info(info, batch_size, max_iterations, initial_error, objective): + assert info.err_history.shape == (batch_size, max_iterations + 1) + assert info.err_history[:, 0].allclose(initial_error) + assert info.err_history.argmin(dim=1).allclose(info.best_iter + 1) + last_error = objective.error_squared_norm() / 2 + last_convergence_idx = info.converged_iter.max().item() + assert info.err_history[:, last_convergence_idx].allclose(last_error) + + +def run_gbp_linear_solver( + vectorize, + optimize_kwargs, +): + max_iterations = 20 + + n_variables = 100 + batch_size = 1 + + rng = torch.Generator() + rng.manual_seed(0) + + variables = [] + meas_vars = [] + for i in range(n_variables): + variables.append( + th.Vector(tensor=torch.rand(batch_size, 1, generator=rng), name=f"x_{i}") + ) + meas_vars.append( + th.Vector( + tensor=torch.rand(batch_size, 1, generator=rng), name=f"meas_x{i}" + ) + ) + + objective = th.Objective() + # measurement cost functions + meas_weight = th.ScaleCostWeight(1.0, name="meas_weight") + for var, meas in zip(variables, meas_vars): + objective.add(th.Difference(var, meas, meas_weight)) + # smoothness cost functions between adjacent variables and between random variables + smoothness_weight = th.ScaleCostWeight(4.0, name="smoothness_weight") + zero = th.Vector(tensor=torch.zeros(batch_size, 1), name="zero") + for i in range(n_variables - 1): + objective.add( + th.Between(variables[i], variables[i + 1], zero, smoothness_weight) + ) + for i in range(100): + ix1, ix2 = torch.randint(n_variables, (2,)) + # ix1, ix2 = 0, 2 + objective.add( + th.Between(variables[ix1], variables[ix2], zero, smoothness_weight) + ) + + # initial input tensors + measurements = torch.rand(batch_size, n_variables, generator=rng) + input_tensors = {} + for var in variables: + input_tensors[var.name] = var.tensor + for i in range(len(measurements[0])): + input_tensors[f"meas_x{i}"] = measurements[:, i][:, None] + + # GBP inference + # optimizer = th.GaussianBeliefPropagation( + # objective, max_iterations=max_iterations, vectorize=vectorize + # ) + # optimizer.set_params(max_iterations=max_iterations) + # objective.update(input_tensors) + # initial_error = objective.error_squared_norm() / 2 + + # callback_expected_iter = [0] + + # def callback(opt_, info_, None, it_): + # assert opt_ is optimizer + # assert isinstance(info_, th.optimizer.OptimizerInfo) + # assert it_ == callback_expected_iter[0] + # callback_expected_iter[0] += 1 + + # info = optimizer.optimize( + # track_best_solution=True, + # track_err_history=True, + # end_iter_callback=callback, + # verbose=True, + # **optimize_kwargs, + # ) + # gbp_solution = [var.tensor.clone() for var in variables] + + # Solve with Gauss-Newton + + def gn_callback(opt_, info_, _, it_): + out = list(opt_.objective.optim_vars.values()) + vec = torch.cat([v.tensor for v in out]) + print(vec.flatten()) + + objective.update(input_tensors) + gn_optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver) + gn_optimizer.set_params(max_iterations=max_iterations) + gn_optimizer.optimize(verbose=True, end_iter_callback=gn_callback) + gn_solution = [var.tensor.clone() for var in variables] + + # Solve with linear solver + objective.update(input_tensors) + linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) + linear_optimizer.optimize(verbose=True) + lin_solution = [var.tensor.clone() for var in variables] + lin_solve_err = objective.error_squared_norm() / 2 + print("linear solver error", lin_solve_err.item()) + + for x, x_target in zip(lin_solution, gn_solution): + print(x, x_target) + assert x.allclose(x_target) + + # print("comparing GBP") + # for x, x_target in zip(gbp_solution, gn_solution): + # print(x, x_target) + # assert x.allclose(x_target) + + # for x, x_target in zip(gbp_solution, lin_solution): + # print(x, x_target) + # assert x.allclose(x_target) + + # _check_info(info, batch_size, max_iterations, initial_error, objective) + + # # Visualise reconstructed surface + # soln_vec = torch.cat(gbp_solution, dim=1)[0] + # import matplotlib.pylab as plt + # plt.scatter(torch.arange(n_variables), soln_vec, label="solution") + # plt.scatter(torch.arange(n_variables), measurements[0], label="meas") + # plt.legend() + # plt.show() + + +def test_gbp_linear_solver(): + optimize_kwargs = {} + + # run_gbp_linear_solver(vectorize=True, optimize_kwargs=optimize_kwargs) + run_gbp_linear_solver(vectorize=False, optimize_kwargs=optimize_kwargs) From e59745117ab80d041fc6132af6357a86b5b20ad5 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 6 Jan 2023 17:33:32 +0000 Subject: [PATCH 57/64] Fixed poor conditioning problems with linear test --- .../linear/test_gbp_linear_solver.py | 174 ++++++++++-------- 1 file changed, 94 insertions(+), 80 deletions(-) diff --git a/tests/optimizer/linear/test_gbp_linear_solver.py b/tests/optimizer/linear/test_gbp_linear_solver.py index a6122634e..09e04c484 100644 --- a/tests/optimizer/linear/test_gbp_linear_solver.py +++ b/tests/optimizer/linear/test_gbp_linear_solver.py @@ -13,9 +13,10 @@ Solve using GBP and using matrix inversion and compare answers. GBP exactly computes the marginal means on convergence. -Test cases: +All the following cases should not affect the converged solution: - with / without vectorization - with / without factor to variable message damping +- with / without dropout - with / without factor linear system damping """ @@ -30,128 +31,141 @@ def _check_info(info, batch_size, max_iterations, initial_error, objective): def run_gbp_linear_solver( - vectorize, - optimize_kwargs, + frac_loops, + vectorize=True, + ftov_damping=0.0, + dropout=0.0, + lin_system_damping=torch.tensor([1e-4]), ): - max_iterations = 20 + max_iterations = 200 n_variables = 100 batch_size = 1 - rng = torch.Generator() - rng.manual_seed(0) + torch.manual_seed(0) + + # initial input tensors + # measurements come from x = sin(t / 50) * t**2 / 250 + 1 with random noise added + ts = torch.arange(n_variables) + true_meas = torch.sin(ts / 10.0) * ts * ts / 250.0 + 1 + noisy_meas = true_meas[None, :].repeat(batch_size, 1) + noisy_meas += torch.normal(torch.zeros_like(noisy_meas), 1.0) variables = [] meas_vars = [] for i in range(n_variables): - variables.append( - th.Vector(tensor=torch.rand(batch_size, 1, generator=rng), name=f"x_{i}") - ) - meas_vars.append( - th.Vector( - tensor=torch.rand(batch_size, 1, generator=rng), name=f"meas_x{i}" - ) - ) + variables.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"x_{i}")) + meas_vars.append(th.Vector(tensor=torch.rand(batch_size, 1), name=f"meas_x{i}")) objective = th.Objective() + # measurement cost functions - meas_weight = th.ScaleCostWeight(1.0, name="meas_weight") + meas_weight = th.ScaleCostWeight(5.0, name="meas_weight") for var, meas in zip(variables, meas_vars): objective.add(th.Difference(var, meas, meas_weight)) - # smoothness cost functions between adjacent variables and between random variables - smoothness_weight = th.ScaleCostWeight(4.0, name="smoothness_weight") + + # smoothness cost functions between adjacent variables + smoothness_weight = th.ScaleCostWeight(2.0, name="smoothness_weight") zero = th.Vector(tensor=torch.zeros(batch_size, 1), name="zero") for i in range(n_variables - 1): objective.add( th.Between(variables[i], variables[i + 1], zero, smoothness_weight) ) - for i in range(100): + + # difference cost functions between non-adjacent variables to give + # off diagonal elements in information matrix + difference_weight = th.ScaleCostWeight(1.0, name="difference_weight") + for i in range(int(n_variables * frac_loops)): ix1, ix2 = torch.randint(n_variables, (2,)) - # ix1, ix2 = 0, 2 + diff = th.Vector( + tensor=torch.tensor([[true_meas[ix2] - true_meas[ix1]]]), name=f"diff{i}" + ) + diff.tensor += torch.normal(torch.zeros(1, 1), 0.2) objective.add( - th.Between(variables[ix1], variables[ix2], zero, smoothness_weight) + th.Between(variables[ix1], variables[ix2], diff, difference_weight) ) - # initial input tensors - measurements = torch.rand(batch_size, n_variables, generator=rng) input_tensors = {} for var in variables: input_tensors[var.name] = var.tensor - for i in range(len(measurements[0])): - input_tensors[f"meas_x{i}"] = measurements[:, i][:, None] - - # GBP inference - # optimizer = th.GaussianBeliefPropagation( - # objective, max_iterations=max_iterations, vectorize=vectorize - # ) - # optimizer.set_params(max_iterations=max_iterations) - # objective.update(input_tensors) - # initial_error = objective.error_squared_norm() / 2 - - # callback_expected_iter = [0] - - # def callback(opt_, info_, None, it_): - # assert opt_ is optimizer - # assert isinstance(info_, th.optimizer.OptimizerInfo) - # assert it_ == callback_expected_iter[0] - # callback_expected_iter[0] += 1 - - # info = optimizer.optimize( - # track_best_solution=True, - # track_err_history=True, - # end_iter_callback=callback, - # verbose=True, - # **optimize_kwargs, - # ) - # gbp_solution = [var.tensor.clone() for var in variables] - - # Solve with Gauss-Newton - - def gn_callback(opt_, info_, _, it_): - out = list(opt_.objective.optim_vars.values()) - vec = torch.cat([v.tensor for v in out]) - print(vec.flatten()) - + for i in range(len(noisy_meas[0])): + input_tensors[f"meas_x{i}"] = noisy_meas[:, i][:, None] + + # Solve with GBP + optimizer = th.GaussianBeliefPropagation( + objective, max_iterations=max_iterations, vectorize=vectorize + ) + optimizer.set_params(max_iterations=max_iterations) objective.update(input_tensors) - gn_optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver) - gn_optimizer.set_params(max_iterations=max_iterations) - gn_optimizer.optimize(verbose=True, end_iter_callback=gn_callback) - gn_solution = [var.tensor.clone() for var in variables] + initial_error = objective.error_squared_norm() / 2 + + callback_expected_iter = [0] + + def callback(opt_, info_, _, it_): + assert opt_ is optimizer + assert isinstance(info_, th.optimizer.OptimizerInfo) + assert it_ == callback_expected_iter[0] + callback_expected_iter[0] += 1 + + info = optimizer.optimize( + track_best_solution=True, + track_err_history=True, + end_iter_callback=callback, + ftov_msg_damping=ftov_damping, + dropout=dropout, + lin_system_damping=lin_system_damping, + verbose=True, + ) + gbp_solution = [var.tensor.clone() for var in variables] # Solve with linear solver objective.update(input_tensors) linear_optimizer = th.LinearOptimizer(objective, th.CholeskyDenseSolver) linear_optimizer.optimize(verbose=True) lin_solution = [var.tensor.clone() for var in variables] - lin_solve_err = objective.error_squared_norm() / 2 - print("linear solver error", lin_solve_err.item()) - - for x, x_target in zip(lin_solution, gn_solution): - print(x, x_target) - assert x.allclose(x_target) - # print("comparing GBP") - # for x, x_target in zip(gbp_solution, gn_solution): - # print(x, x_target) - # assert x.allclose(x_target) - - # for x, x_target in zip(gbp_solution, lin_solution): - # print(x, x_target) - # assert x.allclose(x_target) + # Solve with Gauss-Newton + # If problem is poorly conditioned solving with Gauss-Newton can yield + # a slightly different solution to one linear solve, so check both + objective.update(input_tensors) + gn_optimizer = th.GaussNewton(objective, th.CholeskyDenseSolver) + gn_optimizer.optimize(verbose=True) + gn_solution = [var.tensor.clone() for var in variables] - # _check_info(info, batch_size, max_iterations, initial_error, objective) + # checks + for x, x_target in zip(gbp_solution, lin_solution): + assert x.allclose(x_target, rtol=1e-3) + for x, x_target in zip(gbp_solution, gn_solution): + assert x.allclose(x_target, rtol=1e-3) + _check_info(info, batch_size, max_iterations, initial_error, objective) # # Visualise reconstructed surface # soln_vec = torch.cat(gbp_solution, dim=1)[0] # import matplotlib.pylab as plt # plt.scatter(torch.arange(n_variables), soln_vec, label="solution") - # plt.scatter(torch.arange(n_variables), measurements[0], label="meas") + # plt.scatter(torch.arange(n_variables), noisy_meas[0], label="meas") # plt.legend() # plt.show() def test_gbp_linear_solver(): - optimize_kwargs = {} - # run_gbp_linear_solver(vectorize=True, optimize_kwargs=optimize_kwargs) - run_gbp_linear_solver(vectorize=False, optimize_kwargs=optimize_kwargs) + # problems with increasing loopyness + # the loopier the fewer iterations to solve + frac_loops = [0.1, 0.2, 0.5] + for frac in frac_loops: + + run_gbp_linear_solver(frac_loops=frac) + + # with factor to variable message damping, may take too many steps to converge + # run_gbp_linear_solver(vectorize=vectorize, frac_loops=frac, ftov_damping=0.1) + # with dropout + run_gbp_linear_solver(frac_loops=frac, dropout=0.1) + + # test linear system damping + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([0.0])) + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-2])) + run_gbp_linear_solver(frac_loops=frac, lin_system_damping=torch.tensor([1e-6])) + + # test without vectorization once + run_gbp_linear_solver(frac_loops=0.5, vectorize=False) From c2cebbeba3233b5839907ba187cc141ef676d42f Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 6 Jan 2023 17:37:33 +0000 Subject: [PATCH 58/64] dropout starts later --- theseus/optimizer/gbp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 7fab88404..2433e5c93 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -939,7 +939,7 @@ def _optimize_loop( dtype=self.ordering[0].dtype, ) # dropout is implemented by changing the schedule - if dropout != 0.0 and it_ != 0: + if dropout != 0.0 and it_ > 1: dropout_ixs = torch.rand(self.n_edges) < dropout ftov_schedule[it_, dropout_ixs] = False From a94225f34106f18c9b20116d3396d5bcaa32f489 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 13 Jan 2023 12:45:44 +0000 Subject: [PATCH 59/64] Comments and references for understanding GBP code --- theseus/optimizer/gbp.py | 83 +++++++++++++++++++--------------------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 2433e5c93..282a5edae 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -39,8 +39,19 @@ """ TODO - - solving inverse problem to compute message mean - - factor inherits CF class +- replace generic nonlinear optimizer components. + +Summary. +This file contains the factor class used to wrap cost functions for GBP. +Factor to variable message passing functions are within the factor class. +Variable to factor message passing functions are within GBP optimizer class. + + +References for understanding GBP: +- https://gaussianbp.github.io/ +- https://arxiv.org/abs/1910.14139 +Reference for GBP with non-Euclidean variables: +- https://arxiv.org/abs/2202.03314 """ @@ -85,7 +96,8 @@ def synchronous_schedule(max_iters, n_edges) -> torch.Tensor: # return schedule -# Initialises message precision to zero +# GBP message class, messages are Gaussian distributions +# Has additional fn to initialise messages with zero precision class Message(ManifoldGaussian): def __init__( self, @@ -120,11 +132,7 @@ def zero_message(self): self.update(mean=new_mean, precision=new_precision) -""" -GBP functions -""" - - +# Factor class, one is created for each cost function class Factor: _ids = count(0) @@ -151,6 +159,7 @@ def __init__( device = cf.optim_var_at(0).device dtype = cf.optim_var_at(0).dtype self._dof = sum([var.dof() for var in cf.optim_vars]) + # for storing factor linearization self.potential_eta = torch.zeros(self.batch_size, self.dof).to( dtype=dtype, device=device ) @@ -160,7 +169,6 @@ def __init__( self.lin_point: List[Manifold] = [ var.copy(new_name=f"{cf.name}_{var.name}_lp") for var in cf.optim_vars ] - self.steps_since_lin = torch.zeros( self.batch_size, device=device, dtype=torch.int ) @@ -172,6 +180,7 @@ def __init__( self.a = 2 self.b = 10 + # messages incoming and outgoing from the factor, they are updated in place self.vtof_msgs: List[Message] = [] self.ftov_msgs: List[Message] = [] for var in cf.optim_vars: @@ -189,11 +198,7 @@ def __init__( # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. - def linearize( - self, - relin_threshold: float = None, - lie=True, - ): + def linearize(self, relin_threshold: float = None, lie=True): self.steps_since_lin += 1 if relin_threshold is None: @@ -217,7 +222,9 @@ def linearize( J, error = self.cf.weighted_jacobians_error() J_stk = torch.cat(J, dim=-1) + # eqn 30 - https://arxiv.org/pdf/2202.03314.pdf lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) + # eqn 31 - https://arxiv.org/pdf/2202.03314.pdf eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) if lie is False: optim_vars_stk = torch.cat( @@ -226,7 +233,7 @@ def linearize( eta = eta + torch.matmul(lam, optim_vars_stk.unsqueeze(-1)) eta = eta.squeeze(-1) - # update damping parameter. This is non-differentiable + # update linear system damping parameter (this is non-differentiable) with torch.no_grad(): err = (self.cf.error() ** 2).sum(dim=1) if self.last_err is not None: @@ -254,11 +261,7 @@ def linearize( self.steps_since_lin[do_lin] = 0 # Compute all outgoing messages from the factor. - def comp_mess( - self, - msg_damping, - schedule, - ): + def comp_mess(self, msg_damping, schedule): num_optim_vars = self.cf.num_optim_vars() new_messages = [] @@ -270,6 +273,7 @@ def comp_mess( # Take product of factor with incoming messages. # Convert mesages to tangent space at linearisation point. + # eqns 34-38 - https://arxiv.org/pdf/2202.03314.pdf start = 0 for i in range(num_optim_vars): var_dofs = self.cf.optim_var_at(i).dof() @@ -286,6 +290,7 @@ def comp_mess( dofs = self.cf.optim_var_at(v).dof() + # if no incoming messages then send out zero message if torch.allclose(lam_factor, lam_factor_copy) and num_optim_vars > 1: # print(self.cf.name, "---> not updating, incoming precision is zero") new_mess = Message([self.cf.optim_var_at(v).copy()]) @@ -293,7 +298,7 @@ def comp_mess( else: # print(self.cf.name, "---> sending message") - # Divide up parameters of distribution + # Divide up parameters of distribution to compute schur complement # *_out = parameters for receiver variable (outgoing message vars) # *_notout = parameters for other variables (not outgoing message vars) eta_out = eta_factor[:, sdim : sdim + dofs] @@ -336,6 +341,7 @@ def comp_mess( dim=1, ) + # Schur complement computation new_mess_lam = ( lam_out_out - lam_out_notout @@ -384,17 +390,10 @@ def comp_mess( new_mess_eta[no_update] = prev_mess_eta[no_update] new_mess_lam[no_update] = prev_mess_lam[no_update] - # new_mess_lam = th.DenseSolver._apply_damping( - # new_mess_lam, - # self.lin_system_damping, - # ellipsoidal=True, - # eps=1e-8, - # ) - + # eqns 39-41 - https://arxiv.org/pdf/2202.03314.pdf new_mess_mean = th.LUDenseSolver._solve_sytem( new_mess_eta[..., None], new_mess_lam ) - new_mess = th.retract_gaussian( self.lin_point[v], new_mess_mean, new_mess_lam ) @@ -435,6 +434,7 @@ def __init__( """ Copied and slightly modified from nonlinear optimizer class + GBP class should inherit these functions. """ def set_params(self, **kwargs): @@ -601,10 +601,11 @@ def _merge_infos( GBP functions """ + # eqns in sec 3.1 - https://arxiv.org/pdf/2202.03314.pdf def _pass_var_to_fac_messages_loop(self, update_belief=True): for i, var in enumerate(self.ordering): - # Collect all incoming messages in the tangent space at the current belief + # eqns 4-7 - https://arxiv.org/pdf/2202.03314.pdf etas_tp = [] # message etas lams_tp = [] # message lams for factor in self.factors: @@ -620,6 +621,7 @@ def _pass_var_to_fac_messages_loop(self, update_belief=True): lam_tau = lams_tp.sum(dim=0) # Compute outgoing messages + # eqns 8-10 - https://arxiv.org/pdf/2202.03314.pdf ix = 0 # index for msg in list of msgs to variables for factor in self.factors: for j, msg in enumerate(factor.vtof_msgs): @@ -650,6 +652,7 @@ def _pass_var_to_fac_messages_loop(self, update_belief=True): new_belief = th.retract_gaussian(var, tau, lam_tau) self.beliefs[i].update(new_belief.mean, new_belief.precision) + # Similar to above fn but vectorization require tracking extra indices def _pass_var_to_fac_messages_vectorized(self, update_belief=True): # Each (variable-type, dof) gets mapped to a tuple with: # - the variable that will hold the vectorized data @@ -660,10 +663,7 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): Tuple[Type[Manifold], int], Tuple[Manifold, List[Manifold], List[int], List[torch.Tensor]], ] = {} - batch_size = -1 - # Create var info by looping variables in the given order - # All variables of the same type get grouped together for ix, var in enumerate(self.ordering): if batch_size == -1: batch_size = var.shape[0] @@ -676,6 +676,7 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): var_info[var_type][1].append(var) var_info[var_type][2].append(ix) + # For each variable-type, create tensors to accumulate incoming messages for var_type, (vectorized_var, var_list, _, eta_lam) in var_info.items(): n_vars, dof = len(var_list), vectorized_var.dof() @@ -696,9 +697,9 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): eta_lam.extend([eta_tp_acc, lam_tp_acc]) # add ftov messages to eta_tp and lam_tp accumulator tensors + # eqns 4-7 - https://arxiv.org/pdf/2202.03314.pdf for factor in self.factors: for i, msg in enumerate(factor.ftov_msgs): - # transform messages to tangent plane at the current variable value eta_tp, lam_tp = th.local_gaussian( factor.cf.optim_var_at(i), msg, return_mean=False @@ -738,9 +739,9 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): lam_tp_acc.index_add_(0, factor.vectorized_var_ixs[i], lam_tp) # compute variable to factor messages, now all incoming messages are accumulated + # eqns 8-10 - https://arxiv.org/pdf/2202.03314.pdf for factor in self.factors: for i, msg in enumerate(factor.vtof_msgs): - # transform messages to tangent plane at the current variable value ftov_msg = factor.ftov_msgs[i] eta_tp, lam_tp = th.local_gaussian( @@ -769,7 +770,8 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): msg.update(new_mess.mean, new_mess.precision) # compute the new belief for the vectorized variables - for (vectorized_var, _, var_ixs, eta_lam) in var_info.values(): + # eqns 42-45 - https://arxiv.org/pdf/2202.03314.pdf + for vectorized_var, _, var_ixs, eta_lam in var_info.values(): eta_tp_acc = eta_lam[0] lam_tau = eta_lam[1] @@ -834,7 +836,6 @@ def _create_factors_beliefs(self, lin_system_damping): # compute factor potentials for the first time unary_factor = False for i, cost_function in enumerate(cf_iterator): - if self.objective.vectorized: # create array for indexing the messages base_cf_names = self.objective.vectorized_cf_names[i] @@ -887,13 +888,11 @@ def _create_factors_beliefs(self, lin_system_damping): Optimization loop functions """ - # loop for the iterative optimizer def _optimize_loop( self, num_iter: int, info: NonlinearOptimizerInfo, verbose: bool, - truncated_grad_loop: bool, relin_threshold: float, ftov_msg_damping: float, dropout: float, @@ -975,7 +974,7 @@ def _optimize_loop( self._update_info(info, it_, err, converged_indices) if verbose: print( - f"GBP. Iteration: {it_+1}. Error: {err.mean().item():.4f}. " + f"GBP. Iteration: {it_+1}. Error: {err.mean().item()}. " f"Relins: {relins} / {self.n_individual_factors}" ) converged_indices = self._check_convergence(err, info.last_err) @@ -1055,7 +1054,6 @@ def _optimize_impl( num_iter=self.params.max_iterations, info=info, verbose=verbose, - truncated_grad_loop=False, relin_threshold=relin_threshold, ftov_msg_damping=ftov_msg_damping, dropout=dropout, @@ -1104,7 +1102,6 @@ def _optimize_impl( num_iter=num_no_grad_iter, info=info, verbose=verbose, - truncated_grad_loop=False, relin_threshold=relin_threshold, ftov_msg_damping=ftov_msg_damping, dropout=dropout, @@ -1122,7 +1119,6 @@ def _optimize_impl( num_iter=backward_num_iterations, info=grad_loop_info, verbose=verbose, - truncated_grad_loop=True, relin_threshold=relin_threshold, ftov_msg_damping=ftov_msg_damping, dropout=dropout, @@ -1159,7 +1155,6 @@ def _optimize_impl( num_iter=max_lin_solve_iters, info=grad_loop_info, verbose=verbose, - truncated_grad_loop=True, relin_threshold=1e10, ftov_msg_damping=ftov_msg_damping, dropout=dropout, From b3551faee2d93c008956fb718dbb98e81b6795ea Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Fri, 13 Jan 2023 14:26:01 +0000 Subject: [PATCH 60/64] Reduced atol threshold for symmetric precision matrix --- theseus/optimizer/manifold_gaussian.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/theseus/optimizer/manifold_gaussian.py b/theseus/optimizer/manifold_gaussian.py index 2dafc0501..8bce4ee50 100644 --- a/theseus/optimizer/manifold_gaussian.py +++ b/theseus/optimizer/manifold_gaussian.py @@ -101,8 +101,8 @@ def update( f"Tried to update using tensor on device: {precision.dtype} but precision " f"is on device: {self.device}." ) - # if not torch.allclose(precision, precision.transpose(1, 2)): - # raise ValueError("Tried to update precision with non-symmetric matrix.") + if not torch.allclose(precision, precision.transpose(1, 2), atol=1e-5): + raise ValueError("Tried to update precision with non-symmetric matrix.") self.precision = precision From aade3793c514f48149a02fdb0f4eb51d57be9e53 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Mon, 16 Jan 2023 11:17:09 +0000 Subject: [PATCH 61/64] Detach hessian in implicit GBP backward mode --- theseus/optimizer/gbp.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 282a5edae..60c17a09d 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -40,6 +40,8 @@ """ TODO - replace generic nonlinear optimizer components. +- Remove implicit backward mode with Gauss-Newton, or at least modify it +to make sure it detaches the hessian. Summary. This file contains the factor class used to wrap cost functions for GBP. @@ -198,7 +200,12 @@ def __init__( # Linearizes factors at current belief if beliefs have deviated # from the linearization point by more than the threshold. - def linearize(self, relin_threshold: float = None, lie=True): + def linearize( + self, + relin_threshold: float = None, + detach_hessian: bool = False, + lie=True, + ): self.steps_since_lin += 1 if relin_threshold is None: @@ -223,7 +230,11 @@ def linearize(self, relin_threshold: float = None, lie=True): J_stk = torch.cat(J, dim=-1) # eqn 30 - https://arxiv.org/pdf/2202.03314.pdf - lam = torch.bmm(J_stk.transpose(-2, -1), J_stk) + lam = ( + torch.bmm(J_stk.transpose(-2, -1), J_stk).detach() + if detach_hessian + else torch.bmm(J_stk.transpose(-2, -1), J_stk) + ) # eqn 31 - https://arxiv.org/pdf/2202.03314.pdf eta = -torch.matmul(J_stk.transpose(-2, -1), error.unsqueeze(-1)) if lie is False: @@ -797,10 +808,14 @@ def _pass_var_to_fac_messages_vectorized(self, update_belief=True): self.beliefs[ix].update([belief_mean_slice], belief_precision_slice) start_idx += batch_size - def _linearize_factors(self, relin_threshold: float = None): + def _linearize_factors( + self, relin_threshold: float = None, detach_hessian: bool = False + ): relins = 0 for factor in self.factors: - factor.linearize(relin_threshold=relin_threshold) + factor.linearize( + relin_threshold=relin_threshold, detach_hessian=detach_hessian + ) relins += int((factor.steps_since_lin == 0).sum().item()) return relins @@ -914,7 +929,7 @@ def _optimize_loop( relin_threshold = 1e10 # no relinearisation if self.objective.vectorized: self.objective.update_vectorization_if_needed() - self._linearize_factors() + self._linearize_factors(detach_hessian=True) if schedule == GBPSchedule.SYNCHRONOUS: ftov_schedule = synchronous_schedule(num_iter, self.n_edges) From 4c905b93b7a006357aaf3b70493f4d3e76efb79d Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Mon, 16 Jan 2023 13:06:57 +0000 Subject: [PATCH 62/64] Fix linearization for truncated, exception for DLM --- theseus/optimizer/gbp.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 60c17a09d..3d6db99b9 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -924,6 +924,8 @@ def _optimize_loop( self._create_factors_beliefs(lin_system_damping) else: self.objective.update_vectorization_if_needed() + if not implicit_gbp_loop: + self._linearize_factors() if implicit_gbp_loop: relin_threshold = 1e10 # no relinearisation @@ -1029,6 +1031,10 @@ def _optimize_impl( **kwargs, ) -> NonlinearOptimizerInfo: backward_mode = BackwardMode.resolve(backward_mode) + if backward_mode == BackwardMode.DLM: + raise ValueError( + "DLM backward mode not supported for Gaussian Belief Propagation optimizer." + ) with torch.no_grad(): info = self._init_info( @@ -1147,8 +1153,10 @@ def _optimize_impl( self._merge_infos( grad_loop_info, no_grad_iters_done, grad_iters_done, info ) + + # Compute the approximate the implicit derivative using a gauss-newton step. + # It is first order approximate, as a full Newton step would be exact. elif implicit_method == "gauss_newton": - # use Gauss-Newton update to compute implicit gradient self.implicit_step_size = implicit_step_size gauss_newton_optimizer = th.GaussNewton(self.objective) gauss_newton_optimizer.linear_solver.linearization.linearize() @@ -1163,8 +1171,8 @@ def _optimize_impl( print( f"Nonlinear optimizer implcit step. Error: {err.mean().item()}" ) + # solve normal equation with GBP elif implicit_method == "gbp": - # solve normal equation in a distributed way max_lin_solve_iters = 1000 grad_iters_done = self._optimize_loop( num_iter=max_lin_solve_iters, From 4331f71000e39f4ee83fa0c330ff795d311567dc Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Tue, 17 Jan 2023 15:34:57 +0000 Subject: [PATCH 63/64] Fixed bug in linear system damping with vectorization --- theseus/optimizer/gbp.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index 3d6db99b9..f2b0e77dc 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -178,7 +178,7 @@ def __init__( self.lm_damping = lin_system_damping.repeat(self.batch_size).to(device) self.min_damping = torch.Tensor([1e-4]).to(dtype=dtype, device=device) self.max_damping = torch.Tensor([1e2]).to(dtype=dtype, device=device) - self.last_err: torch.Tensor = None + self.last_err: torch.Tensor = (self.cf.error() ** 2).sum(dim=1) self.a = 2 self.b = 10 @@ -247,15 +247,15 @@ def linearize( # update linear system damping parameter (this is non-differentiable) with torch.no_grad(): err = (self.cf.error() ** 2).sum(dim=1) - if self.last_err is not None: - decreased_ixs = err < self.last_err - self.lm_damping[decreased_ixs] = torch.max( - self.lm_damping[decreased_ixs] / self.a, self.min_damping - ) - self.lm_damping[~decreased_ixs] = torch.min( - self.lm_damping[~decreased_ixs] * self.b, self.max_damping - ) - self.last_err = err + damp_ixs = torch.logical_and(err > self.last_err, do_lin) + undamp_ixs = torch.logical_and(err < self.last_err, do_lin) + self.lm_damping[undamp_ixs] = torch.max( + self.lm_damping[undamp_ixs] / self.a, self.min_damping + ) + self.lm_damping[damp_ixs] = torch.min( + self.lm_damping[damp_ixs] * self.b, self.max_damping + ) + self.last_err[do_lin] = err[do_lin] # damp precision matrix damped_D = self.lm_damping[:, None, None] * torch.eye( From 633f67083b15f085db2cf56c57b5088c64363158 Mon Sep 17 00:00:00 2001 From: Joseph Ortiz Date: Thu, 19 Jan 2023 13:22:33 +0000 Subject: [PATCH 64/64] Zero messages correctly when using vectorization --- theseus/optimizer/gbp.py | 21 +++++++++++++++++---- theseus/optimizer/manifold_gaussian.py | 11 +++++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/theseus/optimizer/gbp.py b/theseus/optimizer/gbp.py index f2b0e77dc..5f3565fc6 100644 --- a/theseus/optimizer/gbp.py +++ b/theseus/optimizer/gbp.py @@ -115,7 +115,7 @@ def __init__( super(Message, self).__init__(mean, precision=precision, name=name) # sets mean to the group identity and zero precision matrix - def zero_message(self): + def zero_message(self, batch_ignore_mask: Optional[torch.Tensor] = None): new_mean = [] batch_size = self.mean[0].shape[0] for var in self.mean: @@ -131,7 +131,9 @@ def zero_message(self): new_precision = torch.zeros(batch_size, self.dof, self.dof).to( dtype=self.dtype, device=self.device ) - self.update(mean=new_mean, precision=new_precision) + self.update( + mean=new_mean, precision=new_precision, batch_ignore_mask=batch_ignore_mask + ) # Factor class, one is created for each cost function @@ -301,8 +303,11 @@ def comp_mess(self, msg_damping, schedule): dofs = self.cf.optim_var_at(v).dof() + inc_messages = ( + ~torch.isclose(lam_factor, lam_factor_copy).all(dim=1).all(dim=1) + ) # if no incoming messages then send out zero message - if torch.allclose(lam_factor, lam_factor_copy) and num_optim_vars > 1: + if not inc_messages.any() and num_optim_vars > 1: # print(self.cf.name, "---> not updating, incoming precision is zero") new_mess = Message([self.cf.optim_var_at(v).copy()]) new_mess.zero_message() @@ -405,9 +410,17 @@ def comp_mess(self, msg_damping, schedule): new_mess_mean = th.LUDenseSolver._solve_sytem( new_mess_eta[..., None], new_mess_lam ) - new_mess = th.retract_gaussian( + new_mess_params = th.retract_gaussian( self.lin_point[v], new_mess_mean, new_mess_lam ) + new_mess = Message( + mean=new_mess_params.mean, precision=new_mess_params.precision + ) + + # set zero msg for factors with no incoming messages + if not inc_messages.all() and num_optim_vars > 1: + # i.e. ignore (don't zero) if there are incoming messages + new_mess.zero_message(batch_ignore_mask=inc_messages) new_messages.append(new_mess) diff --git a/theseus/optimizer/manifold_gaussian.py b/theseus/optimizer/manifold_gaussian.py index 8bce4ee50..fd37484d7 100644 --- a/theseus/optimizer/manifold_gaussian.py +++ b/theseus/optimizer/manifold_gaussian.py @@ -74,6 +74,7 @@ def update( self, mean: Sequence[Manifold], precision: torch.Tensor, + batch_ignore_mask: Optional[torch.Tensor] = None, ): if len(mean) != len(self.mean): raise ValueError( @@ -82,7 +83,7 @@ def update( f"Expected: {len(self.mean)}" ) for i in range(len(self.mean)): - self.mean[i].update(mean[i]) + self.mean[i].update(mean[i], batch_ignore_mask=batch_ignore_mask) expected_shape = torch.Size([mean[0].shape[0], self.dof, self.dof]) if precision.shape != expected_shape: @@ -104,7 +105,13 @@ def update( if not torch.allclose(precision, precision.transpose(1, 2), atol=1e-5): raise ValueError("Tried to update precision with non-symmetric matrix.") - self.precision = precision + if batch_ignore_mask is not None and batch_ignore_mask.any(): + mask_shape = (-1,) + (1,) * (precision.ndim - 1) + self.precision = torch.where( + batch_ignore_mask.view(mask_shape), self.precision, precision + ) + else: + self.precision = precision # Projects the gaussian (ManifoldGaussian object) into the tangent plane at