Skip to content

Commit fdc5c9d

Browse files
committed
Introduced selfAttention in mltu.tensorflow.layers
1 parent d8a212d commit fdc5c9d

File tree

3 files changed

+96
-3
lines changed

3 files changed

+96
-3
lines changed

CHANGELOG.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
## [1.0.7] - 2022-04-14
2+
### Added
3+
- Added `SelfAttention` layer into `mltu.tensorflow.layers` to use with Conv2D layers (need more testings).
4+
5+
16
## [1.0.6] - 2022-04-13
27
### Changed
38
- Fixed bug in `mltu.dataProvider.DataProvider` object to work without `data_preprocessors` when feeding loaded data in memory
@@ -11,8 +16,6 @@
1116
- Fix `ImageReader` to work either with image path or `np.ndarray`
1217
- Added `metadata` support to `callbacks/tf2onnx` when converting to onnx format
1318

14-
### Added
15-
-
1619

1720
## [1.0.3] - 2022-03-20
1821
### Changed

mltu/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "1.0.6"
1+
__version__ = "1.0.7"
22

33
from .annotations.image import Image

mltu/tensorflow/layers.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import tensorflow as tf
2+
from keras import layers
3+
4+
class SelfAttention(layers.Layer):
5+
""" A self-attention layer for convolutional neural networks.
6+
7+
This layer takes as input a tensor of shape (batch_size, height, width, channels)
8+
and applies self-attention to the channels dimension.
9+
10+
Args:
11+
num_heads (int): The number of attention heads to use. Defaults to 8.
12+
wrapper (tf.keras.layers.Wrapper): A wrapper layer to apply to the convolutional layers.
13+
14+
Raises:
15+
TypeError: If `wrapper` is provided and is not a subclass of `tf.keras.layers.Wrapper`.
16+
"""
17+
def __init__(self, num_heads: int = 8, wrapper: tf.keras.layers.Wrapper = None):
18+
super(SelfAttention, self).__init__()
19+
self.num_heads = num_heads
20+
self.wrapper = wrapper
21+
22+
if wrapper and not issubclass(wrapper, tf.keras.layers.Wrapper):
23+
raise TypeError("wrapper must be a class derived from tf.keras.layers.Wrapper")
24+
25+
def get_config(self) -> dict:
26+
config = super().get_config()
27+
config.update({
28+
"num_heads": self.num_heads,
29+
})
30+
return config
31+
32+
def build(self, input_shape):
33+
_, h, w, c = input_shape
34+
self.query_conv = self._conv(filters=c // self.num_heads)
35+
self.key_conv = self._conv(filters=c // self.num_heads)
36+
self.value_conv = self._conv(filters=c)
37+
self.gamma = self.add_weight("gamma", shape=[1], initializer=tf.zeros_initializer(), trainable=True)
38+
39+
def _conv(self, filters: int) -> tf.keras.layers.Layer:
40+
""" Helper function to create a convolutional layer with the given number of filters.
41+
42+
Args:
43+
filters (int): The number of filters to use.
44+
45+
Returns:
46+
tf.keras.layers.Layer: The created convolutional layer.
47+
"""
48+
conv = layers.Conv2D(filters=filters, kernel_size=1, strides=1, padding='same')
49+
if self.wrapper:
50+
conv = self.wrapper(conv)
51+
52+
return conv
53+
54+
def call(self, inputs: tf.Tensor) -> tf.Tensor:
55+
""" Apply the self-attention mechanism to the input tensor.
56+
57+
Args:
58+
inputs (tf.Tensor): The input tensor of shape (batch_size, height, width, channels).
59+
60+
Returns:
61+
tf.Tensor: The output tensor after the self-attention mechanism is applied.
62+
"""
63+
_, h, w, c = inputs.shape
64+
q = self.query_conv(inputs)
65+
k = self.key_conv(inputs)
66+
v = self.value_conv(inputs)
67+
68+
q_reshaped = tf.reshape(q, [-1, h * w, c // self.num_heads])
69+
k_reshaped = tf.reshape(k, [-1, h * w, c // self.num_heads])
70+
v_reshaped = tf.reshape(v, [-1, h * w, c])
71+
72+
# Compute the attention scores by taking the dot product of the query and key tensors.
73+
attention_scores = tf.matmul(q_reshaped, k_reshaped, transpose_b=True)
74+
75+
# Scale the attention scores by the square root of the number of channels.
76+
attention_scores = attention_scores / tf.sqrt(tf.cast(c // self.num_heads, dtype=tf.float32))
77+
78+
# Apply a softmax function to the attention scores to obtain the attention weights.
79+
attention_weights = tf.nn.softmax(attention_scores, axis=-1)
80+
81+
# Apply the attention weights to the value tensor to obtain the attention output.
82+
attention_output = tf.matmul(attention_weights, v_reshaped)
83+
84+
# Reshape the attended value tensor to the original input tensor shape.
85+
attention_output = tf.reshape(attention_output, [-1, h, w, c])
86+
87+
# Apply the gamma parameter to the attended value tensor and add it to the output tensor.
88+
attention_output = self.gamma * attention_output + inputs
89+
90+
return attention_output

0 commit comments

Comments
 (0)