Skip to content

Commit 12e3e99

Browse files
authored
Add include rescaling to the pali gemma backbone (#1650)
To disable this option, pass ``` keras_nlp.models.PaliGemma.from_preset( "pali_gemma_3b_224", include_rescaling=False, ) ``` Allow inputs to be the more standard range 0, 255
1 parent d817656 commit 12e3e99

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

keras_nlp/src/models/pali_gemma/pali_gemma_backbone.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ class PaliGemmaBackbone(Backbone):
7474
vit_classifier_activation: activation function. The activation that
7575
is used for final output classification in the vision transformer.
7676
vit_name: string. The name used for vision transformer layers.
77+
include_rescaling: bool. If true, the image input will be rescaled from
78+
the range `[0, 255]`, to the range `[0, 1]`.
7779
layer_norm_epsilon: float. The epsilon value user for every layer norm
7880
in all transformer blocks.
7981
dropout: float. Dropout probability for the Transformer decoder blocks.
@@ -132,6 +134,7 @@ def __init__(
132134
vit_pooling=None,
133135
vit_classifier_activation=None,
134136
vit_name=None,
137+
include_rescaling=True,
135138
layer_norm_epsilon=1e-6,
136139
dropout=0,
137140
dtype=None,
@@ -163,6 +166,7 @@ def __init__(
163166
vit_intermediate_dim = vit_intermediate_dim or 4304
164167
self.vit_encoder = PaliGemmaVit(
165168
image_size=image_size,
169+
include_rescaling=include_rescaling,
166170
patch_size=vit_patch_size,
167171
num_heads=vit_num_heads,
168172
hidden_dim=vit_hidden_dim,
@@ -232,6 +236,7 @@ def __init__(
232236
# === Config ===
233237
self.vocabulary_size = vocabulary_size
234238
self.image_size = image_size
239+
self.include_rescaling = include_rescaling
235240
self.num_layers = num_layers
236241
self.num_query_heads = num_query_heads
237242
self.num_key_value_heads = num_key_value_heads
@@ -258,6 +263,7 @@ def get_config(self):
258263
{
259264
"vocabulary_size": self.vocabulary_size,
260265
"image_size": self.image_size,
266+
"include_rescaling": self.include_rescaling,
261267
"num_layers": self.num_layers,
262268
"num_query_heads": self.num_query_heads,
263269
"num_key_value_heads": self.num_key_value_heads,

keras_nlp/src/models/pali_gemma/pali_gemma_vit.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ class PaliGemmaVit(keras.Model):
423423
Args:
424424
image_size: int. The height/width of the image. Both height and width is
425425
expected to be the same.
426+
include_rescaling: bool. If true, the image input will be rescaled from
427+
the range `[0, 255]`, to the range `[0, 1]`.
426428
patch_size: int. The size of each square patch in the input image.
427429
num_heads: int. The number of attention heads for the vision(image)
428430
transformer encoder.
@@ -463,6 +465,7 @@ def __init__(
463465
num_layers,
464466
intermediate_dim,
465467
num_classes,
468+
include_rescaling=True,
466469
pooling=None,
467470
classifier_activation=None,
468471
dtype=None,
@@ -472,7 +475,13 @@ def __init__(
472475
image_input = keras.Input(
473476
shape=(image_size, image_size, 3), name="images"
474477
)
475-
encoded = PaliGemmaVitEncoder(
478+
x = image_input # Intermediate result.
479+
if include_rescaling:
480+
rescaling = keras.layers.Rescaling(
481+
scale=1.0 / 127.5, offset=-1.0, name="rescaling"
482+
)
483+
x = rescaling(image_input)
484+
x = PaliGemmaVitEncoder(
476485
hidden_dim=hidden_dim,
477486
num_layers=num_layers,
478487
num_heads=num_heads,
@@ -481,20 +490,20 @@ def __init__(
481490
image_size=image_size,
482491
dtype=dtype,
483492
name="image_encoder",
484-
)(image_input)
493+
)(x)
485494
if pooling == "map":
486-
pooled = MultiHeadAttentionPooling(
495+
x = MultiHeadAttentionPooling(
487496
num_heads=num_heads,
488497
hidden_dim=hidden_dim,
489498
dtype=dtype,
490499
name="pooling",
491-
)(encoded)
500+
)(x)
492501
elif pooling == "gap":
493-
pooled = ops.mean(encoded, axis=1)
502+
x = ops.mean(x, axis=1)
494503
elif pooling == "zero":
495-
pooled = encoded[:, 0]
504+
x = x[:, 0]
496505
elif pooling is None:
497-
pooled = encoded
506+
x = x
498507
else:
499508
raise ValueError(
500509
"Invalid value for argument `pooling`. "
@@ -506,7 +515,7 @@ def __init__(
506515
activation=classifier_activation,
507516
dtype=dtype,
508517
name="image_classifier",
509-
)(pooled)
518+
)(x)
510519
super().__init__(
511520
inputs=image_input,
512521
outputs=outputs,
@@ -521,6 +530,7 @@ def __init__(
521530
self.pooling = pooling
522531
self.num_classes = num_classes
523532
self.image_size = image_size
533+
self.include_rescaling = include_rescaling
524534
self.patch_size = patch_size
525535
self.classifier_activation = keras.activations.get(
526536
classifier_activation
@@ -541,6 +551,7 @@ def get_config(self):
541551
self.classifier_activation
542552
),
543553
"image_size": self.image_size,
554+
"include_rescaling": self.include_rescaling,
544555
"patch_size": self.patch_size,
545556
}
546557
)

keras_nlp/src/models/pali_gemma/pali_gemma_vit_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,23 @@ def test_vit_encoder(self):
4545
output.shape, (batch_size, intermediate_dim, hidden_dim)
4646
)
4747

48+
def test_vit_rescaling(self):
49+
vit_encoder = PaliGemmaVit(
50+
image_size=16,
51+
patch_size=4,
52+
hidden_dim=8,
53+
num_layers=2,
54+
num_heads=2,
55+
intermediate_dim=16,
56+
num_classes=32,
57+
)
58+
self.assertIsNotNone(vit_encoder.get_layer("rescaling"))
59+
with self.assertRaises(ValueError):
60+
config = vit_encoder.get_config()
61+
config["include_rescaling"] = False
62+
vit_encoder = PaliGemmaVit.from_config(config)
63+
vit_encoder.get_layer("rescaling")
64+
4865
def test_vision_embeddings(self):
4966
embeddings_layer = PaliGemmaVitEmbeddings(
5067
image_size=16,

0 commit comments

Comments
 (0)