@@ -121,51 +121,51 @@ class BlockPartitioner:
121121 """
122122
123123 def __init__ (self , var : torch .Tensor , rank : int , block_size : int , pre_conditioner_type : int ):
124- self .shape : List [ int ] = var .shape
124+ self .shape : torch . Size = var .shape
125125
126- self .splits : List [Tuple [int , np . ndarray ]] = []
127- self .split_sizes : List [Tuple [int , np . ndarray ]] = []
126+ self .splits : List [Tuple [int , torch . Tensor ]] = []
127+ self .split_sizes : List [Tuple [int , torch . Tensor ]] = []
128128
129- split_sizes : List [np . ndarray ] = []
129+ split_sizes : List [torch . Tensor ] = []
130130
131131 # We split var into smaller blocks. Here we store the metadata to make that split.
132132 for i , d in enumerate (self .shape ):
133133 if block_size <= 0 or block_size >= d :
134- split_sizes .append (np . array ([d ], dtype = np .int32 ))
134+ split_sizes .append (torch . tensor ([d ], dtype = torch .int32 ))
135135 continue
136136
137137 # d - 1, otherwise split appends a 0-size array.
138138 num_split : int = (d - 1 ) // block_size
139- indices = (np .arange (num_split , dtype = np .int32 ) + 1 ) * block_size
139+ indices = (torch .arange (num_split , dtype = torch .int32 ) + 1 ) * block_size
140140
141- sizes : np . ndarray = np . ones ( num_split + 1 , dtype = np .int32 ) * block_size
141+ sizes : torch . Tensor = torch . full (( num_split + 1 ,), block_size , dtype = torch .int32 )
142142 sizes [- 1 ] = d - indices [- 1 ]
143143
144144 self .splits .append ((i , indices ))
145145 self .split_sizes .append ((i , sizes ))
146146 split_sizes .append (sizes )
147147
148148 self .num_splits : int = len (split_sizes )
149- self .pre_conditioner_shapes : List [List [int ]] = self .build_pre_conditioner_shapes (
149+ self .pre_conditioner_shapes : List [List [torch . Tensor ]] = self .build_pre_conditioner_shapes (
150150 split_sizes , pre_conditioner_type , rank
151151 )
152152
153153 @staticmethod
154154 def build_pre_conditioner_shapes (
155- split_sizes : List [np . ndarray ], pre_conditioner_type : int , rank : int
156- ) -> List [List [int ]]:
155+ split_sizes : List [torch . Tensor ], pre_conditioner_type : int , rank : int
156+ ) -> List [List [torch . Tensor ]]:
157157 r"""Build pre-conditioner shapes."""
158- pre_conditioner_shapes : List [List [int ]] = []
158+ pre_conditioner_shapes : List [List [torch . Tensor ]] = []
159159 for t in itertools .product (* split_sizes ):
160- t_shape : List [Optional [List [int ]]] = [[d , d ] for d in t ]
160+ t_shape : List [Optional [List [torch . Tensor ]]] = [[d , d ] for d in t ]
161161 if pre_conditioner_type == PreConditionerType .INPUT :
162- t_shape = t_shape [: - 1 ] + [ None ]
163- if pre_conditioner_type == PreConditionerType .OUTPUT :
162+ t_shape [ - 1 ] = None
163+ elif pre_conditioner_type == PreConditionerType .OUTPUT :
164164 t_shape = [None ] * (rank - 1 ) + t_shape [- 1 :]
165165 pre_conditioner_shapes .extend (t_shape )
166166 return pre_conditioner_shapes
167167
168- def shapes_for_pre_conditioners (self ) -> List [List [int ]]:
168+ def shapes_for_pre_conditioners (self ) -> List [List [torch . Tensor ]]:
169169 r"""Get shapes of pre-conditioner."""
170170 return self .pre_conditioner_shapes
171171
@@ -244,7 +244,7 @@ def __init__(
244244
245245 self .w2 : float = 1.0 if self .beta2 == 1.0 else (1.0 - self .beta2 )
246246
247- self .original_shape : List [ int ] = var .shape
247+ self .original_shape : torch . Size = var .shape
248248 self .transformed_shape : List [int ] = (
249249 merge_small_dims (self .original_shape , block_size ) if shape_interpretation else var .shape
250250 )
@@ -267,7 +267,7 @@ def __init__(
267267 pre_conditioner_type = self .pre_conditioner_type ,
268268 )
269269
270- shapes : List [Optional [List [int ]]] = self .partitioner .shapes_for_pre_conditioners ()
270+ shapes : List [Optional [List [torch . Tensor ]]] = self .partitioner .shapes_for_pre_conditioners ()
271271 self .statistics = [self .matrix_eps * torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
272272 self .pre_conditioners = [torch .eye (shape [0 ], device = var .device ) for shape in shapes if shape ]
273273 self .is_same_shapes = None not in shapes and len (np .unique (shapes )) == 1
@@ -504,14 +504,14 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
504504 return u @ (s .diag () if len (matrix .shape ) == 2 else s .diag_embed ()) @ vh
505505
506506
507- def merge_small_dims (shape_to_merge : List [int ], max_dim : int ) -> List [int ]:
507+ def merge_small_dims (shape_to_merge : Union [ List [int ], torch . Size ], max_dim : int ) -> List [int ]:
508508 r"""Merge small dimensions.
509509
510510 If there are some small dimensions, we collapse them
511511 e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
512512 [1, 2, 768, 1, 2048] --> [2, 768, 2048].
513513
514- :param shape_to_merge: List[int]. Shape to merge small dimensions.
514+ :param shape_to_merge: Union[ List[int], torch.Size ]. Shape to merge small dimensions.
515515 :param max_dim: int. Maximal dimension of output shape used in merging.
516516 """
517517 merged_shape : List [int ] = []
0 commit comments