Skip to content

Commit c5b7a0e

Browse files
martinsbruverisKevin Keraudrenkevin-keraudren
authored
LoRA layers (#96)
Big PR adding all the LoRA related work. See README for details. --------- Co-authored-by: Kevin Keraudren <kevin.keraudren@onfido.com> Co-authored-by: Kevin Keraudren <kevin.keraudren@googlemail.com>
1 parent 71d2dad commit c5b7a0e

File tree

11 files changed

+1376
-8
lines changed

11 files changed

+1376
-8
lines changed

docs/source/content/lora.rst

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
LoRA
2+
====
3+
4+
.. py:module:: tfimm.architectures.lora
5+
6+
*Low-Rank Adaption (LoRA)* is a parameter-efficient fine-tuning method developed
7+
originally for large language models, but adapted here for vision models. This module
8+
contains TensorFLow code for LoRA layers and their integration with ``tfimm`` models.
9+
For details on LoRA see the paper
10+
11+
**LoRA: Low-Rank Adaptation of Large Language Models.**
12+
*Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang,
13+
Lu Wang, Weizhu Chen.*
14+
Paper: `[arXiv:2106.09685] <https://arxiv.org/abs/2106.09685>`_
15+
16+
Usage
17+
-----
18+
19+
For supported architectures we can use :func:`tfimm.architectures.lora.create_model`
20+
instead of :func:`tfimm.create_model` to create a LoRA model.
21+
22+
.. code-block:: pycon
23+
24+
>>> from tfimm.architectures import lora
25+
>>> model = lora.create_model(
26+
... model_name="convnext_tiny", pretrained=True, lora_rank=2
27+
... )
28+
29+
When we look at the model summary, we see that most model parameters are non-trainable
30+
and only the low-rank weight updates are trainable.
31+
32+
.. code-block:: pycon
33+
34+
>>> model.summary()
35+
...
36+
=================================================================
37+
Total params: 28,721,608
38+
Trainable params: 132,480
39+
Non-trainable params: 28,589,128
40+
_________________________________________________________________
41+
42+
LoRA models can be converted back to regular models.
43+
44+
.. code-block:: pycon
45+
46+
>>> type(model)
47+
<class 'tfimm.architectures.lora.convnext.LoRAConvNeXt'>
48+
>>> regular_model = lora.convert_to_regular_model(model)
49+
>>> type(regular_model)
50+
<class 'tfimm.architectures.convnext.ConvNeXt'>
51+
52+
Supported architectures
53+
-----------------------
54+
55+
Currently we support the following architectures
56+
57+
* ConvNeXt
58+
59+
And the following layers
60+
61+
* Dense
62+
63+
Implementation
64+
--------------
65+
66+
In order to perform LoRA training, the first task is to convert a regular model to its
67+
LoRA version. For ``tfimm`` architectures we do this by subclassing and modifying layers
68+
in ``__init__``. E.g., ``LoRAConvNeXt`` is subclassed from
69+
``ConvNeXt`` and we replace the dense layers in each MLP block by their LoRA
70+
counterparts.
71+
72+
We use a registry system to track model classes and their LoRA counterparts. A
73+
simplified example:
74+
75+
.. code-block:: python
76+
77+
from tfimm.architectures import lora
78+
79+
@dataclass
80+
class ResNetConfig:
81+
nb_blocks = (3, 4, 6, 3)
82+
83+
class ResNet(tf.keras.Model):
84+
cfg_class: ResNetConfig
85+
86+
def __init__(self, cfg, **kwargs):
87+
...
88+
89+
class LoRAResNetConfig(ResNetConfig):
90+
lora_rank = 2
91+
92+
@lora.register_lora_architecture
93+
class LoRAResNet(ResNet):
94+
cfg_class: LoRAResNetConfig
95+
96+
def __init__(self, cfg, **kwargs):
97+
super().__init__(cfg, **kwargs) # Create the original model
98+
... # Then replace layers with LoRA versions
99+
100+
We make the following assumptions
101+
102+
* Model parameters are specified via a configuration dataclass and the configuration
103+
class of each model is defined via the ``cfg_class`` class attribute.
104+
* The configuration of the LoRA model is a superset of the configuration of the base
105+
model.
106+
107+
Under these assumptions we can use the
108+
:func:`register_lora_architecture` decorator to
109+
associate ``LoRAResNet`` as the LoRA variant of the ``ResNet`` class.
110+
111+
Now, given an instance of ``ResNet``, we can use
112+
:func:`convert_to_lora_model` to convert
113+
it to a ``LoRAResNet`` instance *and* transfer all weights.
114+
115+
.. code-block:: python
116+
117+
model = ResNet(cfg=ResNetConfig())
118+
... # Build model or load pre-trained weights
119+
120+
lora_model = lora.convert_to_lora_model(model, lora_rank=2)
121+
122+
The ``lora_model.trainable_weights`` property correctly returns only the LoRA trainable
123+
weights, i.e., the low-rank updates. We additionally have the option to train the
124+
biases as well, either only for LoRA layers or for all layers. This can be specified
125+
by passing the values ``"none"``, ``"lora_only"`` or ``"all"`` for the
126+
``lora_train_bias`` parameter.
127+
128+
.. code-block:: python
129+
130+
lora_model = convert_to_lora_model(
131+
model, lora_rank=2, lora_train_bias="lora_only"
132+
)
133+
134+
Sequential and functional models
135+
--------------------------------
136+
137+
The current implementation focusses on models created by subclassing, which is the case
138+
for all ``tfimm`` models. In particular, the registry system works only for subclassed
139+
models. However, some of the functionality also works for functional models.
140+
141+
* LoRA layers are the basic building blocks for both subclassed as well as
142+
functional models.
143+
* Transferring weights works for all models, regardless of type, provided the regular
144+
model and LoRA variant have the same architecture with the exception of LoRA layers.
145+
Use the :func:`tfimm.models.transfer_weights` function to tranfer weights to LoRA.
146+
147+
.. code-block:: python
148+
149+
from tfimm.architectures import lora
150+
from tfimm.models import transfer_weights
151+
152+
# Transfer weights into the LoRA model
153+
transfer_weights(
154+
regular_model, lora_model, weights_to_ignore=lora.LORA_WEIGHT_NAMES
155+
)
156+
157+
* After training, we need to manually merge weights and then transfer them back to the
158+
regular model.
159+
160+
.. code-block:: python
161+
162+
lora.merge_weights(lora_model)
163+
transfer_weights(lora_model, regular_model)
164+
165+
* The functions
166+
:func:`lora_trainable_weights` and
167+
:func:`lora_non_trainable_weights` work for all
168+
models, regardless of type and return a list of weights to be used for LoRA training
169+
(or all other weights).
170+
171+
Interface
172+
---------
173+
174+
All functions are available under ``tfimm.architectures.lora``.
175+
176+
Factory
177+
~~~~~~~
178+
179+
.. autofunction:: convert_to_lora_model
180+
.. autofunction:: convert_to_regular_model
181+
.. autofunction:: create_model
182+
.. autofunction:: merge_lora_weights
183+
.. autofunction:: lora_non_trainable_weights
184+
.. autofunction:: lora_trainable_weights
185+
186+
Layers
187+
~~~~~~
188+
189+
.. autoclass:: LoRAConv2D
190+
.. autoclass:: LoRADense
191+
.. autofunction:: convert_to_lora_layer
192+
193+
Registry
194+
~~~~~~~~
195+
196+
.. autofunction:: lora_architecture
197+
.. autofunction:: lora_base_architecture
198+
.. autofunction:: lora_config
199+
.. autofunction:: register_lora_architecture

docs/source/index.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,4 +43,10 @@ Contents
4343
:maxdepth: 2
4444
:caption: Training
4545

46-
content/trainer
46+
content/trainer
47+
48+
.. toctree::
49+
:maxdepth: 2
50+
:caption: Applications
51+
52+
content/lora

0 commit comments

Comments
 (0)