1+ import tensorflow as tf
2+ from keras import layers
3+
4+ class SelfAttention (layers .Layer ):
5+ """ A self-attention layer for convolutional neural networks.
6+
7+ This layer takes as input a tensor of shape (batch_size, height, width, channels)
8+ and applies self-attention to the channels dimension.
9+
10+ Args:
11+ num_heads (int): The number of attention heads to use. Defaults to 8.
12+ wrapper (tf.keras.layers.Wrapper): A wrapper layer to apply to the convolutional layers.
13+
14+ Raises:
15+ TypeError: If `wrapper` is provided and is not a subclass of `tf.keras.layers.Wrapper`.
16+ """
17+ def __init__ (self , num_heads : int = 8 , wrapper : tf .keras .layers .Wrapper = None ):
18+ super (SelfAttention , self ).__init__ ()
19+ self .num_heads = num_heads
20+ self .wrapper = wrapper
21+
22+ if wrapper and not issubclass (wrapper , tf .keras .layers .Wrapper ):
23+ raise TypeError ("wrapper must be a class derived from tf.keras.layers.Wrapper" )
24+
25+ def get_config (self ) -> dict :
26+ config = super ().get_config ()
27+ config .update ({
28+ "num_heads" : self .num_heads ,
29+ })
30+ return config
31+
32+ def build (self , input_shape ):
33+ _ , h , w , c = input_shape
34+ self .query_conv = self ._conv (filters = c // self .num_heads )
35+ self .key_conv = self ._conv (filters = c // self .num_heads )
36+ self .value_conv = self ._conv (filters = c )
37+ self .gamma = self .add_weight ("gamma" , shape = [1 ], initializer = tf .zeros_initializer (), trainable = True )
38+
39+ def _conv (self , filters : int ) -> tf .keras .layers .Layer :
40+ """ Helper function to create a convolutional layer with the given number of filters.
41+
42+ Args:
43+ filters (int): The number of filters to use.
44+
45+ Returns:
46+ tf.keras.layers.Layer: The created convolutional layer.
47+ """
48+ conv = layers .Conv2D (filters = filters , kernel_size = 1 , strides = 1 , padding = 'same' )
49+ if self .wrapper :
50+ conv = self .wrapper (conv )
51+
52+ return conv
53+
54+ def call (self , inputs : tf .Tensor ) -> tf .Tensor :
55+ """ Apply the self-attention mechanism to the input tensor.
56+
57+ Args:
58+ inputs (tf.Tensor): The input tensor of shape (batch_size, height, width, channels).
59+
60+ Returns:
61+ tf.Tensor: The output tensor after the self-attention mechanism is applied.
62+ """
63+ _ , h , w , c = inputs .shape
64+ q = self .query_conv (inputs )
65+ k = self .key_conv (inputs )
66+ v = self .value_conv (inputs )
67+
68+ q_reshaped = tf .reshape (q , [- 1 , h * w , c // self .num_heads ])
69+ k_reshaped = tf .reshape (k , [- 1 , h * w , c // self .num_heads ])
70+ v_reshaped = tf .reshape (v , [- 1 , h * w , c ])
71+
72+ # Compute the attention scores by taking the dot product of the query and key tensors.
73+ attention_scores = tf .matmul (q_reshaped , k_reshaped , transpose_b = True )
74+
75+ # Scale the attention scores by the square root of the number of channels.
76+ attention_scores = attention_scores / tf .sqrt (tf .cast (c // self .num_heads , dtype = tf .float32 ))
77+
78+ # Apply a softmax function to the attention scores to obtain the attention weights.
79+ attention_weights = tf .nn .softmax (attention_scores , axis = - 1 )
80+
81+ # Apply the attention weights to the value tensor to obtain the attention output.
82+ attention_output = tf .matmul (attention_weights , v_reshaped )
83+
84+ # Reshape the attended value tensor to the original input tensor shape.
85+ attention_output = tf .reshape (attention_output , [- 1 , h , w , c ])
86+
87+ # Apply the gamma parameter to the attended value tensor and add it to the output tensor.
88+ attention_output = self .gamma * attention_output + inputs
89+
90+ return attention_output
0 commit comments