@@ -135,8 +135,12 @@ struct create_tensors_helper : public create_tensors_helper_interface {
135135 void create_std_attn (int i, const LLM_TN & tn, llama_layer & layer, int n_embd, int n_embd_gqa, ggml_context * ctx_split);
136136 void create_std_ffn (int i, const LLM_TN & tn, llama_layer & layer, int n_ff, int n_embd, ggml_context * ctx_split);
137137
138- inline ggml_context * ctx_for_layer (int i) const ;
139- inline ggml_context * ctx_for_layer_split (int i) const ;
138+ inline ggml_context * ctx_for_layer (int i) const {
139+ return ctx_map.at (model.buft_layer [i].buft );
140+ }
141+ inline ggml_context * ctx_for_layer_split (int i) const {
142+ return ctx_map.at (model.buft_layer [i].buft_matrix );
143+ }
140144
141145 std::map<ggml_backend_buffer_type_t , int > buft_layer_count;
142146 std::map<ggml_backend_buffer_type_t , ggml_context *> ctx_map;
@@ -145,6 +149,23 @@ struct create_tensors_helper : public create_tensors_helper_interface {
145149 ggml_context * ctx_input;
146150 ggml_context * ctx_output;
147151 ggml_context * ctx_output_split;
152+
153+ inline ggml_context * ctx_for_buft (ggml_backend_buffer_type_t buft) {
154+ if (auto it = ctx_map.find (buft); it != ctx_map.end ()) return it->second ;
155+
156+ ggml_init_params params = { /* .mem_size =*/ ctx_size, /* .mem_buffer =*/ NULL , /* .no_alloc =*/ true , };
157+
158+ ggml_context * ctx = ggml_init (params);
159+ if (!ctx) {
160+ throw std::runtime_error (format (" failed to create ggml context" ));
161+ }
162+
163+ ctx_map[buft] = ctx;
164+ model.ctxs .emplace_back (ctx);
165+
166+ return ctx;
167+
168+ }
148169};
149170
150171create_tensors_helper::create_tensors_helper (llama_model_loader & _ml, llama_model & _model) : ml(_ml), model(_model) {
@@ -183,38 +204,14 @@ ggml_tensor * create_tensors_helper::create_tensor(ggml_context * ctx, const std
183204 std::regex pattern (overrides->pattern );
184205 if (std::regex_search (name, pattern)) {
185206 LLAMA_LOG_INFO (" Tensor %s buffer type overriden to %s\n " , name.c_str (), ggml_backend_buft_name (overrides->buft ));
186- if (auto it = ctx_map.find (overrides->buft ); it != ctx_map.end ()) ctx = it->second ;
187- else {
188- ggml_init_params params = {
189- /* .mem_size =*/ ctx_size,
190- /* .mem_buffer =*/ NULL ,
191- /* .no_alloc =*/ true ,
192- };
193-
194- ggml_context * ctx = ggml_init (params);
195- if (!ctx) {
196- throw std::runtime_error (format (" failed to create ggml context" ));
197- }
198-
199- ctx_map[overrides->buft ] = ctx;
200- model.ctxs .emplace_back (ctx);
201-
202- }
207+ ctx = ctx_for_buft (overrides->buft );
203208 break ;
204209 }
205210 }
206211 }
207212 return ml.create_tensor (ctx, name, ne, flags);
208213}
209214
210- ggml_context * create_tensors_helper::ctx_for_layer (int i) const {
211- return ctx_map.at (model.buft_layer [i].buft );
212- }
213-
214- ggml_context * create_tensors_helper::ctx_for_layer_split (int i) const {
215- return ctx_map.at (model.buft_layer [i].buft_matrix );
216- }
217-
218215#define LOADING_PRELUDE \
219216 [[maybe_unused]] const auto & hparams = model.hparams; \
220217 [[maybe_unused]] const int64_t n_layer = hparams.n_layer; \
0 commit comments