@@ -189,23 +189,26 @@ def save_to_preset(self, preset_dir, max_shard_size=10):
189
189
saver = get_preset_saver (preset_dir )
190
190
saver .save_backbone (self , max_shard_size = max_shard_size )
191
191
192
- def get_lora_target_names (self ):
193
- """Returns list of layer names which are to be LoRA-fied.
194
-
195
- Subclasses can override this method if the names of layers to be
196
- LoRa-fied are different.
197
- """
192
+ def default_lora_layer_names (self ):
193
+ """Returns list of layer names which are to be LoRA-fied."""
198
194
return ["query_dense" , "value_dense" , "query" , "value" ]
199
195
200
- def enable_lora (self , rank , target_names = None ):
196
+ def enable_lora (self , rank , target_layer_names = None ):
201
197
"""Enable Lora on the backbone.
202
198
203
199
Calling this method will freeze all weights on the backbone,
204
200
while enabling Lora on the query & value `EinsumDense` layers
205
201
of the attention layers.
202
+
203
+ Args:
204
+ rank: The rank of the LoRA factorization.
205
+ target_layer_names: A list of strings, the names of the layers to
206
+ apply LoRA to. If `None`, this will be populated with the
207
+ default LoRA layer names as returned by
208
+ `backbone.default_lora_layer_names()`.
206
209
"""
207
- if target_names is None :
208
- target_names = self .get_lora_target_names ()
210
+ if target_layer_names is None :
211
+ target_layer_names = self .default_lora_layer_names ()
209
212
self .trainable = True
210
213
self ._lora_enabled_layers = []
211
214
self ._lora_rank = rank
@@ -214,7 +217,7 @@ def enable_lora(self, rank, target_names=None):
214
217
all_layers = self ._flatten_layers (include_self = False )
215
218
all_layers = [lyr for lyr in all_layers if lyr .weights ]
216
219
for i , layer in enumerate (all_layers ):
217
- for name in target_names :
220
+ for name in target_layer_names :
218
221
if layer .name == name :
219
222
if hasattr (layer , "enable_lora" ):
220
223
layer .trainable = True
0 commit comments