Skip to content

Commit cb58448

Browse files
author
Kye
committed
[FEAT]-[Module]: [return_loss_text]: Add [return_loss_text] function for enhanced loss computation readability
[FEAT]-[Module]: [calc_z_loss]: Introduce [calc_z_loss] function to calculate Z loss in model training [FEAT]-[Module]: [max_neg_value]: Implement [max_neg_value] function for negative value handling in computations [FEAT]-[Module]: [TextTokenEmbedding]: Deploy [TextTokenEmbedding] for improved text token embedding functionality [FEAT]-[Module]: [dropout_seq]: Add [dropout_seq] function for sequence dropout in neural network layers [FEAT]-[Module]: [transformer_generate]: Introduce [transformer_generate] function for efficient transformer text generation [FEAT]-[Module]: [vit_output_head]: Add [vit_output_head] for Vision Transformer model output handling [FEAT]-[Module]: [patch_linear_flatten]: Implement [patch_linear_flatten] for streamlined linear patch flattening in ViT [FEAT]-[Module]: [ScalableImgSelfAttention]: Introduce [ScalableImgSelfAttention] for scalable image self-attention mechanism ]
1 parent b9b67a7 commit cb58448

File tree

10 files changed

+530
-15
lines changed

10 files changed

+530
-15
lines changed

playground/models/spectra.py

Whitespace-only changes.

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zetascale"
3-
version = "2.3.3"
3+
version = "2.3.5"
44
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
55
authors = ["Zeta Team <kye@apac.ai>"]
66
license = "MIT"
@@ -35,6 +35,7 @@ tqdm = "4.66.2"
3535
rich = "13.7.1"
3636
colt5-attention = "*"
3737
argparse = "^1.4.0"
38+
local-attention = "*"
3839

3940
[build-system]
4041
requires = ["poetry-core>=1.0.0"]

zeta/nn/attention/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from zeta.nn.attention.spatial_linear_attention import SpatialLinearAttention
2323
from zeta.structs.transformer import Attention, AttentionLayers
2424
from zeta.nn.attention.multi_grouped_attn import MultiGroupedQueryAttn
25-
26-
# from zeta.nn.attention.flash_attention2 import FlashAttentionTwo
27-
# from zeta.nn.attention.mgqa import MGQA
28-
25+
from zeta.nn.attention.scalable_img_self_attn import ScalableImgSelfAttention
2926

