11#include " ggml-metal-common.h"
22
33#include " ggml-impl.h"
4+ #include " ggml-backend-impl.h"
45
56#include < vector>
67
8+ // represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9+ // the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
710struct ggml_mem_range {
811 uint64_t pb; // buffer id
912
@@ -36,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
3639 mrs->ranges .clear ();
3740}
3841
39- static bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, ggml_mem_range mrp ) {
40- mrs->ranges .push_back (mrp );
42+ static bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, ggml_mem_range mr ) {
43+ mrs->ranges .push_back (mr );
4144
4245 return true ;
4346}
@@ -48,28 +51,32 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
4851
4952 GGML_ASSERT (!tensor->view_src );
5053
51- ggml_mem_range mrp ;
54+ ggml_mem_range mr ;
5255
5356 if (tensor->buffer ) {
54- // when the tensor is allocated, use the actual memory address range of the buffer
55- mrp = {
57+ // when the tensor is allocated, use the actual memory address range in the buffer
58+ //
59+ // take the actual allocated size with ggml_backend_buft_get_alloc_size()
60+ // this can be larger than the tensor size if the buffer type allocates extra memory
61+ // ref: https://github.com/ggml-org/llama.cpp/pull/15966
62+ mr = {
5663 /* .pb =*/ (uint64_t ) tensor->buffer ,
5764 /* .p0 =*/ (uint64_t ) tensor->data ,
58- /* .p1 =*/ (uint64_t ) tensor->data + ggml_nbytes ( tensor),
65+ /* .p1 =*/ (uint64_t ) tensor->data + ggml_backend_buft_get_alloc_size (tensor-> buffer -> buft , tensor),
5966 /* .pt =*/ pt,
6067 };
6168 } else {
62- // otherwise, the tensor ptr is used as an unique id of the memory ranges
69+ // otherwise, the pointer address is used as an unique id of the memory ranges
6370 // that the tensor will be using when it is allocated
64- mrp = {
71+ mr = {
6572 /* .pb =*/ (uint64_t ) tensor,
6673 /* .p0 =*/ 0 , //
6774 /* .p1 =*/ 1024 , // [0, 1024) is a dummy range, not used
6875 /* .pt =*/ pt,
6976 };
7077 };
7178
72- return mrp ;
79+ return mr ;
7380}
7481
7582static ggml_mem_range ggml_mem_range_from_tensor_src (const ggml_tensor * tensor) {
@@ -83,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
8390static bool ggml_mem_ranges_add_src (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
8491 GGML_ASSERT (tensor);
8592
86- ggml_mem_range mrp = ggml_mem_range_from_tensor_src (tensor);
93+ ggml_mem_range mr = ggml_mem_range_from_tensor_src (tensor);
8794
8895 if (mrs->debug > 2 ) {
89- GGML_LOG_DEBUG (" %s: add src range buf=%lld, [%lld, %lld)\n " , __func__, mrp .pb , mrp .p0 , mrp .p1 );
96+ GGML_LOG_DEBUG (" %s: add src range buf=%lld, [%lld, %lld)\n " , __func__, mr .pb , mr .p0 , mr .p1 );
9097 }
9198
92- return ggml_mem_ranges_add (mrs, mrp );
99+ return ggml_mem_ranges_add (mrs, mr );
93100}
94101
95102static bool ggml_mem_ranges_add_dst (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
96103 GGML_ASSERT (tensor);
97104
98- ggml_mem_range mrp = ggml_mem_range_from_tensor_dst (tensor);
105+ ggml_mem_range mr = ggml_mem_range_from_tensor_dst (tensor);
99106
100107 if (mrs->debug > 2 ) {
101- GGML_LOG_DEBUG (" %s: add dst range buf=%lld, [%lld, %lld)\n " , __func__, mrp .pb , mrp .p0 , mrp .p1 );
108+ GGML_LOG_DEBUG (" %s: add dst range buf=%lld, [%lld, %lld)\n " , __func__, mr .pb , mr .p0 , mr .p1 );
102109 }
103110
104- return ggml_mem_ranges_add (mrs, mrp );
111+ return ggml_mem_ranges_add (mrs, mr );
105112}
106113
107114bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
@@ -114,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
114121 return ggml_mem_ranges_add_dst (mrs, tensor);
115122}
116123
117- static bool ggml_mem_ranges_check (const ggml_mem_ranges * mrs, ggml_mem_range mrp ) {
124+ static bool ggml_mem_ranges_check (const ggml_mem_ranges * mrs, ggml_mem_range mr ) {
118125 for (size_t i = 0 ; i < mrs->ranges .size (); i++) {
119126 const auto & cmp = mrs->ranges [i];
120127
121- if (mrp.pb != cmp.pb ) {
128+ // two memory ranges cannot intersect if they are in different buffers
129+ if (mr.pb != cmp.pb ) {
122130 continue ;
123131 }
124132
125- if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
133+ // intersecting source ranges are allowed
134+ if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
126135 continue ;
127136 }
128137
129- if (mrp .p0 < cmp.p1 && mrp .p1 >= cmp.p0 ) {
138+ if (mr .p0 < cmp.p1 && mr .p1 >= cmp.p0 ) {
130139 if (mrs->debug > 2 ) {
131140 GGML_LOG_DEBUG (" %s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n " ,
132141 __func__,
133- mrp .pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
134- mrp .pb , mrp .p0 , mrp .p1 ,
142+ mr .pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
143+ mr .pb , mr .p0 , mr .p1 ,
135144 cmp.pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
136145 cmp.pb , cmp.p0 , cmp.p1 );
137146 }
@@ -146,19 +155,19 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
146155static bool ggml_mem_ranges_check_src (const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
147156 GGML_ASSERT (tensor);
148157
149- ggml_mem_range mrp = ggml_mem_range_from_tensor_src (tensor);
158+ ggml_mem_range mr = ggml_mem_range_from_tensor_src (tensor);
150159
151- const bool res = ggml_mem_ranges_check (mrs, mrp );
160+ const bool res = ggml_mem_ranges_check (mrs, mr );
152161
153162 return res;
154163}
155164
156165static bool ggml_mem_ranges_check_dst (const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
157166 GGML_ASSERT (tensor);
158167
159- ggml_mem_range mrp = ggml_mem_range_from_tensor_dst (tensor);
168+ ggml_mem_range mr = ggml_mem_range_from_tensor_dst (tensor);
160169
161- const bool res = ggml_mem_ranges_check (mrs, mrp );
170+ const bool res = ggml_mem_ranges_check (mrs, mr );
162171
163172 return res;
164173}
@@ -222,6 +231,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
222231 }
223232 }
224233
234+ // keep track of the sources of the fused nodes as well
225235 for (const auto * fused : node.fused ) {
226236 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
227237 if (fused->src [i]) {
@@ -290,7 +300,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
290300
291301 std::vector<bool > used (n, false );
292302
303+ // the memory ranges for the set of currently concurrent nodes
293304 ggml_mem_ranges * mrs0 = ggml_mem_ranges_init (0 );
305+
306+ // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
294307 ggml_mem_ranges * mrs1 = ggml_mem_ranges_init (0 );
295308
296309 for (int i0 = 0 ; i0 < n; i0++) {
@@ -329,7 +342,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
329342
330343 const bool is_empty = node1.is_empty ();
331344
332- // to add a concurrent node , it has to be:
345+ // to reorder a node and add it to the concurrent set , it has to be:
333346 // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
334347 // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
335348 if ((is_empty || h_check (mrs0, node1)) && h_check (mrs1, node1)) {
@@ -419,8 +432,8 @@ void ggml_metal_graph_optimize(ggml_cgraph * gf) {
419432 nodes.push_back (std::move (node));
420433 }
421434
422- // reorder to improve concurrency
423435#if 1
436+ // reorder to improve concurrency
424437 const auto order = ggml_metal_graph_optimize_reorder (nodes);
425438#else
426439 std::vector<int> order(nodes.size());
0 commit comments