@@ -270,7 +270,7 @@ def __init__(
270270 shapes : List [Optional [List [torch .Tensor ]]] = self .partitioner .shapes_for_pre_conditioners ()
271271 self .statistics = [self .matrix_eps * torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
272272 self .pre_conditioners = [torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
273- self .is_same_shapes = None not in shapes and len (np .unique (shapes )) == 1
273+ self .is_same_shapes = None not in shapes and len (torch .unique (shapes )) == 1
274274
275275 if self .is_same_shapes :
276276 self .statistics = torch .stack (self .statistics , dim = 0 )
@@ -302,8 +302,7 @@ def add_statistics(self, grad: torch.Tensor) -> None:
302302 reshaped_grad : torch .Tensor = torch .reshape (grad , self .transformed_shape )
303303 partitioned_grads : List [torch .Tensor ] = self .partitioner .partition (reshaped_grad )
304304
305- for j in range (len (partitioned_grads )):
306- partitioned_grad : torch .Tensor = partitioned_grads [j ]
305+ for j , partitioned_grad in enumerate (partitioned_grads ):
307306 for i in range (self .rank ):
308307 axes : List [int ] = [ax for ax in range (partitioned_grad .ndim ) if ax != i ]
309308 stat : torch .Tensor = torch .tensordot (partitioned_grad , partitioned_grad , dims = [axes , axes ])
@@ -341,7 +340,7 @@ def precondition_block(
341340 We keep all axes in the same cyclic order they were originally.
342341 """
343342 rank : int = len (partitioned_grad .shape )
344- roll : Tuple [int , ...] = (* tuple ( range (1 , rank ) ), 0 )
343+ roll : Tuple [int , ...] = (* range (1 , rank ), 0 )
345344
346345 i : int = 0
347346 for should_precondition_dim in should_preconditioned_dims :
@@ -376,7 +375,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
376375
377376 merged_grad = self .partitioner .merge_partitions (pre_cond_partitioned_grads )
378377
379- return torch .reshape (merged_grad , self .original_shape )
378+ return merged_grad .reshape (self .original_shape )
380379
381380
382381def build_graft (p : torch .Tensor , graft_type : int , diagonal_eps : float = 1e-10 ):
0 commit comments