@@ -80,7 +80,7 @@ def apply_jagged(self, coords: fvdb.JaggedTensor) -> fvdb.JaggedTensor:
8080 """Transform coordinates from source frame to target frame.
8181
8282 Args:
83- coords: JaggedTensor of shape [N , 3] containing 3D coordinates.
83+ coords: JaggedTensor of shape [B, Njagged , 3] containing 3D coordinates.
8484
8585 Returns:
8686 Transformed coordinates as a JaggedTensor with the same structure.
@@ -101,13 +101,13 @@ def apply(self, coords: torch.Tensor | fvdb.JaggedTensor) -> torch.Tensor | fvdb
101101 Args:
102102 coords: 3D coordinates to transform.
103103 - torch.Tensor: shape [N, 3] where N is the number of points.
104- - fvdb.JaggedTensor: shape [B][N, 3] where B is batch size and N
104+ - fvdb.JaggedTensor: shape [B, Njagged, 3] where B is batch size and Njagged
105105 varies per batch element.
106106
107107 Returns:
108108 Transformed coordinates with the same type and shape as input.
109109 - torch.Tensor input -> torch.Tensor output, shape [N, 3]
110- - fvdb.JaggedTensor input -> fvdb.JaggedTensor output, shape [B][N , 3]
110+ - fvdb.JaggedTensor input -> fvdb.JaggedTensor output, shape [B, Njagged , 3]
111111 """
112112 if isinstance (coords , torch .Tensor ):
113113 return self .apply_tensor (coords )
@@ -288,7 +288,7 @@ def compose(self, other: "CoordXform") -> "CoordXform":
288288 Returns:
289289 A new CoordXform representing the composition.
290290 """
291- return ComposedXform (first = other , second = self )
291+ return ComposedXform ([ other , self ] )
292292
293293 @overload
294294 def __matmul__ (self , other : "CoordXform" ) -> "CoordXform" : ...
@@ -346,40 +346,54 @@ def __matmul__(
346346
347347@dataclass (frozen = True )
348348class ComposedXform (CoordXform ):
349- """Composition of two transformations applied in sequence.
349+ """Composition of multiple transformations applied in sequence.
350+ The 0th transformation is applied first, followed by the 1st, etc.
350351
351- Represents the transformation: coords_out = second(first(coords_in))
352352 """
353353
354- first : CoordXform
355- second : CoordXform
354+ xforms : list [CoordXform ]
356355
357356 def apply_tensor (self , coords : torch .Tensor ) -> torch .Tensor :
358- """Apply first, then second transformation."""
359- return self .second .apply_tensor (self .first .apply_tensor (coords ))
357+ """Apply all transformations in to tensor coords insequence."""
358+ for xform in self .xforms :
359+ coords = xform .apply_tensor (coords )
360+ return coords
360361
361362 def apply_jagged (self , coords : fvdb .JaggedTensor ) -> fvdb .JaggedTensor :
362- """Apply first, then second transformation."""
363- return self .second .apply_jagged (self .first .apply_jagged (coords ))
363+ """Apply all transformations to jagged tensor coords in sequence."""
364+ for xform in self .xforms :
365+ coords = xform .apply_jagged (coords )
366+ return coords
364367
365368 def apply_bounds_tensor (self , bounds : torch .Tensor ) -> torch .Tensor :
366- """Apply first, then second bounds transformation."""
367- return self .second .apply_bounds_tensor (self .first .apply_bounds_tensor (bounds ))
369+ """Apply all transformations to bounds tensor in sequence."""
370+ for xform in self .xforms :
371+ bounds = xform .apply_bounds_tensor (bounds )
372+ return bounds
368373
369374 def apply_bounds_jagged (self , bounds : fvdb .JaggedTensor ) -> fvdb .JaggedTensor :
370- """Apply first, then second bounds transformation."""
371- return self .second .apply_bounds_jagged (self .first .apply_bounds_jagged (bounds ))
375+ """Apply all transformations to bounds jagged tensor in sequence."""
376+ for xform in self .xforms :
377+ bounds = xform .apply_bounds_jagged (bounds )
378+ return bounds
372379
373380 def inverse (self ) -> CoordXform :
374- """Return inverse: (second o first )^-1 = first^-1 o second^-1 ."""
381+ """Return inverse: (all xforms )^-1 = inverse(all xforms) in reverse order ."""
375382 if not self .invertible :
376383 raise NotImplementedError ("Cannot invert: one or both transforms not invertible." )
377- return ComposedXform (self .second .inverse (), self .first .inverse ())
384+ return ComposedXform ([xform .inverse () for xform in reversed (self .xforms )])
385+
386+ def compose (self , other : "CoordXform" ) -> "CoordXform" :
387+ """Compose this transformation with another. Other is applied first, then self."""
388+ if isinstance (other , ComposedXform ):
389+ return ComposedXform (other .xforms + self .xforms )
390+ else :
391+ return ComposedXform ([other ] + self .xforms )
378392
379393 @property
380394 def invertible (self ) -> bool :
381395 """True if both first and second are invertible."""
382- return self . first . invertible and self .second . invertible
396+ return all ( xform . invertible for xform in self .xforms )
383397
384398
385399@dataclass (frozen = True )
@@ -476,7 +490,8 @@ def compose(self, other: "CoordXform") -> "CoordXform":
476490 scale = other .scale * self .scale ,
477491 translation = other .translation * self .scale + self .translation ,
478492 )
479- return ComposedXform (first = other , second = self )
493+ else :
494+ return super ().compose (other )
480495
481496 @property
482497 def invertible (self ) -> bool :
0 commit comments