|
14 | 14 | (and indeed, can bootstrap these off of GGUF files). |
15 | 15 | """ |
16 | 16 |
|
17 | | -from dataclasses import dataclass, field |
| 17 | +from dataclasses import asdict, dataclass, field |
18 | 18 | from typing import Any, Optional |
19 | 19 | import torch |
20 | 20 |
|
21 | | -__all__ = ["LlamaHParams", "LlamaModelConfig", "T5Config"] |
| 21 | +__all__ = ["ClipTextConfig", "LlamaHParams", "LlamaModelConfig", "T5Config"] |
22 | 22 |
|
23 | 23 |
|
24 | 24 | @dataclass |
@@ -266,3 +266,49 @@ def from_gguf_properties(properties: dict[str, Any], **kwargs): |
266 | 266 | all_kwargs.update(kwargs) |
267 | 267 |
|
268 | 268 | return T5Config(**all_kwargs) |
| 269 | + |
| 270 | + |
| 271 | +@dataclass |
| 272 | +class ClipTextConfig: |
| 273 | + vocab_size: int = 49408 |
| 274 | + hidden_size: int = 512 |
| 275 | + intermediate_size: int = 2048 |
| 276 | + projection_dim: int = 512 |
| 277 | + num_hidden_layers: int = 12 |
| 278 | + num_attention_heads: int = 8 |
| 279 | + max_position_embeddings: int = 77 |
| 280 | + hidden_act: str = "quick_gelu" |
| 281 | + layer_norm_eps: float = 1e-5 |
| 282 | + # This differs from `CLIPTokenizer`'s default and from openai/clip |
| 283 | + # See https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 |
| 284 | + pad_token_id: int = 1 |
| 285 | + bos_token_id: int = 49406 |
| 286 | + eos_token_id: int = 49407 |
| 287 | + output_attentions: bool = False |
| 288 | + output_hidden_states: bool = False |
| 289 | + use_return_dict: bool = True |
| 290 | + |
| 291 | + @staticmethod |
| 292 | + def from_transformers_clip_text_config( |
| 293 | + config: "transformers.CLIPTextConfig", |
| 294 | + ) -> "ClipTextConfig": |
| 295 | + return ClipTextConfig( |
| 296 | + vocab_size=config.vocab_size, |
| 297 | + hidden_size=config.hidden_size, |
| 298 | + intermediate_size=config.intermediate_size, |
| 299 | + projection_dim=config.projection_dim, |
| 300 | + num_hidden_layers=config.num_hidden_layers, |
| 301 | + num_attention_heads=config.num_attention_heads, |
| 302 | + max_position_embeddings=config.max_position_embeddings, |
| 303 | + hidden_act=config.hidden_act, |
| 304 | + layer_norm_eps=config.layer_norm_eps, |
| 305 | + pad_token_id=config.pad_token_id, |
| 306 | + bos_token_id=config.bos_token_id, |
| 307 | + eos_token_id=config.eos_token_id, |
| 308 | + output_attentions=config.output_attentions, |
| 309 | + output_hidden_states=config.output_hidden_states, |
| 310 | + use_return_dict=config.use_return_dict, |
| 311 | + ) |
| 312 | + |
| 313 | + def as_properties(self) -> dict[str, Any]: |
| 314 | + return asdict(self) |
0 commit comments