@@ -102,20 +102,12 @@ struct llama_hparams {
102102
103103 // Sliding Window Attention (SWA)
104104 llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
105-
106- uint32_t n_swa = 0 ; // the size of the sliding window (0 - no SWA)
107- uint32_t n_swa_pattern = 1 ; // this value n means that every nth layer is dense (i.e. non-SWA)
108- // by default n == 1, all layers are dense
109- // note that if n_swa_pattern == 0, all layers are SWA
110- // example: n_swa_pattern = 3
111- // il == 0: swa
112- // il == 1: swa
113- // il == 2: dense
114- // il == 3: swa
115- // il == 4: swa
116- // il == 5: dense
117- // il == 6: swa
118- // etc ...
105+ // the size of the sliding window (0 - no SWA)
106+ uint32_t n_swa = 0 ;
107+ // if swa_layers[il] == true, then layer il is SWA
108+ // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
109+ // by default, all layers are dense
110+ std::array<bool , LLAMA_MAX_LAYERS> swa_layers;
119111
120112 // for State Space Models
121113 uint32_t ssm_d_conv = 0 ;
@@ -153,6 +145,23 @@ struct llama_hparams {
153145 enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
154146 enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
155147
148+ // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
149+ // note that if n_pattern == 0, all layers are SWA
150+ // if n_pattern == 1, all layers are dense
151+ // example: n_pattern = 3
152+ // il == 0: swa
153+ // il == 1: swa
154+ // il == 2: dense
155+ // il == 3: swa
156+ // il == 4: swa
157+ // il == 5: dense
158+ // il == 6: swa
159+ // etc ...
160+ void set_swa_pattern (uint32_t n_pattern);
161+
162+ // return true if one of the layers is SWA
163+ bool is_swa_any () const ;
164+
156165 uint32_t n_head (uint32_t il = 0 ) const ;
157166
158167 uint32_t n_head_kv (uint32_t il = 0 ) const ;
0 commit comments