Skip to content

Commit 92c2e10

Browse files
committed
Made composed xforms be a sequence. Changed voxel status eval to use inject
Signed-off-by: Christopher Horvath <[email protected]>
1 parent c802771 commit 92c2e10

File tree

4 files changed

+325
-40
lines changed

4 files changed

+325
-40
lines changed

surface_reconstruction/nksr/nksr_fvdb/coord_xform.py

Lines changed: 35 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def apply_jagged(self, coords: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
8080
"""Transform coordinates from source frame to target frame.
8181
8282
Args:
83-
coords: JaggedTensor of shape [N, 3] containing 3D coordinates.
83+
coords: JaggedTensor of shape [B, Njagged, 3] containing 3D coordinates.
8484
8585
Returns:
8686
Transformed coordinates as a JaggedTensor with the same structure.
@@ -101,13 +101,13 @@ def apply(self, coords: torch.Tensor | fvdb.JaggedTensor) -> torch.Tensor | fvdb
101101
Args:
102102
coords: 3D coordinates to transform.
103103
- torch.Tensor: shape [N, 3] where N is the number of points.
104-
- fvdb.JaggedTensor: shape [B][N, 3] where B is batch size and N
104+
- fvdb.JaggedTensor: shape [B, Njagged, 3] where B is batch size and Njagged
105105
varies per batch element.
106106
107107
Returns:
108108
Transformed coordinates with the same type and shape as input.
109109
- torch.Tensor input -> torch.Tensor output, shape [N, 3]
110-
- fvdb.JaggedTensor input -> fvdb.JaggedTensor output, shape [B][N, 3]
110+
- fvdb.JaggedTensor input -> fvdb.JaggedTensor output, shape [B, Njagged, 3]
111111
"""
112112
if isinstance(coords, torch.Tensor):
113113
return self.apply_tensor(coords)
@@ -288,7 +288,7 @@ def compose(self, other: "CoordXform") -> "CoordXform":
288288
Returns:
289289
A new CoordXform representing the composition.
290290
"""
291-
return ComposedXform(first=other, second=self)
291+
return ComposedXform([other, self])
292292

293293
@overload
294294
def __matmul__(self, other: "CoordXform") -> "CoordXform": ...
@@ -346,40 +346,54 @@ def __matmul__(
346346

347347
@dataclass(frozen=True)
348348
class ComposedXform(CoordXform):
349-
"""Composition of two transformations applied in sequence.
349+
"""Composition of multiple transformations applied in sequence.
350+
The 0th transformation is applied first, followed by the 1st, etc.
350351
351-
Represents the transformation: coords_out = second(first(coords_in))
352352
"""
353353

354-
first: CoordXform
355-
second: CoordXform
354+
xforms: list[CoordXform]
356355

357356
def apply_tensor(self, coords: torch.Tensor) -> torch.Tensor:
358-
"""Apply first, then second transformation."""
359-
return self.second.apply_tensor(self.first.apply_tensor(coords))
357+
"""Apply all transformations in to tensor coords insequence."""
358+
for xform in self.xforms:
359+
coords = xform.apply_tensor(coords)
360+
return coords
360361

361362
def apply_jagged(self, coords: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
362-
"""Apply first, then second transformation."""
363-
return self.second.apply_jagged(self.first.apply_jagged(coords))
363+
"""Apply all transformations to jagged tensor coords in sequence."""
364+
for xform in self.xforms:
365+
coords = xform.apply_jagged(coords)
366+
return coords
364367

365368
def apply_bounds_tensor(self, bounds: torch.Tensor) -> torch.Tensor:
366-
"""Apply first, then second bounds transformation."""
367-
return self.second.apply_bounds_tensor(self.first.apply_bounds_tensor(bounds))
369+
"""Apply all transformations to bounds tensor in sequence."""
370+
for xform in self.xforms:
371+
bounds = xform.apply_bounds_tensor(bounds)
372+
return bounds
368373

369374
def apply_bounds_jagged(self, bounds: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
370-
"""Apply first, then second bounds transformation."""
371-
return self.second.apply_bounds_jagged(self.first.apply_bounds_jagged(bounds))
375+
"""Apply all transformations to bounds jagged tensor in sequence."""
376+
for xform in self.xforms:
377+
bounds = xform.apply_bounds_jagged(bounds)
378+
return bounds
372379

373380
def inverse(self) -> CoordXform:
374-
"""Return inverse: (second o first)^-1 = first^-1 o second^-1."""
381+
"""Return inverse: (all xforms)^-1 = inverse(all xforms) in reverse order."""
375382
if not self.invertible:
376383
raise NotImplementedError("Cannot invert: one or both transforms not invertible.")
377-
return ComposedXform(self.second.inverse(), self.first.inverse())
384+
return ComposedXform([xform.inverse() for xform in reversed(self.xforms)])
385+
386+
def compose(self, other: "CoordXform") -> "CoordXform":
387+
"""Compose this transformation with another. Other is applied first, then self."""
388+
if isinstance(other, ComposedXform):
389+
return ComposedXform(other.xforms + self.xforms)
390+
else:
391+
return ComposedXform([other] + self.xforms)
378392

379393
@property
380394
def invertible(self) -> bool:
381395
"""True if both first and second are invertible."""
382-
return self.first.invertible and self.second.invertible
396+
return all(xform.invertible for xform in self.xforms)
383397

384398

385399
@dataclass(frozen=True)
@@ -476,7 +490,8 @@ def compose(self, other: "CoordXform") -> "CoordXform":
476490
scale=other.scale * self.scale,
477491
translation=other.translation * self.scale + self.translation,
478492
)
479-
return ComposedXform(first=other, second=self)
493+
else:
494+
return super().compose(other)
480495

481496
@property
482497
def invertible(self) -> bool:

surface_reconstruction/nksr/nksr_fvdb/sparse_feature_hierarchy.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class VoxelStatus(Enum):
2525
VS_EXIST_CONTINUE = 2
2626

2727

28-
def evaluate_voxel_status(target_grid: GridBatch, coarse_grid: GridBatch, fine_grid: GridBatch | None) -> torch.Tensor:
28+
def evaluate_voxel_status(target_grid: GridBatch, coarse_grid: GridBatch, fine_grid: GridBatch | None) -> JaggedTensor:
2929
"""Compute per-voxel status labels for structure prediction training.
3030
3131
Classifies each voxel in target_grid by comparing against coarse/fine hierarchy levels:
@@ -49,26 +49,32 @@ def evaluate_voxel_status(target_grid: GridBatch, coarse_grid: GridBatch, fine_g
4949
if coarse_grid.device != device:
5050
raise ValueError(f"Device not match {device} vs {coarse_grid.device}.")
5151

52-
if fine_grid is not None and fine_grid.device != device:
53-
raise ValueError(f"Device not match {device} vs {fine_grid.device}.")
52+
# create a coarse feature jagged tensor, by broadcasting. The coarse grid is filled with
53+
# VoxelStatus.VS_EXIST_STOP.
54+
coarse_feature = coarse_grid.jagged_like(
55+
torch.tensor(VoxelStatus.VS_EXIST_STOP.value, dtype=torch.uint8, device=device).expand(coarse_grid.total_voxels)
56+
)
5457

55-
# Start with a flat tensor of all voxels in the target grid, initialized to NON_EXIST.
56-
status = torch.full((target_grid.total_voxels,), VoxelStatus.VS_NON_EXIST.value, dtype=torch.uint8, device=device)
57-
58-
# Set the intersection of the coarse grid and the target grid to EXIST_STOP.
59-
coarse_exist_idx = target_grid.ijk_to_index(coarse_grid.ijk, cumulative=True)
60-
coarse_exist_mask = coarse_exist_idx.jdata != -1
61-
status[coarse_exist_mask] = VoxelStatus.VS_EXIST_STOP.value
58+
# Inject the coarse feature into the target
59+
target_feature = coarse_grid.inject_to(
60+
dst_grid=target_grid, src=coarse_feature, dst=None, default_value=VoxelStatus.VS_NON_EXIST.value
61+
)
6262

6363
# If there is a fine grid, set the intersection of the fine grid and the target grid to EXIST_CONTINUE.
64-
# This can overwrite the coarse grid's EXIST_STOP status.
64+
# This can overwrite the coarse grid's EXIST_STOP status. It modifies the target_feature in place,
65+
# and doesn't overwrite except where the fine coarsened grid exists.
6566
if fine_grid is not None:
67+
if fine_grid.device != device:
68+
raise ValueError(f"Device not match {device} vs {fine_grid.device}.")
6669
fine_coarsened = fine_grid.coarsened_grid(coarsening_factor=2)
67-
fine_exist_idx = target_grid.ijk_to_index(fine_coarsened.ijk, cumulative=True)
68-
fine_exist_mask = fine_exist_idx.jdata != -1
69-
status[fine_exist_mask] = VoxelStatus.VS_EXIST_CONTINUE.value
70+
fine_coarsened_feature = fine_coarsened.jagged_like(
71+
torch.tensor(VoxelStatus.VS_EXIST_CONTINUE.value, dtype=torch.uint8, device=device).expand(
72+
fine_coarsened.total_voxels
73+
)
74+
)
75+
fine_coarsened.inject_to(dst_grid=target_grid, src=fine_coarsened_feature, dst=target_feature)
7076

71-
return status
77+
return target_feature
7278

7379

7480
@dataclass(frozen=True)
@@ -270,7 +276,7 @@ def from_point_splatting(
270276

271277
return cls(levels=levels)
272278

273-
def evaluate_voxel_status(self, target_grid: GridBatch, coarse_depth: int | None = None) -> torch.Tensor:
279+
def evaluate_voxel_status(self, target_grid: GridBatch, coarse_depth: int | None = None) -> JaggedTensor:
274280
"""Evaluate voxel status at a given hierarchy depth.
275281
276282
Convenience wrapper around the module-level evaluate_voxel_status function.

0 commit comments

Comments
 (0)