Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 38 additions & 17 deletions src/sharp/utils/gaussians.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,37 +133,58 @@ def apply_transform(gaussians: Gaussians3D, transform: torch.Tensor) -> Gaussian

def decompose_covariance_matrices(
covariance_matrices: torch.Tensor,
use_gpu: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Decompose 3D covariance matrices into quaternions and singular values.

Args:
covariance_matrices: The covariance matrices to decompose.
use_gpu: If True, perform SVD on GPU and use GPU quaternion conversion for
~10x faster decomposition. If False, use CPU with float64 for maximum
numerical precision (original behavior).

Returns:
Quaternion and singular values corresponding to the orientation and scales of
the diagonalized matrix.

Note: This operation is not differentiable.
Note: This operation is not differentiable. GPU mode provides significant speedup
for large batches but may have slightly reduced numerical precision.
"""
device = covariance_matrices.device
dtype = covariance_matrices.dtype

# We convert to fp64 to avoid numerical errors.
covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64)
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)

# NOTE: in SVD, it is possible that U and VT are both reflections.
# We need to correct them.
batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0)
num_reflections = len(gaussian_idx)
if num_reflections > 0:
LOGGER.warning(
"Received %d reflection matrices from SVD. Flipping them to rotations.",
num_reflections,
)
# Flip the last column of reflection and make it a rotation.
rotations[batch_idx, gaussian_idx, :, -1] *= -1
quaternions = linalg.quaternions_from_rotation_matrices(rotations)
if use_gpu and covariance_matrices.is_cuda:
# GPU path: faster but slightly less precise
try:
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)
except RuntimeError:
# Fallback to CPU float64 for numerical stability if GPU SVD fails
LOGGER.warning("GPU SVD failed, falling back to CPU float64.")
covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64)
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)

# Vectorized reflection correction (faster than indexed assignment)
det_sign = torch.linalg.det(rotations).sign()[..., None, None]
rotations = rotations.clone()
rotations[..., :, 2:3] = rotations[..., :, 2:3] * det_sign
else:
# CPU path: original behavior with float64 for numerical stability
covariance_matrices = covariance_matrices.detach().cpu().to(torch.float64)
rotations, singular_values_2, _ = torch.linalg.svd(covariance_matrices)

# NOTE: in SVD, it is possible that U and VT are both reflections.
# We need to correct them.
batch_idx, gaussian_idx = torch.where(torch.linalg.det(rotations) < 0)
num_reflections = len(gaussian_idx)
if num_reflections > 0:
LOGGER.warning(
"Received %d reflection matrices from SVD. Flipping them to rotations.",
num_reflections,
)
# Flip the last column of reflection and make it a rotation.
rotations[batch_idx, gaussian_idx, :, -1] *= -1

quaternions = linalg.quaternions_from_rotation_matrices(rotations, use_gpu=use_gpu)
quaternions = quaternions.to(dtype=dtype, device=device)
singular_values = singular_values_2.sqrt().to(dtype=dtype, device=device)
return quaternions, singular_values
Expand Down
82 changes: 79 additions & 3 deletions src/sharp/utils/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,34 @@ def rotation_matrices_from_quaternions(quaternions: torch.Tensor) -> torch.Tenso
return matrix_outer + matrix_diag + matrix_cross_1 + matrix_cross_2


def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor:
def quaternions_from_rotation_matrices(
matrices: torch.Tensor,
use_gpu: bool = True,
) -> torch.Tensor:
"""Convert batch of rotation matrices to quaternions.

Args:
matrices: The matrices to convert to quaternions.
use_gpu: If True and matrices are on CUDA, use pure PyTorch GPU implementation
for ~300x faster conversion. If False, use scipy on CPU (original behavior).

Returns:
The quaternions corresponding to the rotation matrices.
The quaternions corresponding to the rotation matrices (w, x, y, z convention).

Note: this operation is not differentiable and will be performed on the CPU.
Note: The GPU implementation is not differentiable but provides significant speedup
for large batches (e.g., 2M+ gaussians). Set use_gpu=False for maximum numerical
precision or when working with small batches on CPU.
"""
if not matrices.shape[-2:] == (3, 3):
raise ValueError(f"matrices have invalid shape {matrices.shape}")

if use_gpu and matrices.is_cuda:
return _quaternions_from_rotation_matrices_gpu(matrices)
return _quaternions_from_rotation_matrices_cpu(matrices)


def _quaternions_from_rotation_matrices_cpu(matrices: torch.Tensor) -> torch.Tensor:
"""CPU implementation using scipy (original behavior)."""
matrices_np = matrices.detach().cpu().numpy()
quaternions_np = Rotation.from_matrix(matrices_np.reshape(-1, 3, 3)).as_quat()
# We use a convention where the w component is at the start of the quaternion.
Expand All @@ -59,6 +74,67 @@ def quaternions_from_rotation_matrices(matrices: torch.Tensor) -> torch.Tensor:
return torch.as_tensor(quaternions_np, device=matrices.device, dtype=matrices.dtype)


def _quaternions_from_rotation_matrices_gpu(matrices: torch.Tensor) -> torch.Tensor:
"""Pure PyTorch GPU implementation using Shepperd's method.

Reference:
https://www.euclideanspace.com/maths/geometry/rotations/conversions/matrixToQuaternion/
"""
# Flatten batch dimensions
original_shape = matrices.shape[:-2]
matrices = matrices.reshape(-1, 3, 3)
batch_size = matrices.shape[0]

# Allocate output
quaternions = torch.zeros(batch_size, 4, device=matrices.device, dtype=matrices.dtype)

# Extract matrix elements
m00, m01, m02 = matrices[:, 0, 0], matrices[:, 0, 1], matrices[:, 0, 2]
m10, m11, m12 = matrices[:, 1, 0], matrices[:, 1, 1], matrices[:, 1, 2]
m20, m21, m22 = matrices[:, 2, 0], matrices[:, 2, 1], matrices[:, 2, 2]

# Compute trace
trace = m00 + m11 + m22

# Case 1: trace > 0
mask1 = trace > 0
s1 = torch.sqrt(trace[mask1] + 1.0) * 2 # s = 4 * w
quaternions[mask1, 0] = 0.25 * s1 # w
quaternions[mask1, 1] = (m21[mask1] - m12[mask1]) / s1 # x
quaternions[mask1, 2] = (m02[mask1] - m20[mask1]) / s1 # y
quaternions[mask1, 3] = (m10[mask1] - m01[mask1]) / s1 # z

# Case 2: m00 > m11 and m00 > m22
mask2 = (~mask1) & (m00 > m11) & (m00 > m22)
s2 = torch.sqrt(1.0 + m00[mask2] - m11[mask2] - m22[mask2]) * 2 # s = 4 * x
quaternions[mask2, 0] = (m21[mask2] - m12[mask2]) / s2 # w
quaternions[mask2, 1] = 0.25 * s2 # x
quaternions[mask2, 2] = (m01[mask2] + m10[mask2]) / s2 # y
quaternions[mask2, 3] = (m02[mask2] + m20[mask2]) / s2 # z

# Case 3: m11 > m22
mask3 = (~mask1) & (~mask2) & (m11 > m22)
s3 = torch.sqrt(1.0 + m11[mask3] - m00[mask3] - m22[mask3]) * 2 # s = 4 * y
quaternions[mask3, 0] = (m02[mask3] - m20[mask3]) / s3 # w
quaternions[mask3, 1] = (m01[mask3] + m10[mask3]) / s3 # x
quaternions[mask3, 2] = 0.25 * s3 # y
quaternions[mask3, 3] = (m12[mask3] + m21[mask3]) / s3 # z

# Case 4: else (m22 is largest)
mask4 = (~mask1) & (~mask2) & (~mask3)
s4 = torch.sqrt(1.0 + m22[mask4] - m00[mask4] - m11[mask4]) * 2 # s = 4 * z
quaternions[mask4, 0] = (m10[mask4] - m01[mask4]) / s4 # w
quaternions[mask4, 1] = (m02[mask4] + m20[mask4]) / s4 # x
quaternions[mask4, 2] = (m12[mask4] + m21[mask4]) / s4 # y
quaternions[mask4, 3] = 0.25 * s4 # z

# Normalize to ensure unit quaternions
quaternions = F.normalize(quaternions, dim=-1)

# Reshape back to original batch shape
return quaternions.reshape(original_shape + (4,))


def get_cross_product_matrix(vectors: torch.Tensor) -> torch.Tensor:
"""Generate cross product matrix for vector exterior product."""
if not vectors.shape[-1] == 3:
Expand Down