3027
__all__ = [
3128
"Attend",
@@ -48,4 +45,5 @@
4845
"Attention",
4946
"AttentionLayers",
5047
"MultiGroupedQueryAttn",
48+
"ScalableImgSelfAttention",
5149
]
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
from torch import nn, Tensor
3+
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm
4+
from einops import rearrange
5+
6+
7+
class ScalableImgSelfAttention(nn.Module):
8+
"""
9+
ScalableImgSelfAttention module applies self-attention mechanism to image data.
10+
11+
Args:
12+
dim (int): The input dimension of the image.
13+
heads (int, optional): The number of attention heads. Defaults to 8.
14+
dim_key (int, optional): The dimension of the key vectors. Defaults to 32.
15+
dim_value (int, optional): The dimension of the value vectors. Defaults to 32.
16+
dropout (float, optional): The dropout rate. Defaults to 0.0.
17+
reduction_factor (int, optional): The reduction factor for downscaling the image. Defaults to 1.
18+
19+
Attributes:
20+
dim (int): The input dimension of the image.
21+
heads (int): The number of attention heads.
22+
dim_key (int): The dimension of the key vectors.
23+
dim_value (int): The dimension of the value vectors.
24+
reduction_factor (int): The reduction factor for downscaling the image.
25+
scale (float): The scaling factor for the key vectors.
26+
attend (nn.Softmax): The softmax function for attention calculation.
27+
dropout (nn.Dropout): The dropout layer.
28+
norm (ChanLayerNorm): The channel-wise layer normalization.
29+
to_q (nn.Conv2d): The convolutional layer for query projection.
30+
to_k (nn.Conv2d): The convolutional layer for key projection.
31+
to_v (nn.Conv2d): The convolutional layer for value projection.
32+
to_out (nn.Sequential): The sequential layer for output projection.
33+
34+
"""
35+
36+
def __init__(
37+
self,
38+
dim: int,
39+
heads: int = 8,
40+
dim_key: int = 32,
41+
dim_value: int = 32,
42+
dropout: float = 0.0,
43+
reduction_factor: int = 1,
44+
*args,
45+
**kwargs,
46+
):
47+
super().__init__()
48+
self.dim = dim
49+
self.heads = heads
50+
self.dim_key = dim_key
51+
self.dim_value = dim_value
52+
self.reduction_factor = reduction_factor
53+
54+
self.scale = dim_key**-0.5
55+
self.attend = nn.Softmax(dim=-1)
56+
self.dropout = nn.Dropout(dropout)
57+
self.norm = ChanLayerNorm(dim)
58+
59+
# Projections
60+
self.to_q = nn.Conv2d(dim, dim_key * heads, 1, bias=False)
61+
self.to_k = nn.Conv2d(
62+
dim,
63+
dim_key * heads,
64+
reduction_factor,
65+
stride=reduction_factor,
66+
bias=False,
67+
)
68+
self.to_v = nn.Conv2d(
69+
dim,
70+
dim_value * heads,
71+
reduction_factor,
72+
stride=reduction_factor,
73+
bias=False,
74+
)
75+
76+
self.to_out = nn.Sequential(
77+
nn.Conv2d(dim_value * heads, dim, 1), nn.Dropout(dropout)
78+
)
79+
80+
def forward(self, x: Tensor) -> Tensor:
81+
"""
82+
Forward pass of the ScalableImgSelfAttention module.
83+
84+
Args:
85+
x (Tensor): The input tensor of shape (batch_size, channels, height, width).
86+
87+
Returns:
88+
Tensor: The output tensor of shape (batch_size, channels, height, width).
89+
90+
"""
91+
h, w, h = *x.shape[-2:], self.heads
92+
93+
x = self.norm(x)
94+
95+
q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)
96+
97+
# Split out heads
98+
q, k, v = map(
99+
lambda t: rearrange(t, "b (h d) ... -> b h (...) d", h=h),
100+
(
101+
q,
102+
k,
103+
),
104+
)
105+
106+
# Similarity
107+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
108+
109+
# Attention
110+
attn = self.attend(dots)
111+
attn = self.dropout(attn)
112+
113+
# Aggregate values
114+
out = torch.matmul(attn, v)
115+
116+
# Merge back heads
117+
out = rearrange(
118+
out,
119+
"b h (x y) d -> b (h d) x y",
120+
x=h,
121+
y=w,
122+
)
123+
return self.to_out(out)
124+
125+
126+
# x = torch.randn(1, 3, 64, 64)
127+
# peg = ScalableImgSelfAttention(3)
128+
# out = peg(x)
129+
# print(out.shape)

zeta/nn/modules/__init__.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,20 @@
195195
NormalSparseMoE,
196196
HeirarchicalSparseMoE,
197197
)
198+
from zeta.nn.modules.return_loss_text import (
199+
return_loss_text,
200+
calc_z_loss,
201+
max_neg_value,
202+
TextTokenEmbedding,
203+
dropout_seq,
204+
transformer_generate,
205+
)
206+
from zeta.nn.modules.patch_linear_flatten import (
207+
vit_output_head,
208+
patch_linear_flatten,
209+
)
210+
from zeta.nn.modules.chan_layer_norm import ChanLayerNorm
211+
198212

199213
# from zeta.nn.modules.img_reshape import image_reshape
200214
# from zeta.nn.modules.flatten_features import flatten_features
@@ -392,4 +406,14 @@
392406
"Top2Gating",
393407
"NormalSparseMoE",
394408
"HeirarchicalSparseMoE",
409+
"return_loss_text",
410+
"calc_z_loss",
411+
"max_neg_value",
412+
"TextTokenEmbedding",
413+
"dropout_seq",
414+
"transformer_generate",
415+
"patch_linear_flatten",
416+
"vit_output_head",
417+
"posemb_sincos_2d",
418+
"ChanLayerNorm",
395419
]

