@@ -60,26 +60,60 @@ namespace chatllm
6060 LayerBufAllocator::LayerBufAllocator (ggml_backend_allocator alloc, Backend *backend): LayerBufAllocator(alloc, alloc, backend) {}
6161 LayerBufAllocator::LayerBufAllocator (ggml_backend_allocator alloc_matrix, ggml_backend_allocator alloc_others, Backend *backend)
6262 : alloc_matrix(alloc_matrix), alloc_others(alloc_others), backend(backend)
63- {}
63+ {
64+ CHATLLM_CHECK (alloc_matrix == alloc_others) << " TODO: alloc_matrix must be alloc_others now." ;
65+ }
6466
6567 BackendBuffer *LayerBufAllocator::alloc (size_t size, Usage usage)
6668 {
6769 total += size;
68- ggml_backend_buffer_t buf = nullptr ;
70+ ggml_backend_buffer_t buf = ggml_backend_buft_alloc_buffer (get_allocator (usage), size);
71+
72+ CHATLLM_CHECK (buf) << __FUNCTION__ << " () failed to allocate buffer" ;
73+
74+ auto r = new BackendBuffer (buf);
75+ buffers.emplace_back (r);
76+ return r;
77+ }
78+
79+ bool LayerBufAllocator::alloc (ggml::tensor *tensor)
80+ {
81+ BackendBuffer *buf = alloc (ggml::nbytes (tensor), detect_usage (tensor));
82+ if (nullptr == buf) return false ;
83+
84+ buf->assign_to (tensor);
85+ return true ;
86+ }
87+
88+ bool LayerBufAllocator::supported_by_backend (Backend *backend, ggml::tensor *tensor)
89+ {
90+ ggml_backend_allocator allocator = get_allocator (tensor); return false ;
91+ return ggml_backend_supports_buft (backend->backend , allocator);
92+ }
93+
94+ BackendBufAllocator::Usage LayerBufAllocator::detect_usage (ggml::tensor *tensor)
95+ {
96+ int dims = ggml::n_dims (tensor);
97+ return dims >= 2 ? Usage::Matrix : Usage::Others;
98+ }
99+
100+ ggml_backend_allocator LayerBufAllocator::get_allocator (Usage usage)
101+ {
69102 switch (usage)
70103 {
71104 case Usage::Matrix:
72- buf = ggml_backend_buft_alloc_buffer (alloc_matrix, size);
73- break ;
105+ return alloc_matrix;
74106 case Usage::Others:
75- buf = ggml_backend_buft_alloc_buffer (alloc_others, size);
76- break ;
107+ return alloc_others;
108+ default :
109+ CHATLLM_CHECK (false );
110+ return nullptr ;
77111 }
78- CHATLLM_CHECK (buf) << __FUNCTION__ << " () failed to allocate buffer " ;
112+ }
79113
80- auto r = new BackendBuffer (buf);
81- buffers. emplace_back (r);
82- return r ;
114+ ggml_backend_allocator LayerBufAllocator::get_allocator (ggml::tensor *tensor)
115+ {
116+ return get_allocator ( detect_usage (tensor)) ;
83117 }
84118
85119 size_t LayerBufAllocator::get_alignment (Usage usage) const
@@ -377,7 +411,7 @@ namespace chatllm
377411 for (auto &cfg : gpu_cfgs) n_gpu_layers += cfg.n_layers ;
378412 const bool use_gpu = n_gpu_layers > 0 ;
379413
380- buf_compute_meta.resize (ggml_tensor_overhead ()* graph_max_nodes_num + ggml_graph_overhead_custom (graph_max_nodes_num, false ));
414+ buf_compute_meta.resize (ggml_tensor_overhead () * graph_max_nodes_num + ggml_graph_overhead_custom (graph_max_nodes_num, false ));
381415
382416 backend_cpu = ggml_backend_cpu_init ();
383417 CHATLLM_CHECK (backend_cpu != nullptr ) << __func__ << " : failed to initialize CPU backend" ;
@@ -409,7 +443,6 @@ namespace chatllm
409443 #elif defined(GGML_USE_CUDA) || defined(GGML_USE_VULKAN) || defined(GGML_USE_SYCL) || defined(GGML_USE_CANN)
410444 if (use_gpu)
411445 {
412- const int total = ComputeManager::get_device_count ();
413446 for (auto cfg : gpu_cfgs)
414447 {
415448 int device = cfg.id >= 0 ? cfg.id : 0 ;
@@ -466,9 +499,9 @@ namespace chatllm
466499 return ggml_backend_sched_reserve (sched, gf);
467500 }
468501
469- void BackendContext::alloc_graph (ggml_cgraph *gf)
502+ bool BackendContext::alloc_graph (ggml_cgraph *gf)
470503 {
471- ggml_backend_sched_alloc_graph (sched, gf);
504+ return ggml_backend_sched_alloc_graph (sched, gf);
472505 }
473506
474507 void BackendContext::compute_graph (ggml_cgraph *gf, int n_threads)
@@ -538,11 +571,6 @@ namespace chatllm
538571
539572 void ComputeContext::cb_op_tensor (ggml::tensor *tensor)
540573 {
541- if (get_sched () && get_backend ())
542- {
543- if (ggml_backend_supports_op (get_backend ()->backend , tensor) || ggml_backend_offload_op (get_backend ()->backend , tensor))
544- ggml_backend_sched_set_tensor_backend (get_sched (), tensor, get_backend ()->backend );
545- }
546574 }
547575
548576 ggml_backend_sched_t ComputeContext::get_sched (void )
@@ -570,9 +598,9 @@ namespace chatllm
570598 backend_context->compute_graph (get_cgraph (), n_threads);
571599 }
572600
573- void ComputeContext::allocate (void )
601+ bool ComputeContext::allocate (void )
574602 {
575- backend_context->alloc_graph (get_cgraph ());
603+ return backend_context->alloc_graph (get_cgraph ());
576604 }
577605
578606 bool ComputeContext::reserve_memory (void )
0 commit comments