@@ -112,7 +112,7 @@ def call(self, target, condition, **kwargs):
112112
113113 # Handle 3D case for a set-flow and repeat condition over
114114 # the second `time` or `n_observations` axis of `target``
115- if target . ndim == 3 and condition . ndim == 2 :
115+ if tf . rank ( target ) == 3 and tf . rank ( condition ) == 2 :
116116 shape = tf .shape (target )
117117 condition = tf .expand_dims (condition , 1 )
118118 condition = tf .tile (condition , [1 , shape [1 ], 1 ])
@@ -228,7 +228,7 @@ def _forward(self, target):
228228 """Performs a learnable generalized permutation over the last axis."""
229229
230230 shape = tf .shape (target )
231- rank = target . ndim
231+ rank = tf . rank ( target )
232232 log_det = tf .math .log (tf .math .abs (tf .linalg .det (self .W )))
233233 if rank == 2 :
234234 z = tf .linalg .matmul (target , self .W )
@@ -241,7 +241,7 @@ def _inverse(self, z):
241241 """Un-does the learnable permutation over the last axis."""
242242
243243 W_inv = tf .linalg .inv (self .W )
244- rank = z . ndim
244+ rank = tf . rank ( z )
245245 if rank == 2 :
246246 return tf .linalg .matmul (z , W_inv )
247247 return tf .tensordot (z , W_inv , [[rank - 1 ], [0 ]])
@@ -402,11 +402,11 @@ def _initalize_parameters_data_dependent(self, init_data):
402402 """
403403
404404 # 2D Tensor case, assume first batch dimension
405- if init_data . ndim == 2 :
405+ if tf . rank ( init_data ) == 2 :
406406 mean = tf .math .reduce_mean (init_data , axis = 0 )
407407 std = tf .math .reduce_std (init_data , axis = 0 )
408408 # 3D Tensor case, assume first batch dimension, second number of observations dimension
409- elif init_data . ndim == 3 :
409+ elif tf . rank ( init_data ) == 3 :
410410 mean = tf .math .reduce_mean (init_data , axis = (0 , 1 ))
411411 std = tf .math .reduce_std (init_data , axis = (0 , 1 ))
412412 # Raise other cases
@@ -527,7 +527,7 @@ def call(self, x, **kwargs):
527527 # Example: Output dim is (batch_size, inv_dim) - > (batch_size, N, inv_dim)
528528 out_inv = self .invariant_module (x , ** kwargs )
529529 out_inv = tf .expand_dims (out_inv , - 2 )
530- tiler = [1 ] * x . ndim
530+ tiler = [1 ] * tf . rank ( x )
531531 tiler [- 2 ] = shape [- 2 ]
532532 out_inv_rep = tf .tile (out_inv , tiler )
533533
0 commit comments