Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-graph-jacobian.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:

- name: Install dependencies
run: |
python -m pip install --break-system-packages --no-cache-dir --target /tmp/pydeps pytest warp-lang
python -m pip install --break-system-packages --no-cache-dir --target /tmp/pydeps pytest warp-lang matplotlib
python -m pip install --break-system-packages --no-cache-dir --target /tmp/pydeps --no-deps --no-build-isolation "git+https://github.com/pypose/pypose.git"
env:
PIP_CACHE_DIR: /tmp/pip-cache
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ task.md
*.png
.DS_Store
tmp/*
examples/module/pgo/data/*
3 changes: 3 additions & 0 deletions bae/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .utils.pypose_ambient_grad import maybe_install_pypose_ambient_grad_monkeypatch

maybe_install_pypose_ambient_grad_monkeypatch()
60 changes: 50 additions & 10 deletions bae/autograd/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from typing import Optional
import warnings

import pypose as pp
import torch
Expand Down Expand Up @@ -172,13 +173,18 @@ def update_from_trace(bsrt: torch.Tensor, arg, new_col: Optional[torch.Tensor]=N
)
return jac_trace

def backward(output_):
def backward(output_, is_root=False):
# For non-root recursion, no incoming trace means no contribution to
# propagate. This avoids re-initializing identity traces on revisits.
if (not is_root) and (not hasattr(output_, 'jactrace')):
return

if output_.optrace[id(output_)][0] == 'map':
func = output_.optrace[id(output_)][1]
args = output_.optrace[id(output_)][2]
argnums = tuple(idx for idx, arg in enumerate(args) if hasattr(arg, 'optrace') or isinstance(arg, torch.nn.Parameter))
if len(argnums) == 0:
warning("No upstream parameters to compute jacobian")
warnings.warn("No upstream parameters to compute jacobian", stacklevel=2)
return
with pp.retain_ltype():
jac_blocks = torch.vmap(jacrev(func, argnums=argnums))(*args)
Expand Down Expand Up @@ -207,9 +213,22 @@ def backward(output_):
elif isinstance(output_.jactrace, torch.Tensor) and output_.jactrace.layout == torch.sparse_bsr:
jac_trace = update_from_trace(output_.jactrace, arg, new_val=jac_block)
amend_trace(arg, jac_trace)
# Recurse once per unique upstream tensor after all local contributions
# have been accumulated.
seen = set()
for argidx in argnums:
if isinstance(args[argidx], torch.Tensor) and hasattr(args[argidx], 'optrace'):
backward(args[argidx])
arg = args[argidx]
if isinstance(arg, torch.Tensor) and hasattr(arg, 'optrace'):
arg_id = id(arg)
if arg_id in seen:
continue
seen.add(arg_id)
backward(arg, is_root=False)

# Consume intermediate trace to avoid re-propagating it when this node is
# reached again from another downstream branch (e.g. two index ops on one cat).
if hasattr(output_, 'jactrace'):
delattr(output_, 'jactrace')


elif output_.optrace[id(output_)][0] == 'index':
Expand All @@ -220,6 +239,8 @@ def backward(output_):
# populate Jacobian values. In this case, the Jacobian block values are
# identity matrices placed at the indexed columns.
if not hasattr(output_, 'jactrace'):
if not is_root:
return
if output_.ndim == 1:
eye_blocks = torch.ones((output_.shape[0], 1, 1), device=output_.device, dtype=output_.dtype)
else:
Expand All @@ -242,7 +263,10 @@ def backward(output_):

amend_trace(arg, jac_trace)
if isinstance(arg, torch.Tensor) and hasattr(arg, 'optrace'):
backward(arg)
backward(arg, is_root=False)

if hasattr(output_, 'jactrace'):
delattr(output_, 'jactrace')

elif output_.optrace[id(output_)][0] == 'cat':
dim = output_.optrace[id(output_)][1]
Expand All @@ -251,6 +275,8 @@ def backward(output_):
raise NotImplementedError("Only torch.cat(..., dim=0) is supported")

if not hasattr(output_, 'jactrace'):
if not is_root:
return
if output_.ndim == 1:
eye_blocks = torch.ones((output_.shape[0], 1, 1), device=output_.device, dtype=output_.dtype)
else:
Expand All @@ -265,6 +291,12 @@ def backward(output_):
n = arg.shape[0]
start, end = offset, offset + n

# Fixed/non-optimizable tensors (e.g., gauge-fixed first pose) do
# not need Jacobian traces.
if not (hasattr(arg, 'optrace') or isinstance(arg, torch.nn.Parameter)):
offset = end
continue

if type(upstream) is tuple:
jac_trace = _slice_upstream_tuple_columns(
upstream[0], upstream[1], start, end, out_cols_blocks=n
Expand All @@ -278,23 +310,31 @@ def backward(output_):

amend_trace(arg, jac_trace)
if isinstance(arg, torch.Tensor) and hasattr(arg, 'optrace'):
backward(arg)
backward(arg, is_root=False)
offset = end

if hasattr(output_, 'jactrace'):
delattr(output_, 'jactrace')


def jacobian(output, params):
assert output.optrace[id(output)][0] in ('map', 'index', 'cat'), "Unsupported last operation in compute graph"
_clear_jactrace(output, params)
try:
backward(output)
backward(output, is_root=True)
res = []
for param in params:
if hasattr(param, 'jactrace'):
if isinstance(param.jactrace, tuple):
values = trim_parameter_jacobian_values(param, param.jactrace[1])
param.jactrace = (param.jactrace[0], values)
indices, values = param.jactrace
values = trim_parameter_jacobian_values(param, values, block_indices=indices)
param.jactrace = (indices, values)
elif isinstance(param.jactrace, torch.Tensor) and param.jactrace.layout == torch.sparse_bsr:
values = trim_parameter_jacobian_values(param, param.jactrace.values())
values = trim_parameter_jacobian_values(
param,
param.jactrace.values(),
block_indices=param.jactrace.col_indices(),
)
if values.shape != param.jactrace.values().shape:
param.jactrace = torch.sparse_bsr_tensor(
col_indices=param.jactrace.col_indices(),
Expand Down
5 changes: 4 additions & 1 deletion bae/optim/optimizer.py
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer needed because mm supports all float types now

Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def step(self, input, target=None, weight=None):
self.reject_count = 0
J_T = J_T.to_sparse_csr()
J = J.to_sparse_csr()
A = self.mm(J_T, J)
if J.dtype == torch.float64:
A = self.mm(J_T, J)
else:
A = J_T @ J

diagonal_op_(A, op=partial(torch.clamp_, min=pg['min'], max=pg['max']))

Expand Down
26 changes: 22 additions & 4 deletions bae/sparse/py_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def diagonal_op_(input, offset: int=0, op: Optional[Callable]=None):

#simple case(block is square and offset is 0)
if dm == dn and offset == 0:
indices = None
if not USE_TRITON:
dummy_val = torch.zeros(bsr_values.shape[0], device='cpu')
dummy = torch.sparse_csr_tensor(crow_indices=crow_indices.to('cpu'),
Expand All @@ -134,17 +135,35 @@ def diagonal_op_(input, offset: int=0, op: Optional[Callable]=None):
n_diag_blocks = sm if sm < sn else sn
if diag_indices.shape[-1] == n_diag_blocks:
results = values
diag_rows = None
else:
# Triton path only computes a mask over existing entries. If some diagonal
# blocks are structurally missing, we need row indices to scatter into a
# dense diagonal buffer.
if indices is None:
dummy_val = torch.zeros(bsr_values.shape[0], device='cpu')
dummy = torch.sparse_csr_tensor(crow_indices=crow_indices.to('cpu'),
col_indices=col_indices.to('cpu'),
values=dummy_val)
indices = dummy.to_sparse(layout=torch.sparse_coo).coalesce().indices().to(input.device)
results_shape = (n_diag_blocks, dm)
results = torch.zeros(results_shape, dtype=values.dtype, device=values.device)
results[indices[0, diag_indices]] = values
assert op is None, "op is not supported for diagonal that has empty values."
diag_rows = indices[0, diag_indices]
if values.ndim == 1:
results[diag_rows, 0] = values
else:
results[diag_rows] = values
if bsr_values.ndim > 1:
results = torch.flatten(results, start_dim=-2, end_dim=-1)
# apply the inplace op
if op is not None:
results = op(results)
block_diags[diag_indices] = results.view(n_diag_blocks, dm) if bsr_values.ndim > 1 else results
if diag_rows is None:
block_diags[diag_indices] = results.view(n_diag_blocks, dm) if bsr_values.ndim > 1 else results
else:
if diag_indices.numel() > 0:
dense_results = results.view(n_diag_blocks, dm) if bsr_values.ndim > 1 else results
block_diags[diag_indices] = dense_results[diag_rows]
return results
else:
raise NotImplementedError('Only square block and offset 0 is supported.')
Expand Down Expand Up @@ -231,4 +250,3 @@ def bsr2bsc(J):
sparse_lib = Library('aten', 'IMPL')
sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCPU')
sparse_lib.impl('diagonal', diagonal_op_, 'SparseCsrCUDA')

59 changes: 59 additions & 0 deletions bae/utils/ceres_pose.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch


def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
x, y, z = v.unbind(dim=-1)
zero = torch.zeros_like(x)
rows = (
torch.stack((zero, -z, y), dim=-1),
torch.stack((z, zero, -x), dim=-1),
torch.stack((-y, x, zero), dim=-1),
)
return torch.stack(rows, dim=-2)


def quat_mul_xyzw(q1: torch.Tensor, q2: torch.Tensor) -> torch.Tensor:
x1, y1, z1, w1 = q1.unbind(dim=-1)
x2, y2, z2, w2 = q2.unbind(dim=-1)
return torch.stack((
w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2,
w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2,
w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2,
w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2,
), dim=-1)


def quat_conj_xyzw(q: torch.Tensor) -> torch.Tensor:
return torch.cat((-q[..., :3], q[..., 3:4]), dim=-1)


def quat_inv_xyzw(q: torch.Tensor) -> torch.Tensor:
return quat_conj_xyzw(q) / q.square().sum(dim=-1, keepdim=True)


def quat_rotate_xyzw(q: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
qv = torch.cat((v, torch.zeros_like(v[..., :1])), dim=-1)
return quat_mul_xyzw(quat_mul_xyzw(q, qv), quat_inv_xyzw(q))[..., :3]


def quaternion_plus_jacobian_xyzw(quat: torch.Tensor) -> torch.Tensor:
qx, qy, qz, qw = quat.unbind(dim=-1)
rows = (
torch.stack((qw, qz, -qy), dim=-1),
torch.stack((-qz, qw, qx), dim=-1),
torch.stack((qy, -qx, qw), dim=-1),
torch.stack((-qx, -qy, -qz), dim=-1),
)
return 0.5 * torch.stack(rows, dim=-2)


def se3_pose_plus_jacobian_xyzw(pose: torch.Tensor) -> torch.Tensor:
if pose.shape[-1] != 7:
raise ValueError(f"Expected 7D pose blocks, got {pose.shape[-1]}.")

J = torch.zeros(*pose.shape[:-1], 7, 6, dtype=pose.dtype, device=pose.device)
eye = torch.eye(3, dtype=pose.dtype, device=pose.device)
J[..., :3, :3] = eye
J[..., :3, 3:6] = -skew_symmetric(pose[..., :3])
J[..., 3:7, 3:6] = quaternion_plus_jacobian_xyzw(pose[..., 3:7])
return J
25 changes: 23 additions & 2 deletions bae/utils/parameter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pypose as pp
import torch

from .ceres_pose import quaternion_plus_jacobian_xyzw, se3_pose_plus_jacobian_xyzw
from .pypose_ambient_grad import pypose_ambient_grad_enabled


def parameter_update_shape(param: torch.Tensor) -> torch.Size:
if param.ndim == 0:
Expand All @@ -12,12 +15,30 @@ def parameter_update_shape(param: torch.Tensor) -> torch.Size:
return param.shape


def trim_parameter_jacobian_values(param: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
def trim_parameter_jacobian_values(
param: torch.Tensor,
values: torch.Tensor,
block_indices: torch.Tensor | None = None,
) -> torch.Tensor:
if param.ndim == 0 or values.shape[-1] != param.shape[-1]:
return values
if getattr(param, 'trim_SE3_grad', False):
return torch.cat([values[..., :6], values[..., 7:]], dim=-1)
pose = param[..., :7].detach()
if block_indices is not None:
pose = pose[block_indices.to(torch.long)]
pose_values = values[..., :7] @ se3_pose_plus_jacobian_xyzw(pose)
if param.shape[-1] == 7:
return pose_values
return torch.cat([pose_values, values[..., 7:]], dim=-1)
if isinstance(param, pp.LieTensor):
if pypose_ambient_grad_enabled():
lie_param = param.detach()
if block_indices is not None:
lie_param = lie_param[block_indices.to(torch.long)]
if param.ltype == pp.SO3_type:
return values @ quaternion_plus_jacobian_xyzw(lie_param)
if param.ltype == pp.SE3_type:
return values @ se3_pose_plus_jacobian_xyzw(lie_param)
step_dim = int(param.ltype.manifold[0])
if step_dim != param.shape[-1]:
return values[..., :step_dim]
Expand Down
Loading
Loading