@@ -29,7 +29,7 @@ class Graft:
2929 def __init__ (self , * args ):
3030 pass
3131
32- def add_statistics (self , grad : torch .Tensor , unused_beta2 : float ):
32+ def add_statistics (self , grad : torch .Tensor , unused_beta2 : float ) -> None :
3333 r"""Add the statistics."""
3434 pass
3535
@@ -47,7 +47,7 @@ class SGDGraft(Graft):
4747
4848 def __init__ (self , var : torch .Tensor ):
4949 super ().__init__ (var )
50- self .momentum : torch .Tensor = torch .zeros_like (var , device = var . device )
50+ self .momentum : torch .Tensor = torch .zeros_like (var )
5151
5252 def update_momentum (self , update : torch .Tensor , beta1 : float ) -> torch .Tensor :
5353 r"""Update momentum."""
@@ -78,13 +78,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
7878 self .diagonal_eps = diagonal_eps
7979 self .statistics : torch .Tensor = torch .zeros_like (var )
8080
81- def add_statistics (self , grad : torch .Tensor , _ ):
81+ def add_statistics (self , grad : torch .Tensor , _ ) -> None :
8282 r"""Add the statistics."""
8383 self .statistics .add_ (grad .pow (2 ))
8484
8585 def precondition_gradient (self , grad : torch .Tensor ) -> torch .Tensor :
8686 r"""Get preconditioned gradient."""
87- return grad / ( torch . sqrt (self .statistics ) + self .diagonal_eps )
87+ return grad . div (self .statistics . sqrt (). add_ ( self .diagonal_eps ) )
8888
8989
9090class RMSPropGraft (SGDGraft ):
@@ -99,13 +99,13 @@ def __init__(self, var: torch.Tensor, diagonal_eps: float):
9999 self .diagonal_eps = diagonal_eps
100100 self .statistics : torch .Tensor = torch .zeros_like (var )
101101
102- def add_statistics (self , grad : torch .Tensor , beta2 : float ):
102+ def add_statistics (self , grad : torch .Tensor , beta2 : float ) -> None :
103103 r"""Add the statistics."""
104104 self .statistics .mul_ (beta2 ).addcmul_ (grad , grad , value = 1.0 - beta2 )
105105
106106 def precondition_gradient (self , grad : torch .Tensor ) -> torch .Tensor :
107107 r"""Get preconditioned gradient."""
108- return grad / ( torch . sqrt (self .statistics ) + self .diagonal_eps )
108+ return grad . div (self .statistics . sqrt (). add_ ( self .diagonal_eps ) )
109109
110110
111111class BlockPartitioner :
@@ -121,51 +121,51 @@ class BlockPartitioner:
121121 """
122122
123123 def __init__ (self , var : torch .Tensor , rank : int , block_size : int , pre_conditioner_type : int ):
124- self .shape : List [ int ] = var .shape
124+ self .shape : torch . Size = var .shape
125125
126- self .splits : List [Tuple [int , np . ndarray ]] = []
127- self .split_sizes : List [Tuple [int , np . ndarray ]] = []
126+ self .splits : List [Tuple [int , torch . Tensor ]] = []
127+ self .split_sizes : List [Tuple [int , torch . Tensor ]] = []
128128
129- split_sizes : List [np . ndarray ] = []
129+ split_sizes : List [torch . Tensor ] = []
130130
131131 # We split var into smaller blocks. Here we store the metadata to make that split.
132132 for i , d in enumerate (self .shape ):
133133 if block_size <= 0 or block_size >= d :
134- split_sizes .append (np . array ([d ], dtype = np .int32 ))
134+ split_sizes .append (torch . tensor ([d ], dtype = torch .int32 ))
135135 continue
136136
137137 # d - 1, otherwise split appends a 0-size array.
138138 num_split : int = (d - 1 ) // block_size
139- indices = (np .arange (num_split , dtype = np .int32 ) + 1 ) * block_size
139+ indices = (torch .arange (num_split , dtype = torch .int32 ) + 1 ) * block_size
140140
141- sizes : np . ndarray = np . ones ( num_split + 1 , dtype = np .int32 ) * block_size
141+ sizes : torch . Tensor = torch . full (( num_split + 1 ,), block_size , dtype = torch .int32 )
142142 sizes [- 1 ] = d - indices [- 1 ]
143143
144144 self .splits .append ((i , indices ))
145145 self .split_sizes .append ((i , sizes ))
146146 split_sizes .append (sizes )
147147
148148 self .num_splits : int = len (split_sizes )
149- self .pre_conditioner_shapes : List [List [int ]] = self .build_pre_conditioner_shapes (
149+ self .pre_conditioner_shapes : List [List [torch . Tensor ]] = self .build_pre_conditioner_shapes (
150150 split_sizes , pre_conditioner_type , rank
151151 )
152152
153153 @staticmethod
154154 def build_pre_conditioner_shapes (
155- split_sizes : List [np . ndarray ], pre_conditioner_type : int , rank : int
156- ) -> List [List [int ]]:
155+ split_sizes : List [torch . Tensor ], pre_conditioner_type : int , rank : int
156+ ) -> List [List [torch . Tensor ]]:
157157 r"""Build pre-conditioner shapes."""
158- pre_conditioner_shapes : List [List [int ]] = []
158+ pre_conditioner_shapes : List [List [torch . Tensor ]] = []
159159 for t in itertools .product (* split_sizes ):
160- t_shape : List [Optional [List [int ]]] = [[d , d ] for d in t ]
160+ t_shape : List [Optional [List [torch . Tensor ]]] = [[d , d ] for d in t ]
161161 if pre_conditioner_type == PreConditionerType .INPUT :
162- t_shape = t_shape [: - 1 ] + [ None ]
163- if pre_conditioner_type == PreConditionerType .OUTPUT :
162+ t_shape [ - 1 ] = None
163+ elif pre_conditioner_type == PreConditionerType .OUTPUT :
164164 t_shape = [None ] * (rank - 1 ) + t_shape [- 1 :]
165165 pre_conditioner_shapes .extend (t_shape )
166166 return pre_conditioner_shapes
167167
168- def shapes_for_pre_conditioners (self ) -> List [List [int ]]:
168+ def shapes_for_pre_conditioners (self ) -> List [List [torch . Tensor ]]:
169169 r"""Get shapes of pre-conditioner."""
170170 return self .pre_conditioner_shapes
171171
@@ -244,7 +244,7 @@ def __init__(
244244
245245 self .w2 : float = 1.0 if self .beta2 == 1.0 else (1.0 - self .beta2 )
246246
247- self .original_shape : List [ int ] = var .shape
247+ self .original_shape : torch . Size = var .shape
248248 self .transformed_shape : List [int ] = (
249249 merge_small_dims (self .original_shape , block_size ) if shape_interpretation else var .shape
250250 )
@@ -267,7 +267,7 @@ def __init__(
267267 pre_conditioner_type = self .pre_conditioner_type ,
268268 )
269269
270- shapes : List [Optional [List [int ]]] = self .partitioner .shapes_for_pre_conditioners ()
270+ shapes : List [Optional [List [torch . Tensor ]]] = self .partitioner .shapes_for_pre_conditioners ()
271271 self .statistics = [self .matrix_eps * torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
272272 self .pre_conditioners = [torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
273273 self .is_same_shapes = None not in shapes and len (np .unique (shapes )) == 1
@@ -291,7 +291,7 @@ def skip_precondition(self, x: torch.Tensor) -> bool:
291291 dim > self .no_preconditioning_for_layers_with_dim_gt for dim in x .shape
292292 )
293293
294- def add_statistics (self , grad : torch .Tensor ):
294+ def add_statistics (self , grad : torch .Tensor ) -> None :
295295 r"""Compute statistics from gradients and add to the correct state entries.
296296
297297 :param grad: torch.Tensor. gradient to compute statistics from.
@@ -302,14 +302,13 @@ def add_statistics(self, grad: torch.Tensor):
302302 reshaped_grad : torch .Tensor = torch .reshape (grad , self .transformed_shape )
303303 partitioned_grads : List [torch .Tensor ] = self .partitioner .partition (reshaped_grad )
304304
305- for j in range (len (partitioned_grads )):
306- partitioned_grad : torch .Tensor = partitioned_grads [j ]
305+ for j , partitioned_grad in enumerate (partitioned_grads ):
307306 for i in range (self .rank ):
308307 axes : List [int ] = [ax for ax in range (partitioned_grad .ndim ) if ax != i ]
309308 stat : torch .Tensor = torch .tensordot (partitioned_grad , partitioned_grad , dims = [axes , axes ])
310309 self .statistics [j * self .rank + i ].mul_ (self .beta2 ).add_ (stat , alpha = self .w2 )
311310
312- def compute_pre_conditioners (self ):
311+ def compute_pre_conditioners (self ) -> None :
313312 r"""Compute L^{-1/exp} for each stats matrix L.
314313
315314 If `self.use_svd` is enabled and where all shapes of statistics & pre-conditioners are same, perform batch SVD.
@@ -333,15 +332,15 @@ def compute_pre_conditioners(self):
333332 def precondition_block (
334333 partitioned_grad : torch .Tensor ,
335334 should_preconditioned_dims : List [bool ],
336- pre_conditioners_for_grad : List [torch .Tensor ],
335+ pre_conditioners_for_grad : Union [ List [torch . Tensor ], torch .Tensor ],
337336 ) -> torch .Tensor :
338337 r"""Perform a preconditioning operation on a single gradient block.
339338
340339 Loop invariant: the dimension to be preconditioned is first
341340 We keep all axes in the same cyclic order they were originally.
342341 """
343342 rank : int = len (partitioned_grad .shape )
344- roll : Tuple [int , ...] = (* tuple ( range (1 , rank ) ), 0 )
343+ roll : Tuple [int , ...] = (* range (1 , rank ), 0 )
345344
346345 i : int = 0
347346 for should_precondition_dim in should_preconditioned_dims :
@@ -376,7 +375,7 @@ def preconditioned_grad(self, grad: torch.Tensor) -> torch.Tensor:
376375
377376 merged_grad = self .partitioner .merge_partitions (pre_cond_partitioned_grads )
378377
379- return torch .reshape (merged_grad , self .original_shape )
378+ return merged_grad .reshape (self .original_shape )
380379
381380
382381def build_graft (p : torch .Tensor , graft_type : int , diagonal_eps : float = 1e-10 ):
@@ -407,7 +406,8 @@ def power_iteration(mat_g: torch.Tensor, num_iters: int = 100) -> torch.Tensor:
407406
408407 for _ in range (num_iters ):
409408 torch .mv (mat_g , v , out = mat_v )
410- v = mat_v .div (torch .linalg .norm (mat_v ))
409+ v .copy_ (mat_v )
410+ v .div_ (torch .linalg .norm (v ))
411411
412412 return (v .t () @ mat_g @ v ).clamp_min_ (1e-16 )
413413
@@ -490,7 +490,7 @@ def compute_power_schur_newton(
490490
491491@torch .no_grad ()
492492def compute_power_svd (matrix : torch .Tensor , power : float ) -> torch .Tensor :
493- r"""Compute G^{-1/p} using a SVD.
493+ r"""Compute G^{-1/p} using SVD.
494494
495495 Calculate SVD on the GPU. Sometimes, SVD on the CPU is faster than GPU, but based on the several experiments,
496496 CUDA seems much faster than on CPU.
@@ -503,14 +503,14 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
503503 return u @ (s .diag () if len (matrix .shape ) == 2 else s .diag_embed ()) @ vh
504504
505505
506- def merge_small_dims (shape_to_merge : List [int ], max_dim : int ) -> List [int ]:
506+ def merge_small_dims (shape_to_merge : Union [ List [int ], torch . Size ], max_dim : int ) -> List [int ]:
507507 r"""Merge small dimensions.
508508
509509 If there are some small dimensions, we collapse them
510510 e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
511511 [1, 2, 768, 1, 2048] --> [2, 768, 2048].
512512
513- :param shape_to_merge: List[int]. Shape to merge small dimensions.
513+ :param shape_to_merge: Union[ List[int], torch.Size ]. Shape to merge small dimensions.
514514 :param max_dim: int. Maximal dimension of output shape used in merging.
515515 """
516516 merged_shape : List [int ] = []
0 commit comments