|
1 | 1 | # coding=utf-8 |
2 | | -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
| 2 | +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. |
3 | 3 | # |
4 | 4 | # Licensed under the Apache License, Version 2.0 (the "License"); |
5 | 5 | # you may not use this file except in compliance with the License. |
|
16 | 16 |
|
17 | 17 | from collections import OrderedDict |
18 | 18 | from typing import Mapping |
| 19 | +from typing import Dict, List, Optional, Set, Tuple, Union, Callable, Any |
19 | 20 | import functools |
20 | 21 | import torch.nn as nn |
21 | 22 |
|
|
29 | 30 | logger = logging.get_logger(__name__) |
30 | 31 |
|
31 | 32 |
|
32 | | -class Aimv2Config(PretrainedConfig): |
| 33 | +class AIMv2Config(PretrainedConfig): |
33 | 34 | r""" |
34 | | - This is the configuration class to store the configuration of a [`Aimv2Model`]. It is used to instantiate a AIM-v2 |
| 35 | + This is the configuration class to store the configuration of a [`AIMv2Model`]. It is used to instantiate a AIM-v2 |
35 | 36 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the |
36 | 37 | defaults will yield a similar configuration to that of the AIM-v2 [apple/aimv2-large-patch14-224](...) |
37 | 38 | architecture. |
@@ -65,76 +66,120 @@ class Aimv2Config(PretrainedConfig): |
65 | 66 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
66 | 67 | layer_norm_eps (`float`, *optional*, defaults to 1e-5): |
67 | 68 | The epsilon used by the layer normalization layers. |
| 69 | + qkv_bias (`bool`, *optional*, defaults to `False`): |
| 70 | + Whether or not to use bias in query, key, value. |
| 71 | + use_bias (`bool`, *optional*, defaults to `False`): |
| 72 | + Whether or not to use bias in all linear layers. |
68 | 73 | use_cls_token (`bool`, *optional*, defaults to `False`): |
69 | 74 | Whether or not to use a classification token. |
70 | | - use_mask_token (`bool`, *optional*, defaults to `False`): |
71 | | - Whether or not to use a mask token. |
72 | | - use_pos_embed (`str`, *optional*, defaults to `"absolute"`): |
| 75 | + pos_embed_type (`str`, *optional*, defaults to `"absolute"`): |
73 | 76 | Positional embedding type. Choose from 'absolute', 'sincos', or 'none'. |
| 77 | + use_rms_norm (`bool`, *optional*, defaults to `False`): |
| 78 | + Whether or not to use RMS norm. |
| 79 | + post_trunk_norm (`bool`, *optional*, defaults to `False`): |
| 80 | + Whether or not to use norm layer after the transformer blocks (layers). |
| 81 | + probe_layers (`int`, *optional*, defaults to 6): |
| 82 | + The layer ids to use for selecting features. |
| 83 | + reduce (`bool`, *optional*, defaults to `False`): |
| 84 | + Whether or not to reduce features using mean. |
| 85 | + ffn_target_type (`str`, *optional*, defaults to `"swiglu"`): |
| 86 | + Type of feedforward network (FFN) to use. |
| 87 | + is_causal (`bool`, *optional*, defaults to `False`): |
| 88 | + Whether or not to use causal attention. |
74 | 89 | norm_layer (`[torch.nn.Module]`, *optional*, defaults to `torch.nn.LayerNorm`): |
75 | 90 | Normalization layer to use. |
76 | | - Example: |
| 91 | + num_queries (`int`, *optional*, defaults to 1): |
| 92 | + Number of query tokens for attention pooling. |
| 93 | + use_batch_norm (`bool`, *optional*, defaults to `True`): |
| 94 | + Whether to use batch normalization in attention pooling. |
| 95 | + proj_bias (`bool`, *optional*, defaults to `False`): |
| 96 | + Whether to use bias in the projection layer of the attention pooling. |
| 97 | + average_pool (`bool`, *optional*, defaults to `True`): |
| 98 | + Whether to use average pooling in the attention pooling. |
| 99 | + num_labels (`int`, *optional*, defaults to 1000): |
| 100 | + The number of labels for classification tasks. |
| 101 | + **kwargs: |
| 102 | + Remaining keyword arguments are passed to the superclass. |
| 103 | +
|
| 104 | + Example: |
77 | 105 |
|
78 | 106 | ```python |
79 | | - >>> from aim.v2.configuration_aimv2 import Aimv2Config |
80 | | - >>> from aim.v2.modeling_aimv2 import Aimv2Model |
| 107 | + >>> from aim.v2.configuration_aimv2 import AIMv2Config |
81 | 108 |
|
82 | 109 | >>> # Initializing a aimv2-large-patch14-224 style configuration |
83 | | - >>> configuration = Aimv2Config() |
84 | | -
|
85 | | - >>> # Initializing a model (with random weights) from the aimv2-large-patch14-224 style configuration |
86 | | - >>> model = Aimv2Model(configuration) |
| 110 | + >>> configuration = AIMv2Config() |
87 | 111 |
|
88 | 112 | >>> # Accessing the model configuration |
89 | | - >>> configuration = model.config |
| 113 | + >>> print(configuration) |
90 | 114 | ``` |
91 | 115 | """ |
| 116 | + |
92 | 117 | model_type = "aimv2" |
93 | 118 |
|
94 | 119 | def __init__( |
95 | 120 | self, |
96 | | - image_size: int = 224, |
97 | | - patch_size: int = 14, |
| 121 | + image_size: Union[int, Tuple[int, int]] = 224, |
| 122 | + patch_size: Union[int, Tuple[int, int]] = 14, |
98 | 123 | num_channels: int = 3, |
99 | 124 | hidden_size: int = 1024, |
100 | 125 | num_hidden_layers: int = 24, |
101 | 126 | num_attention_heads: int = 16, |
102 | | - intermediate_size: int = 4096, |
103 | | - hidden_act: str = "gelu", |
104 | | - hidden_dropout_prob: float = 0.0, |
105 | | - attention_probs_dropout_prob: float = 0.0, |
| 127 | + #mlp_ratio: float = 4.0, |
| 128 | + hidden_act: Union[str, Callable] = "gelu", |
| 129 | + hidden_dropout_prob: float = 0.1, |
| 130 | + attention_probs_dropout_prob: float = 0.1, |
106 | 131 | initializer_range: float = 0.02, |
| 132 | + intermediate_size=2816, |
107 | 133 | layer_norm_eps: float = 1e-5, |
108 | | - use_cls_token: bool = False, |
109 | | - use_mask_token: bool = False, |
110 | | - use_pos_embed: str = "absolute", |
111 | 134 | qkv_bias: bool = False, |
112 | | - norm_layer=nn.LayerNorm, |
| 135 | + use_bias: bool = False, |
| 136 | + use_cls_token: bool = False, |
| 137 | + pos_embed_type: str = "absolute", |
| 138 | + #use_rms_norm: bool = False, |
| 139 | + post_trunk_norm: bool = True, |
| 140 | + probe_layers: Union[int, Tuple[int, ...]] = 6, |
| 141 | + reduce: bool = False, |
| 142 | + ffn_target_type: str = "swiglu", |
| 143 | + is_causal: bool = False, |
| 144 | + norm_layer: Optional[Callable[[int], nn.Module]] = nn.RMSNorm, |
| 145 | + num_queries: int = 1, |
| 146 | + use_batch_norm: bool = True, |
| 147 | + proj_bias: bool = False, |
| 148 | + average_pool: bool = True, |
| 149 | + num_labels: int = 1000, |
113 | 150 | **kwargs, |
114 | 151 | ): |
115 | 152 | super().__init__(**kwargs) |
| 153 | + |
116 | 154 | self.image_size = image_size |
117 | 155 | self.patch_size = patch_size |
118 | 156 | self.num_channels = num_channels |
119 | 157 | self.hidden_size = hidden_size |
120 | 158 | self.num_hidden_layers = num_hidden_layers |
121 | 159 | self.num_attention_heads = num_attention_heads |
122 | | - self.intermediate_size = intermediate_size |
| 160 | + #self.mlp_ratio = mlp_ratio |
123 | 161 | self.hidden_act = hidden_act |
124 | 162 | self.hidden_dropout_prob = hidden_dropout_prob |
125 | 163 | self.attention_probs_dropout_prob = attention_probs_dropout_prob |
126 | 164 | self.initializer_range = initializer_range |
| 165 | + self.intermediate_size=intermediate_size |
127 | 166 | self.layer_norm_eps = layer_norm_eps |
128 | | - self.use_cls_token = use_cls_token |
129 | | - self.use_mask_token = use_mask_token |
130 | | - self.use_pos_embed = use_pos_embed # we will use "sincos" or "absolute" |
131 | 167 | self.qkv_bias = qkv_bias |
132 | | - # If norm_layer is provided, use it, otherwise, default to nn.LayerNorm with the specified eps |
133 | | - self.norm_layer = ( |
134 | | - norm_layer |
135 | | - if norm_layer is not None |
136 | | - else functools.partial(nn.LayerNorm, eps=layer_norm_eps) |
137 | | - ) |
| 168 | + self.use_bias = use_bias |
| 169 | + self.use_cls_token = use_cls_token |
| 170 | + self.pos_embed_type = pos_embed_type |
| 171 | + #self.use_rms_norm = use_rms_norm |
| 172 | + self.post_trunk_norm = post_trunk_norm |
| 173 | + self.probe_layers = probe_layers |
| 174 | + self.reduce = reduce |
| 175 | + self.ffn_target_type = ffn_target_type |
| 176 | + self.is_causal = is_causal |
| 177 | + self.norm_layer = norm_layer |
| 178 | + self.num_queries = num_queries |
| 179 | + self.use_batch_norm = use_batch_norm |
| 180 | + self.proj_bias = proj_bias |
| 181 | + self.average_pool = average_pool |
| 182 | + self.num_labels = num_labels |
138 | 183 |
|
139 | 184 |
|
140 | 185 | class AIMv2OnnxConfig(OnnxConfig): |
|
0 commit comments