zeta/nn/modules/chan_layer_norm.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from torch import nn, Tensor
3+
4+
5+
class ChanLayerNorm(nn.Module):
6+
def __init__(self, dim: int, eps: float = 1e-5):
7+
"""
8+
Initializes the ChanLayerNorm module.
9+
10+
Args:
11+
dim (int): The input dimension.
12+
eps (float, optional): The epsilon value. Defaults to 1e-5.
13+
"""
14+
super().__init__()
15+
self.dim = dim
16+
self.eps = eps
17+
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
18+
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
19+
20+
def forward(self, x: Tensor):
21+
"""
22+
Forward pass of the ChanLayerNorm module.
23+
24+
Args:
25+
x (Tensor): The input tensor.
26+
27+
Returns:
28+
Tensor: The normalized tensor.
29+
"""
30+
var = torch.car(
31+
x,
32+
dim=1,
33+
unbiased=False,
34+
keepdim=True,
35+
)
36+
mean = torch.mean(x, dim=1, keepdim=True)
37+
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import torch
2+
from torch import nn, Tensor
3+
from einops.layers.torch import Rearrange
4+
5+
6+
def posemb_sincos_2d(patches, temperature=10000, dtype=torch.float32):
7+
_, h, w, dim, device, dtype = *patches.shape, patches.device, patches.dtype
8+
9+
y, x = torch.meshgrid(
10+
torch.arange(h, device=device),
11+
torch.arange(w, device=device),
12+
indexing="ij",
13+
)
14+
assert (
15+
dim % 4
16+
) == 0, "feature dimension must be multiple of 4 for sincos emb"
17+
omega = torch.arange(dim // 4, device=device) / (dim // 4 - 1)
18+
omega = 1.0 / (temperature**omega)
19+
20+
y = y.flatten()[:, None] * omega[None, :]
21+
x = x.flatten()[:, None] * omega[None, :]
22+
pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
23+
return pe.type(dtype)
24+
25+
26+
def vit_output_head(x: Tensor, dim: int, num_classes: int = None):
27+
"""
28+
Applies a Vision Transformer (ViT) output head to the input tensor.
29+
30+
Args:
31+
x (Tensor): The input tensor.
32+
dim (int): The dimension of the input tensor.
33+
num_classes (int, optional): The number of output classes. Defaults to None.
34+
35+
Returns:
36+
Tensor: The output tensor after applying the ViT output head.
37+
"""
38+
return nn.Sequential(nn.LayerNorm(dim), nn.Linear(dim, num_classes))(x)
39+
40+
41+
def patch_linear_flatten(
42+
x: Tensor,
43+
patch_size: int,
44+
dim: int,
45+
image_size: int,
46+
channels: int = 3,
47+
add_pos_embeddings: bool = False,
48+
*args,
49+
**kwargs,
50+
):
51+
"""
52+
Applies patch embedding to the input tensor and flattens it.
53+
54+
Args:
55+
x (Tensor): Input tensor of shape (batch_size, channels, image_height, image_width).
56+
patch_size (int): Size of the square patch.
57+
dim (int): Dimension of the output tensor.
58+
image_size (int): Size of the input image (assumed to be square).
59+
channels (int, optional): Number of input channels. Defaults to 3.
60+
add_pos_embeddings (bool, optional): Whether to add positional embeddings. Defaults to False.
61+
62+
Returns:
63+
Tensor: Flattened tensor of shape (batch_size, num_patches, dim).
64+
"""
65+
image_height, image_width = image_size, image_size
66+
patch_height, patch_width = patch_size, patch_size
67+
68+
# calculate number of patches
69+
(image_height // patch_height) * (image_width // patch_width)
70+
patch_dim = channels * patch_height * patch_width
71+
72+
# Patch Embedding layer
73+
to_patch_embeddings = nn.Sequential(
74+
Rearrange(
75+
"b c (h p1) (w p2) -> b h w (p1 p2 c)",
76+
p1=patch_height,
77+
p2=patch_width,
78+
),
79+
nn.LayerNorm(patch_dim),
80+
nn.Linear(patch_dim, dim),
81+
nn.LayerNorm(dim),
82+
)(x)
83+
84+
if add_pos_embeddings is not False:
85+
pos_embeddings = posemb_sincos_2d(x, *args, **kwargs)
86+
to_patch_embeddings + +pos_embeddings
87+
88+
return to_patch_embeddings

zeta/nn/modules/peg.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from torch import nn, Tensor
2+
3+
4+
class PEG(nn.Module):
5+
"""
6+
PEG (Positional Encoding Generator) module.
7+
8+
Args:
9+
dim (int): The input dimension.
10+
kernel_size (int, optional): The size of the convolutional kernel. Defaults to 3.
11+
"""
12+
13+
def __init__(self, dim: int, kernel_size: int = 3):
14+
super().__init__()
15+
self.proj = nn.Conv2d(
16+
dim,
17+
dim,
18+
kernel_size=kernel_size,
19+
padding=kernel_size // 2,
20+
groups=dim,
21+
stride=1,
22+
)
23+
24+
def forward(self, x: Tensor):
25+
"""
26+
Forward pass of the PEG module.
27+
28+
Args:
29+
x (Tensor): The input tensor.
30+
31+
Returns:
32+
Tensor: The output tensor.
33+
"""
34+
return self.proj(x) + x

0 commit comments

Comments
 (0)