Skip to content

Commit 7ab2c53

Browse files
fix get_lora_target_names function (#2167)
* fix target_names function * fix in backbone * update the get lora target names * add setter * address comment * update docstring
1 parent 7f1f011 commit 7ab2c53

File tree

4 files changed

+18
-15
lines changed

4 files changed

+18
-15
lines changed

keras_hub/src/models/backbone.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,23 +189,26 @@ def save_to_preset(self, preset_dir, max_shard_size=10):
189189
saver = get_preset_saver(preset_dir)
190190
saver.save_backbone(self, max_shard_size=max_shard_size)
191191

192-
def get_lora_target_names(self):
193-
"""Returns list of layer names which are to be LoRA-fied.
194-
195-
Subclasses can override this method if the names of layers to be
196-
LoRa-fied are different.
197-
"""
192+
def default_lora_layer_names(self):
193+
"""Returns list of layer names which are to be LoRA-fied."""
198194
return ["query_dense", "value_dense", "query", "value"]
199195

200-
def enable_lora(self, rank, target_names=None):
196+
def enable_lora(self, rank, target_layer_names=None):
201197
"""Enable Lora on the backbone.
202198
203199
Calling this method will freeze all weights on the backbone,
204200
while enabling Lora on the query & value `EinsumDense` layers
205201
of the attention layers.
202+
203+
Args:
204+
rank: The rank of the LoRA factorization.
205+
target_layer_names: A list of strings, the names of the layers to
206+
apply LoRA to. If `None`, this will be populated with the
207+
default LoRA layer names as returned by
208+
`backbone.default_lora_layer_names()`.
206209
"""
207-
if target_names is None:
208-
target_names = self.get_lora_target_names()
210+
if target_layer_names is None:
211+
target_layer_names = self.default_lora_layer_names()
209212
self.trainable = True
210213
self._lora_enabled_layers = []
211214
self._lora_rank = rank
@@ -214,7 +217,7 @@ def enable_lora(self, rank, target_names=None):
214217
all_layers = self._flatten_layers(include_self=False)
215218
all_layers = [lyr for lyr in all_layers if lyr.weights]
216219
for i, layer in enumerate(all_layers):
217-
for name in target_names:
220+
for name in target_layer_names:
218221
if layer.name == name:
219222
if hasattr(layer, "enable_lora"):
220223
layer.trainable = True

keras_hub/src/models/gemma/gemma_lora_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def test_lora_fine_tuning(self):
5050
def test_lora_fine_tuning_target_names(self):
5151
# Set up backbone and preprocessor.
5252
backbone = GemmaBackbone(**self._init_kwargs)
53-
backbone.enable_lora(4, target_names=["query"])
53+
backbone.enable_lora(4, target_layer_names=["query"])
5454
# 4 layers, 2 weights per layer
5555
self.assertLen(backbone.trainable_weights, 2 * 2)
5656
self.assertLen(backbone.non_trainable_weights, 20)

keras_hub/src/models/gemma3/gemma3_backbone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,8 +402,8 @@ def get_config(self):
402402
)
403403
return config
404404

405-
def get_lora_target_names(self):
406-
target_names = super().get_lora_target_names()
405+
def default_lora_layer_names(self):
406+
target_names = super().default_lora_layer_names()
407407

408408
# Add these for `Gemma3VITAttention`.
409409
if not self.text_only_model:

keras_hub/src/models/pali_gemma/pali_gemma_backbone.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,8 @@ def __init__(
274274
# Keep the image_sequence_length as a backbone property for easy access.
275275
self.image_sequence_length = self.vit_encoder.image_sequence_length
276276

277-
def get_lora_target_names(self):
278-
target_names = super().get_lora_target_names()
277+
def default_lora_layer_names(self):
278+
target_names = super().default_lora_layer_names()
279279

280280
# Add these for `PaliGemmaVITAttention`.
281281
target_names += ["query_proj", "value_proj"]

0 commit comments

Comments
 (0)