|
| 1 | +LoRA |
| 2 | +==== |
| 3 | + |
| 4 | +.. py:module:: tfimm.architectures.lora |
| 5 | +
|
| 6 | +*Low-Rank Adaption (LoRA)* is a parameter-efficient fine-tuning method developed |
| 7 | +originally for large language models, but adapted here for vision models. This module |
| 8 | +contains TensorFLow code for LoRA layers and their integration with ``tfimm`` models. |
| 9 | +For details on LoRA see the paper |
| 10 | + |
| 11 | +**LoRA: Low-Rank Adaptation of Large Language Models.** |
| 12 | +*Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, |
| 13 | +Lu Wang, Weizhu Chen.* |
| 14 | +Paper: `[arXiv:2106.09685] <https://arxiv.org/abs/2106.09685>`_ |
| 15 | + |
| 16 | +Usage |
| 17 | +----- |
| 18 | + |
| 19 | +For supported architectures we can use :func:`tfimm.architectures.lora.create_model` |
| 20 | +instead of :func:`tfimm.create_model` to create a LoRA model. |
| 21 | + |
| 22 | +.. code-block:: pycon |
| 23 | +
|
| 24 | + >>> from tfimm.architectures import lora |
| 25 | + >>> model = lora.create_model( |
| 26 | + ... model_name="convnext_tiny", pretrained=True, lora_rank=2 |
| 27 | + ... ) |
| 28 | +
|
| 29 | +When we look at the model summary, we see that most model parameters are non-trainable |
| 30 | +and only the low-rank weight updates are trainable. |
| 31 | + |
| 32 | +.. code-block:: pycon |
| 33 | +
|
| 34 | + >>> model.summary() |
| 35 | + ... |
| 36 | + ================================================================= |
| 37 | + Total params: 28,721,608 |
| 38 | + Trainable params: 132,480 |
| 39 | + Non-trainable params: 28,589,128 |
| 40 | + _________________________________________________________________ |
| 41 | +
|
| 42 | +LoRA models can be converted back to regular models. |
| 43 | + |
| 44 | +.. code-block:: pycon |
| 45 | +
|
| 46 | + >>> type(model) |
| 47 | + <class 'tfimm.architectures.lora.convnext.LoRAConvNeXt'> |
| 48 | + >>> regular_model = lora.convert_to_regular_model(model) |
| 49 | + >>> type(regular_model) |
| 50 | + <class 'tfimm.architectures.convnext.ConvNeXt'> |
| 51 | +
|
| 52 | +Supported architectures |
| 53 | +----------------------- |
| 54 | + |
| 55 | +Currently we support the following architectures |
| 56 | + |
| 57 | +* ConvNeXt |
| 58 | + |
| 59 | +And the following layers |
| 60 | + |
| 61 | +* Dense |
| 62 | + |
| 63 | +Implementation |
| 64 | +-------------- |
| 65 | + |
| 66 | +In order to perform LoRA training, the first task is to convert a regular model to its |
| 67 | +LoRA version. For ``tfimm`` architectures we do this by subclassing and modifying layers |
| 68 | +in ``__init__``. E.g., ``LoRAConvNeXt`` is subclassed from |
| 69 | +``ConvNeXt`` and we replace the dense layers in each MLP block by their LoRA |
| 70 | +counterparts. |
| 71 | + |
| 72 | +We use a registry system to track model classes and their LoRA counterparts. A |
| 73 | +simplified example: |
| 74 | + |
| 75 | +.. code-block:: python |
| 76 | +
|
| 77 | + from tfimm.architectures import lora |
| 78 | +
|
| 79 | + @dataclass |
| 80 | + class ResNetConfig: |
| 81 | + nb_blocks = (3, 4, 6, 3) |
| 82 | +
|
| 83 | + class ResNet(tf.keras.Model): |
| 84 | + cfg_class: ResNetConfig |
| 85 | +
|
| 86 | + def __init__(self, cfg, **kwargs): |
| 87 | + ... |
| 88 | +
|
| 89 | + class LoRAResNetConfig(ResNetConfig): |
| 90 | + lora_rank = 2 |
| 91 | +
|
| 92 | + @lora.register_lora_architecture |
| 93 | + class LoRAResNet(ResNet): |
| 94 | + cfg_class: LoRAResNetConfig |
| 95 | +
|
| 96 | + def __init__(self, cfg, **kwargs): |
| 97 | + super().__init__(cfg, **kwargs) # Create the original model |
| 98 | + ... # Then replace layers with LoRA versions |
| 99 | +
|
| 100 | +We make the following assumptions |
| 101 | + |
| 102 | +* Model parameters are specified via a configuration dataclass and the configuration |
| 103 | + class of each model is defined via the ``cfg_class`` class attribute. |
| 104 | +* The configuration of the LoRA model is a superset of the configuration of the base |
| 105 | + model. |
| 106 | + |
| 107 | +Under these assumptions we can use the |
| 108 | +:func:`register_lora_architecture` decorator to |
| 109 | +associate ``LoRAResNet`` as the LoRA variant of the ``ResNet`` class. |
| 110 | + |
| 111 | +Now, given an instance of ``ResNet``, we can use |
| 112 | +:func:`convert_to_lora_model` to convert |
| 113 | +it to a ``LoRAResNet`` instance *and* transfer all weights. |
| 114 | + |
| 115 | +.. code-block:: python |
| 116 | +
|
| 117 | + model = ResNet(cfg=ResNetConfig()) |
| 118 | + ... # Build model or load pre-trained weights |
| 119 | +
|
| 120 | + lora_model = lora.convert_to_lora_model(model, lora_rank=2) |
| 121 | +
|
| 122 | +The ``lora_model.trainable_weights`` property correctly returns only the LoRA trainable |
| 123 | +weights, i.e., the low-rank updates. We additionally have the option to train the |
| 124 | +biases as well, either only for LoRA layers or for all layers. This can be specified |
| 125 | +by passing the values ``"none"``, ``"lora_only"`` or ``"all"`` for the |
| 126 | +``lora_train_bias`` parameter. |
| 127 | + |
| 128 | +.. code-block:: python |
| 129 | +
|
| 130 | + lora_model = convert_to_lora_model( |
| 131 | + model, lora_rank=2, lora_train_bias="lora_only" |
| 132 | + ) |
| 133 | +
|
| 134 | +Sequential and functional models |
| 135 | +-------------------------------- |
| 136 | + |
| 137 | +The current implementation focusses on models created by subclassing, which is the case |
| 138 | +for all ``tfimm`` models. In particular, the registry system works only for subclassed |
| 139 | +models. However, some of the functionality also works for functional models. |
| 140 | + |
| 141 | +* LoRA layers are the basic building blocks for both subclassed as well as |
| 142 | + functional models. |
| 143 | +* Transferring weights works for all models, regardless of type, provided the regular |
| 144 | + model and LoRA variant have the same architecture with the exception of LoRA layers. |
| 145 | + Use the :func:`tfimm.models.transfer_weights` function to tranfer weights to LoRA. |
| 146 | + |
| 147 | + .. code-block:: python |
| 148 | +
|
| 149 | + from tfimm.architectures import lora |
| 150 | + from tfimm.models import transfer_weights |
| 151 | +
|
| 152 | + # Transfer weights into the LoRA model |
| 153 | + transfer_weights( |
| 154 | + regular_model, lora_model, weights_to_ignore=lora.LORA_WEIGHT_NAMES |
| 155 | + ) |
| 156 | +
|
| 157 | +* After training, we need to manually merge weights and then transfer them back to the |
| 158 | + regular model. |
| 159 | + |
| 160 | + .. code-block:: python |
| 161 | +
|
| 162 | + lora.merge_weights(lora_model) |
| 163 | + transfer_weights(lora_model, regular_model) |
| 164 | +
|
| 165 | +* The functions |
| 166 | + :func:`lora_trainable_weights` and |
| 167 | + :func:`lora_non_trainable_weights` work for all |
| 168 | + models, regardless of type and return a list of weights to be used for LoRA training |
| 169 | + (or all other weights). |
| 170 | + |
| 171 | +Interface |
| 172 | +--------- |
| 173 | + |
| 174 | +All functions are available under ``tfimm.architectures.lora``. |
| 175 | + |
| 176 | +Factory |
| 177 | +~~~~~~~ |
| 178 | + |
| 179 | +.. autofunction:: convert_to_lora_model |
| 180 | +.. autofunction:: convert_to_regular_model |
| 181 | +.. autofunction:: create_model |
| 182 | +.. autofunction:: merge_lora_weights |
| 183 | +.. autofunction:: lora_non_trainable_weights |
| 184 | +.. autofunction:: lora_trainable_weights |
| 185 | + |
| 186 | +Layers |
| 187 | +~~~~~~ |
| 188 | + |
| 189 | +.. autoclass:: LoRAConv2D |
| 190 | +.. autoclass:: LoRADense |
| 191 | +.. autofunction:: convert_to_lora_layer |
| 192 | + |
| 193 | +Registry |
| 194 | +~~~~~~~~ |
| 195 | + |
| 196 | +.. autofunction:: lora_architecture |
| 197 | +.. autofunction:: lora_base_architecture |
| 198 | +.. autofunction:: lora_config |
| 199 | +.. autofunction:: register_lora_architecture |
0 commit comments