Skip to content

Commit c9cb226

Browse files
authored
Add CLIP text model (#643)
Ports the CLIP text model from Hugging Face. This is the first iteration so not much is changed from the original model. Things like dropout and checkpointing are removed. Add numeric verification tests for the various components of the stack when executing in eager mode. Verifications are made for float32 and bfloat16. There are tests for toy-sized components and the whole model as well as the Large pretrained variant. These tests does not include testing with IREE. Functionalities for mask creation are not yet ported.
1 parent 2f5bfab commit c9cb226

File tree

20 files changed

+1209
-35
lines changed

20 files changed

+1209
-35
lines changed

.github/workflows/ci-sharktank.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,13 @@ jobs:
122122
iree-base-runtime
123123
124124
- name: Run tests
125+
# TODO: unify with-t5-data and with-clip-data flags into a single flag
126+
# and make it possible to run only tests that require data.
125127
run: |
126128
pytest \
129+
--with-clip-data \
127130
--with-t5-data \
131+
sharktank/tests/models/clip/clip_test.py \
128132
sharktank/tests/models/t5/t5_test.py \
129133
--durations=0
130134

sharktank/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def pytest_addoption(parser):
8888
help="Enable all llama benchmarking tests",
8989
)
9090

91+
parser.addoption(
92+
"--with-clip-data",
93+
action="store_true",
94+
default=False,
95+
help=(
96+
"Enable tests that use CLIP data like models that is not a part of the source "
97+
"code. The user is expected to provide the data"
98+
),
99+
)
91100
parser.addoption(
92101
"--with-t5-data",
93102
action="store_true",

sharktank/sharktank/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .kv_cache import BaseKVCache, DirectKVCache, PagedKVCache
1010
from .causal_llm import BaseCausalLMModel
1111
from .linear import LinearLayer
12-
from .norm import RMSNormLayer
12+
from .norm import RMSNormLayer, LayerNorm
1313
from .rotary_embedding import RotaryEmbeddingLayer
1414
from .token_embedding import TokenEmbeddingLayer
1515
from .llama_attention_block import LlamaAttentionBlock
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from torch import nn
8+
from .. import ops
9+
10+
# TODO: don't use nn.functional directly.
11+
ACT2FN = {
12+
"gelu": nn.functional.gelu,
13+
"gelu_new": ops.gelu_tanh_approximation,
14+
"relu": nn.functional.relu,
15+
"quick_gelu": ops.gelu_sigmoid_approximation,
16+
}

sharktank/sharktank/layers/configs/llm_configs.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
(and indeed, can bootstrap these off of GGUF files).
1515
"""
1616

17-
from dataclasses import dataclass, field
17+
from dataclasses import asdict, dataclass, field
1818
from typing import Any, Optional
1919
import torch
2020

21-
__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"]
21+
__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"]
2222

2323

2424
@dataclass
@@ -266,3 +266,49 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs):
266266
all_kwargs.update(kwargs)
267267

268268
return T5Config(**all_kwargs)
269+
270+
271+
@dataclass
272+
class ClipTextConfig:
273+
vocab_size: int = 49408
274+
hidden_size: int = 512
275+
intermediate_size: int = 2048
276+
projection_dim: int = 512
277+
num_hidden_layers: int = 12
278+
num_attention_heads: int = 8
279+
max_position_embeddings: int = 77
280+
hidden_act: str = "quick_gelu"
281+
layer_norm_eps: float = 1e-5
282+
# This differs from `CLIPTokenizer`'s default and from openai/clip
283+
# See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538
284+
pad_token_id: int = 1
285+
bos_token_id: int = 49406
286+
eos_token_id: int = 49407
287+
output_attentions: bool = False
288+
output_hidden_states: bool = False
289+
use_return_dict: bool = True
290+
291+
@staticmethod
292+
def from_transformers_clip_text_config(
293+
config: "transformers.CLIPTextConfig",
294+
) -> "ClipTextConfig":
295+
return ClipTextConfig(
296+
vocab_size=config.vocab_size,
297+
hidden_size=config.hidden_size,
298+
intermediate_size=config.intermediate_size,
299+
projection_dim=config.projection_dim,
300+
num_hidden_layers=config.num_hidden_layers,
301+
num_attention_heads=config.num_attention_heads,
302+
max_position_embeddings=config.max_position_embeddings,
303+
hidden_act=config.hidden_act,
304+
layer_norm_eps=config.layer_norm_eps,
305+
pad_token_id=config.pad_token_id,
306+
bos_token_id=config.bos_token_id,
307+
eos_token_id=config.eos_token_id,
308+
output_attentions=config.output_attentions,
309+
output_hidden_states=config.output_hidden_states,
310+
use_return_dict=config.use_return_dict,
311+
)
312+
313+
def as_properties(self) -> dict[str, Any]:
314+
return asdict(self)

sharktank/sharktank/layers/norm.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,23 @@ def forward(self, x: torch.Tensor):
3939
# often in higher precision. Downcast back to expected.
4040
norm = ops.to(norm, orig_dtype)
4141
return norm
42+
43+
44+
class LayerNorm(ThetaLayer):
45+
def __init__(
46+
self,
47+
theta: Theta,
48+
*,
49+
weight_name: str = "weight",
50+
bias_name: str = "bias",
51+
eps: float = 1e-05,
52+
):
53+
super().__init__(theta)
54+
self.weight = self.theta_tensor(weight_name)
55+
self.bias = None
56+
if bias_name in self.theta.keys:
57+
self.bias = self.theta_tensor(bias_name)
58+
self.eps = eps
59+
60+
def forward(self, x: torch.Tensor):
61+
return ops.layer_norm(x, weight=self.weight, bias=self.bias, eps=self.eps)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from .clip import *
8+
from .export import *

0 commit comments

Comments
 (0)