diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 76ac0631b4..f8c1bd7194 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -69,6 +69,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/video_classification \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -82,6 +83,7 @@ else keras_cv/models/classification \ keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ + keras_cv/models/video_classification \ keras_cv/models/object_detection_3d \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion diff --git a/keras_cv/models/video_classification/__init__.py b/keras_cv/models/video_classification/__init__.py new file mode 100644 index 0000000000..320da488c1 --- /dev/null +++ b/keras_cv/models/video_classification/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.models.video_classification.vivit import ViViT diff --git a/keras_cv/models/video_classification/vivit.py b/keras_cv/models/video_classification/vivit.py new file mode 100644 index 0000000000..3858f423aa --- /dev/null +++ b/keras_cv/models/video_classification/vivit.py @@ -0,0 +1,197 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.models.task import Task +from keras_cv.models.video_classification.vivit_layers import PositionalEncoder +from keras_cv.models.video_classification.vivit_layers import TubeletEmbedding + + +@keras_cv_export( + [ + "keras_cv.models.ViViT", + "keras_cv.models.video_classification.ViViT", + ] +) +class ViViT(Task): + """A Keras model implementing a Video Vision Transformer + for video classification. + + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + inp_shape: tuple, the shape of the input video frames. + num_classes: int, the number of classes for video classification. + transformer_layers: int, the number of transformer layers in the model. + Defaults to 8. + patch_size: tuple , contains the size of the + spatio-temporal patches for each dimension + Defaults to (8,8,8) + num_heads: int, the number of heads for multi-head + self-attention mechanism. Defaults to 8. + projection_dim: int, number of dimensions in the projection space. + Defaults to 128. + layer_norm_eps: float, epsilon value for layer normalization. + Defaults to 1e-6. + + + Examples: + ```python + import keras_cv + + INPUT_SHAPE = (32, 32, 32, 1) + NUM_CLASSES = 11 + PATCH_SIZE = (8, 8, 8) + LAYER_NORM_EPS = 1e-6 + PROJECTION_DIM = 128 + NUM_HEADS = 8 + NUM_LAYERS = 8 + + frames = np.random.uniform(size=(5, 32, 32, 32, 1)) + labels = np.ones(shape=(5)) + + # Instantiate Model + model = ViViT( + projection_dim=PROJECTION_DIM, + patch_size=PATCH_SIZE, + inp_shape=INPUT_SHAPE, + transformer_layers=NUM_LAYERS, + num_heads=NUM_HEADS, + layer_norm_eps=LAYER_NORM_EPS, + num_classes=NUM_CLASSES, + ) + + # Compile model + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + ], + ) + + # Build Model + model.build(INPUT_SHAPE) + + # Train Model + model.fit(frames, labels, epochs=3) + + ``` + """ + + def __init__( + self, + inp_shape, + num_classes, + projection_dim=128, + patch_size=(8, 8, 8), + transformer_layers=8, + num_heads=8, + layer_norm_eps=1e-6, + **kwargs, + ): + super().__init__(**kwargs) + + self.projection_dim = projection_dim + self.patch_size = patch_size + self.tubelet_embedder = TubeletEmbedding( + embed_dim=self.projection_dim, patch_size=self.patch_size + ) + + self.positional_encoder = PositionalEncoder( + embed_dim=self.projection_dim + ) + self.layer_norm = keras.layers.LayerNormalization( + epsilon=layer_norm_eps + ) + self.attention_output = keras.layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=projection_dim // num_heads, + dropout=0.1, + ) + self.dense_1 = keras.layers.Dense( + units=projection_dim * 4, activation=keras.ops.gelu + ) + + self.dense_2 = keras.layers.Dense( + units=projection_dim, activation=keras.ops.gelu + ) + self.add = keras.layers.Add() + self.pooling = keras.layers.GlobalAvgPool1D() + self.dense_output = keras.layers.Dense( + units=num_classes, activation="softmax" + ) + + self.inp_shape = inp_shape + self.num_heads = num_heads + self.num_classes = num_classes + self.projection_dim = projection_dim + self.patch_size = patch_size + self.transformer_layers = transformer_layers + + def build(self, input_shape): + super().build(input_shape) + self.tubelet_embedder.build(input_shape) + flattened_patch_shape = self.tubelet_embedder.compute_output_shape( + input_shape + ) + self.positional_encoder.build(flattened_patch_shape) + self.layer_norm.build([None, None, self.projection_dim]) + self.attention_output.build( + query_shape=[None, None, self.projection_dim], + value_shape=[None, None, self.projection_dim], + ) + self.add.build( + [ + (None, None, self.projection_dim), + (None, None, self.projection_dim), + ] + ) + + self.dense_1.build([None, None, self.projection_dim]) + self.dense_2.build([None, None, self.projection_dim * 4]) + self.pooling.build([None, None, self.projection_dim]) + self.dense_output.build([None, self.projection_dim]) + + def call(self, x): + patches = self.tubelet_embedder(x) + encoded_patches = self.positional_encoder(patches) + for _ in range(self.transformer_layers): + x1 = self.layer_norm(encoded_patches) + attention_output = self.attention_output(x1, x1) + x2 = self.add([attention_output, encoded_patches]) + x3 = self.layer_norm(x2) + x4 = self.dense_1(x3) + x5 = self.dense_2(x4) + encoded_patches = self.add([x5, x2]) + representation = self.layer_norm(encoded_patches) + pooled_representation = self.pooling(representation) + outputs = self.dense_output(pooled_representation) + return outputs + + def get_config(self): + config = super().get_config() + config.update( + { + "num_heads": self.num_heads, + "inp_shape": self.inp_shape, + "num_classes": self.num_classes, + "projection_dim": self.projection_dim, + "patch_size": self.patch_size, + } + ) + return config diff --git a/keras_cv/models/video_classification/vivit_layers.py b/keras_cv/models/video_classification/vivit_layers.py new file mode 100644 index 0000000000..53c2a0fc79 --- /dev/null +++ b/keras_cv/models/video_classification/vivit_layers.py @@ -0,0 +1,129 @@ +# Copyright 2024 The KerasCV Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops + + +@keras_cv_export( + "keras_cv.layers.TubeletEmebedding", + package="keras_cv.layers", +) +class TubeletEmbedding(keras.layers.Layer): + """ + A Keras layer for spatio-temporal tube embedding applied to input sequences + retrieved from video frames. + + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + embed_dim: int, number of dimensions in the embedding space. + Defaults to 128. + patch_size: tuple , size of the spatio-temporal patch. + Specifies the size for each dimension. + Defaults to (8,8,8). + + """ + + def __init__(self, embed_dim=128, patch_size=(8, 8, 8), **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + self.patch_size = patch_size + self.projection = keras.layers.Conv3D( + filters=self.embed_dim, + kernel_size=self.patch_size, + strides=self.patch_size, + data_format="channels_last", + padding="VALID", + ) + self.flatten = keras.layers.Reshape(target_shape=(-1, self.embed_dim)) + + def build(self, input_shape): + super().build(input_shape) + self.projection.build( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + ) + ) + projected_patch_shape = self.projection.compute_output_shape( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + ) + ) + self.flatten.build(projected_patch_shape) + + def compute_output_shape(self, input_shape): + projected_patch_shape = self.projection.compute_output_shape( + ( + None, + input_shape[0], + input_shape[1], + input_shape[2], + input_shape[3], + ) + ) + return self.flatten.compute_output_shape(projected_patch_shape) + + def call(self, videos): + projected_patches = self.projection(videos) + flattened_patches = self.flatten(projected_patches) + return flattened_patches + + +@keras_cv_export( + "keras_cv.layers.PositionalEncoder", + package="keras_cv.layers", +) +class PositionalEncoder(keras.layers.Layer): + """ + A Keras layer for adding positional information to the encoded video tokens. + + References: + - [ViViT: A Video Vision Transformer](https://arxiv.org/abs/2103.15691) + (ICCV 2021) + + Args: + embed_dim: int, number of dimensions in the embedding space. + Defaults to 128. + + """ + + def __init__(self, embed_dim=128, **kwargs): + super().__init__(**kwargs) + self.embed_dim = embed_dim + + def build(self, input_shape): + super().build(input_shape) + _, num_tokens, _ = input_shape + self.position_embedding = keras.layers.Embedding( + input_dim=num_tokens, output_dim=self.embed_dim + ) + self.position_embedding.build(input_shape) + self.positions = ops.arange(start=0, stop=num_tokens, step=1) + + def call(self, encoded_tokens): + encoded_positions = self.position_embedding(self.positions) + encoded_tokens = encoded_tokens + encoded_positions + return encoded_tokens diff --git a/keras_cv/models/video_classification/vivit_test.py b/keras_cv/models/video_classification/vivit_test.py new file mode 100644 index 0000000000..ed561debc3 --- /dev/null +++ b/keras_cv/models/video_classification/vivit_test.py @@ -0,0 +1,158 @@ +# Copyright 2024 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import tensorflow as tf + +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.backend.config import keras_3 +from keras_cv.models.video_classification.vivit import ViViT +from keras_cv.tests.test_case import TestCase + + +class ViViT_Test(TestCase): + def test_vivit_construction(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + model = ViViT( + projection_dim=projection_dim, + patch_size=patch_size, + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy( + 5, name="top-5-accuracy" + ), + ], + ) + + def test_vivit_call(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + model = ViViT( + projection_dim=projection_dim, + patch_size=patch_size, + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + model.build(input_shape) + frames = np.random.uniform(size=(5, 28, 28, 28, 1)) + _ = model(frames) + + def test_weights_change(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + frames = np.random.uniform(size=(5, 28, 28, 28, 1)) + labels = np.ones(shape=(5)) + ds = tf.data.Dataset.from_tensor_slices((frames, labels)) + ds = ds.repeat(2) + ds = ds.batch(2) + + model = ViViT( + projection_dim=projection_dim, + patch_size=patch_size, + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + + model.compile( + optimizer="adam", + loss="sparse_categorical_crossentropy", + metrics=[ + keras.metrics.SparseCategoricalAccuracy(name="accuracy"), + keras.metrics.SparseTopKCategoricalAccuracy( + 5, name="top-5-accuracy" + ), + ], + ) + model.build(input_shape) + representation_layer = model.get_layer(index=-8) # Accesses MHSA Layer + original_weights = representation_layer.get_weights() + model.fit(ds, epochs=1) + updated_weights = representation_layer.get_weights() + + for w1, w2 in zip(original_weights, updated_weights): + self.assertNotAllEqual(w1, w2) + self.assertFalse(ops.any(ops.isnan(w2))) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + input_shape = (28, 28, 28, 1) + num_classes = 11 + patch_size = (8, 8, 8) + layer_norm_eps = 1e-6 + projection_dim = 128 + num_heads = 8 + num_layers = 8 + + model = ViViT( + projection_dim=projection_dim, + patch_size=patch_size, + inp_shape=input_shape, + transformer_layers=num_layers, + num_heads=num_heads, + layer_norm_eps=layer_norm_eps, + num_classes=num_classes, + ) + model.build(input_shape) + input_batch = np.random.uniform(size=(5, 28, 28, 28, 1)) + model_output = model(input_batch) + + save_path = os.path.join(self.get_temp_dir(), "model.keras") + if keras_3(): + model.save(save_path) + else: + model.save(save_path, save_format="keras_v3") + restored_model = keras.models.load_model(save_path) + + self.assertIsInstance(restored_model, ViViT) + + restored_output = restored_model(input_batch) + self.assertAllClose(model_output, restored_output)