Skip to content

Commit e6d72ed

Browse files
committed
Update adafactor comments / attrib
1 parent d73e8e7 commit e6d72ed

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

timm/optim/adafactor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
33
Lifted from https://github.com/pytorch/fairseq/blob/master/fairseq/optim/adafactor.py
44
5-
Original header/copyright below.
5+
Modified by Ross Wightman to fix some issues with factorization dims for non nn.Linear layers
66
7+
Original header/copyright below.
78
"""
89
# Copyright (c) Facebook, Inc. and its affiliates.
910
#
@@ -96,7 +97,7 @@ def _get_options(param_group, param_shape, min_size_to_factor=32):
9697
# nD convs in torch are ND + 2 dim weights with leading in/out chs
9798
factored = 0, 1
9899
elif ndim >= 2 and param_shape[-2] > min_size_to_factor and param_shape[-1] > min_size_to_factor:
99-
# if the criteria above didn't match, check trailing dims
100+
# if the criteria above didn't match, test trailing dims for eligibility
100101
factored = ndim - 2, ndim - 1
101102

102103
return factored, use_first_moment

timm/optim/adafactor_bv.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
""" Adafactor (Big Vision variant) for PyTorch
2+
3+
Adapted from the implementation in big vision: https://github.com/google-research/big_vision
4+
5+
Described in 'Scaling Vision Transformers': https://arxiv.org/abs/2106.04560
6+
7+
Adaptation and PyTorch modifications by Ross Wightman
8+
"""
9+
110
from typing import List, Optional, Tuple, Union
211

312
import torch
@@ -39,6 +48,8 @@ def _factored_dims(
3948
class AdafactorBigVision(Optimizer):
4049
"""
4150
PyTorch implementation of BigVision's Adafactor variant with both single and multi tensor implementations.
51+
52+
Adapted from https://github.com/google-research/big_vision by Ross Wightman
4253
"""
4354

4455
def __init__(
@@ -292,4 +303,5 @@ def _multi_tensor_adafactor(
292303
clipping_threshold: Optional[float],
293304
unscaled_wd: bool,
294305
):
306+
# FIXME TODO
295307
assert False, 'multi-tensor fn (foreach=True) not implemented yet'

0 commit comments

Comments
 (0)