@@ -1090,6 +1090,38 @@ def norm_class(self) -> Type:
1090
1090
# 5 local layers for every global layer
1091
1091
rope_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (26 )],
1092
1092
),
1093
+ # https://huggingface.co/google/gemma-3-4b-it/blob/main/config.json
1094
+ dict (
1095
+ name = "Gemma-3-4b-it" ,
1096
+ hf_config = dict (org = "google" , name = "gemma-3-4b-it" ),
1097
+ scale_embeddings = True ,
1098
+ attention_scores_scalar = 256 ,
1099
+ vocab_size = 262144 ,
1100
+ block_size = 131072 ,
1101
+ sliding_window_size = 1024 ,
1102
+ # 5 local layers for every global layer
1103
+ sliding_window_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (34 )],
1104
+ intermediate_size = 10240 ,
1105
+ n_embd = 2560 ,
1106
+ n_layer = 34 ,
1107
+ n_head = 8 ,
1108
+ n_query_groups = 4 ,
1109
+ head_size = 256 ,
1110
+ rotary_percentage = 1.0 ,
1111
+ rope_adjustments = dict (factor = 8.0 ),
1112
+ parallel_residual = False ,
1113
+ bias = False ,
1114
+ norm_class_name = "RMSNorm" ,
1115
+ mlp_class_name = "GemmaMLP" ,
1116
+ gelu_approximate = "tanh" ,
1117
+ post_attention_norm = True ,
1118
+ post_mlp_norm = True ,
1119
+ norm_qk = True ,
1120
+ rope_base = 1000000 ,
1121
+ rope_local_base_freq = 10000 ,
1122
+ # 5 local layers for every global layer
1123
+ rope_indices = [0 if (i + 1 ) % 6 == 0 else 1 for i in range (34 )],
1124
+ ),
1093
1125
# https://huggingface.co/google/gemma-3-12b-it/blob/main/config.json
1094
1126
dict (
1095
1127
name = "Gemma-3-12b-it" ,
0 commit comments