Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions keras/src/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,11 +230,17 @@ def kernel(self):
raise AttributeError(
"You must build the layer before accessing `kernel`."
)
kernel = self._kernel
if self.lora_enabled:
return self._kernel + (
self.lora_alpha / self.lora_rank
) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)
return self._kernel
kernel = ops.cast(
ops.add(
kernel,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
),
dtype=self.compute_dtype,
)
return kernel

def convolution_op(self, inputs, kernel):
return ops.conv(
Expand All @@ -247,10 +253,7 @@ def convolution_op(self, inputs, kernel):
)

def call(self, inputs):
outputs = self.convolution_op(
inputs,
self.kernel,
)
outputs = self.convolution_op(inputs, self.kernel)
if self.use_bias:
if self.data_format == "channels_last":
bias_shape = (1,) * (self.rank + 1) + (self.filters,)
Expand Down Expand Up @@ -296,16 +299,21 @@ def enable_lora(
"lora is already enabled. This can only be done once per layer."
)
self._tracker.unlock()

# LoRA weights should be float32 to avoid the risk of underflow or
# overflow during fine-tuning.
self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=self._kernel.shape[:-1] + (rank,),
initializer=initializers.get(a_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self.lora_kernel_b = self.add_weight(
name="lora_kernel_b",
shape=(rank, self.filters),
initializer=initializers.get(b_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self._kernel.trainable = False
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/convolutional/conv_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,9 @@ def test_enable_lora(
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 4)
self.assertDType(layer.lora_kernel_a, "float32")
self.assertDType(layer.lora_kernel_b, "float32")

# Try eager call
x = np.random.random((64,) + input_shape[1:])
y = np.random.random((64,) + output_shape[1:])
Expand Down
16 changes: 13 additions & 3 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,13 @@ def kernel(self):

# Apply LoRA once at the end.
if self.lora_enabled:
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
self.lora_kernel_a, self.lora_kernel_b
kernel = ops.cast(
ops.add(
kernel,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
),
dtype=self.compute_dtype,
)

return kernel
Expand Down Expand Up @@ -265,16 +270,20 @@ def enable_lora(
else:
input_dim_for_lora = self.kernel.shape[0]

# LoRA weights should be float32 to avoid the risk of underflow or
# overflow during fine-tuning.
self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=(input_dim_for_lora, rank),
initializer=initializers.get(a_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self.lora_kernel_b = self.add_weight(
name="lora_kernel_b",
shape=(rank, self.kernel.shape[1]),
initializer=initializers.get(b_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self._kernel.trainable = False
Expand Down Expand Up @@ -810,6 +819,7 @@ def grad_fn(*args, upstream=None):
lora_x = ops.matmul(inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
x = ops.cast(x, self.compute_dtype)
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
Expand Down Expand Up @@ -918,7 +928,7 @@ def grad_fn(*args, upstream=None):
lora_x = ops.matmul(inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)

x = ops.cast(x, self.compute_dtype)
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/core/dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ def test_enable_lora(self):
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 4)
self.assertDType(layer.lora_kernel_a, "float32")
self.assertDType(layer.lora_kernel_b, "float32")

# Try eager call
x = np.random.random((64, 8))
y = np.random.random((64, 16))
Expand Down
16 changes: 13 additions & 3 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,13 @@ def kernel(self):

# Apply LoRA if enabled
if self.lora_enabled:
kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(
self.lora_kernel_a, self.lora_kernel_b
kernel = ops.cast(
ops.add(
kernel,
(self.lora_alpha / self.lora_rank)
* ops.matmul(self.lora_kernel_a, self.lora_kernel_b),
),
dtype=self.compute_dtype,
)

return kernel
Expand Down Expand Up @@ -317,16 +322,20 @@ def enable_lora(
else:
kernel_shape_for_lora = self.kernel.shape

# LoRA weights should be float32 to avoid the risk of underflow or
# overflow during fine-tuning.
self.lora_kernel_a = self.add_weight(
name="lora_kernel_a",
shape=(kernel_shape_for_lora[:-1] + (rank,)),
initializer=initializers.get(a_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self.lora_kernel_b = self.add_weight(
name="lora_kernel_b",
shape=(rank, kernel_shape_for_lora[-1]),
initializer=initializers.get(b_initializer),
dtype="float32",
regularizer=self.kernel_regularizer,
)
self._kernel.trainable = False
Expand Down Expand Up @@ -980,6 +989,7 @@ def grad_fn(*args, upstream=None):
lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)
x = ops.cast(x, dtype=self.compute_dtype)
if self.bias is not None:
x = ops.add(x, self.bias)
if self.activation is not None:
Expand Down Expand Up @@ -1121,7 +1131,7 @@ def grad_fn(*args, upstream=None):
lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)
lora_x = ops.matmul(lora_x, self.lora_kernel_b)
x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)

x = ops.cast(x, dtype=self.compute_dtype)
# Bias & activation
if self.bias is not None:
x = ops.add(x, self.bias)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/core/einsum_dense_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ def test_enable_lora(self):
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 3)
self.assertDType(layer.lora_kernel_a, "float32")
self.assertDType(layer.lora_kernel_b, "float32")

# Try eager call
x = np.random.random((64, 3))
y = np.random.random((64, 8, 32))
Expand Down
18 changes: 16 additions & 2 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,15 @@ def embeddings(self):
embeddings, self._orig_output_dim, axis=-1
)
if self.lora_enabled:
return embeddings + (self.lora_alpha / self.lora_rank) * ops.matmul(
self.lora_embeddings_a, self.lora_embeddings_b
embeddings = ops.cast(
ops.add(
embeddings,
(self.lora_alpha / self.lora_rank)
* ops.matmul(
self.lora_embeddings_a, self.lora_embeddings_b
),
),
dtype=self.compute_dtype,
)
return embeddings

Expand Down Expand Up @@ -206,16 +213,21 @@ def enable_lora(
"lora is already enabled. This can only be done once per layer."
)
self._tracker.unlock()

# LoRA weights should be float32 to avoid the risk of underflow or
# overflow during fine-tuning.
self.lora_embeddings_a = self.add_weight(
name="lora_embeddings_a",
shape=(self.input_dim, rank),
initializer=initializers.get(a_initializer),
dtype="float32",
regularizer=self.embeddings_regularizer,
)
self.lora_embeddings_b = self.add_weight(
name="lora_embeddings_b",
shape=(rank, self.output_dim),
initializer=initializers.get(b_initializer),
dtype="float32",
regularizer=self.embeddings_regularizer,
)
self.embeddings.trainable = False
Expand Down Expand Up @@ -478,6 +490,7 @@ def _int8_call(self, inputs, training=None):
outputs = ops.add(
outputs, (self.lora_alpha / self.lora_rank) * lora_outputs
)
outputs = ops.cast(outputs, dtype=self.compute_dtype)
return outputs

def _int4_call(self, inputs, training=None):
Expand Down Expand Up @@ -519,6 +532,7 @@ def _int4_call(self, inputs, training=None):
outputs = ops.add(
outputs, (self.lora_alpha / self.lora_rank) * lora_outputs
)
outputs = ops.cast(outputs, dtype=self.compute_dtype)
return outputs

def quantize(self, mode=None, type_check=True, config=None):
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/core/embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ def test_enable_lora(self):
self.assertLen(layer.non_trainable_weights, 1)
if backend.backend() == "torch":
self.assertLen(layer.torch_params, 3)
self.assertDType(layer.lora_embeddings_a, "float32")
self.assertDType(layer.lora_embeddings_b, "float32")

# Try eager call
x = np.random.randint(0, 9, size=(64, 3))
y = np.random.random((64, 3, 16))
Expand Down
Loading