Skip to content

Commit a770ec7

Browse files
committed
refactor: numpy to torch
1 parent 552dbae commit a770ec7

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

pytorch_optimizer/optimizer/shampoo_utils.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -121,51 +121,51 @@ class BlockPartitioner:
121121
"""
122122

123123
def __init__(self, var: torch.Tensor, rank: int, block_size: int, pre_conditioner_type: int):
124-
self.shape: List[int] = var.shape
124+
self.shape: torch.Size = var.shape
125125

126-
self.splits: List[Tuple[int, np.ndarray]] = []
127-
self.split_sizes: List[Tuple[int, np.ndarray]] = []
126+
self.splits: List[Tuple[int, torch.Tensor]] = []
127+
self.split_sizes: List[Tuple[int, torch.Tensor]] = []
128128

129-
split_sizes: List[np.ndarray] = []
129+
split_sizes: List[torch.Tensor] = []
130130

131131
# We split var into smaller blocks. Here we store the metadata to make that split.
132132
for i, d in enumerate(self.shape):
133133
if block_size <= 0 or block_size >= d:
134-
split_sizes.append(np.array([d], dtype=np.int32))
134+
split_sizes.append(torch.tensor([d], dtype=torch.int32))
135135
continue
136136

137137
# d - 1, otherwise split appends a 0-size array.
138138
num_split: int = (d - 1) // block_size
139-
indices = (np.arange(num_split, dtype=np.int32) + 1) * block_size
139+
indices = (torch.arange(num_split, dtype=torch.int32) + 1) * block_size
140140

141-
sizes: np.ndarray = np.ones(num_split + 1, dtype=np.int32) * block_size
141+
sizes: torch.Tensor = torch.full((num_split + 1,), block_size, dtype=torch.int32)
142142
sizes[-1] = d - indices[-1]
143143

144144
self.splits.append((i, indices))
145145
self.split_sizes.append((i, sizes))
146146
split_sizes.append(sizes)
147147

148148
self.num_splits: int = len(split_sizes)
149-
self.pre_conditioner_shapes: List[List[int]] = self.build_pre_conditioner_shapes(
149+
self.pre_conditioner_shapes: List[List[torch.Tensor]] = self.build_pre_conditioner_shapes(
150150
split_sizes, pre_conditioner_type, rank
151151
)
152152

153153
@staticmethod
154154
def build_pre_conditioner_shapes(
155-
split_sizes: List[np.ndarray], pre_conditioner_type: int, rank: int
156-
) -> List[List[int]]:
155+
split_sizes: List[torch.Tensor], pre_conditioner_type: int, rank: int
156+
) -> List[List[torch.Tensor]]:
157157
r"""Build pre-conditioner shapes."""
158-
pre_conditioner_shapes: List[List[int]] = []
158+
pre_conditioner_shapes: List[List[torch.Tensor]] = []
159159
for t in itertools.product(*split_sizes):
160-
t_shape: List[Optional[List[int]]] = [[d, d] for d in t]
160+
t_shape: List[Optional[List[torch.Tensor]]] = [[d, d] for d in t]
161161
if pre_conditioner_type == PreConditionerType.INPUT:
162-
t_shape = t_shape[:-1] + [None]
163-
if pre_conditioner_type == PreConditionerType.OUTPUT:
162+
t_shape[-1] = None
163+
elif pre_conditioner_type == PreConditionerType.OUTPUT:
164164
t_shape = [None] * (rank - 1) + t_shape[-1:]
165165
pre_conditioner_shapes.extend(t_shape)
166166
return pre_conditioner_shapes
167167

168-
def shapes_for_pre_conditioners(self) -> List[List[int]]:
168+
def shapes_for_pre_conditioners(self) -> List[List[torch.Tensor]]:
169169
r"""Get shapes of pre-conditioner."""
170170
return self.pre_conditioner_shapes
171171

@@ -244,7 +244,7 @@ def __init__(
244244

245245
self.w2: float = 1.0 if self.beta2 == 1.0 else (1.0 - self.beta2)
246246

247-
self.original_shape: List[int] = var.shape
247+
self.original_shape: torch.Size = var.shape
248248
self.transformed_shape: List[int] = (
249249
merge_small_dims(self.original_shape, block_size) if shape_interpretation else var.shape
250250
)
@@ -267,7 +267,7 @@ def __init__(
267267
pre_conditioner_type=self.pre_conditioner_type,
268268
)
269269

270-
shapes: List[Optional[List[int]]] = self.partitioner.shapes_for_pre_conditioners()
270+
shapes: List[Optional[List[torch.Tensor]]] = self.partitioner.shapes_for_pre_conditioners()
271271
self.statistics = [self.matrix_eps * torch.eye(shape[0], device=var.device) for shape in shapes if shape]
272272
self.pre_conditioners = [torch.eye(shape[0], device=var.device) for shape in shapes if shape]
273273
self.is_same_shapes = None not in shapes and len(np.unique(shapes)) == 1
@@ -504,14 +504,14 @@ def compute_power_svd(matrix: torch.Tensor, power: float) -> torch.Tensor:
504504
return u @ (s.diag() if len(matrix.shape) == 2 else s.diag_embed()) @ vh
505505

506506

507-
def merge_small_dims(shape_to_merge: List[int], max_dim: int) -> List[int]:
507+
def merge_small_dims(shape_to_merge: Union[List[int], torch.Size], max_dim: int) -> List[int]:
508508
r"""Merge small dimensions.
509509
510510
If there are some small dimensions, we collapse them
511511
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
512512
[1, 2, 768, 1, 2048] --> [2, 768, 2048].
513513
514-
:param shape_to_merge: List[int]. Shape to merge small dimensions.
514+
:param shape_to_merge: Union[List[int], torch.Size]. Shape to merge small dimensions.
515515
:param max_dim: int. Maximal dimension of output shape used in merging.
516516
"""
517517
merged_shape: List[int] = []

0 commit comments

Comments
 (0)