55
66#include < vector>
77
8+ // represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
9+ // the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
810struct ggml_mem_range {
911 uint64_t pb; // buffer id
1012
@@ -37,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
3739 mrs->ranges .clear ();
3840}
3941
40- static bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, ggml_mem_range mrp ) {
41- mrs->ranges .push_back (mrp );
42+ static bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, ggml_mem_range mr ) {
43+ mrs->ranges .push_back (mr );
4244
4345 return true ;
4446}
@@ -49,32 +51,32 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
4951
5052 GGML_ASSERT (!tensor->view_src );
5153
52- ggml_mem_range mrp ;
54+ ggml_mem_range mr ;
5355
5456 if (tensor->buffer ) {
5557 // when the tensor is allocated, use the actual memory address range in the buffer
5658 //
57- // take the actual allocated size
59+ // take the actual allocated size with ggml_backend_buft_get_alloc_size()
5860 // this can be larger than the tensor size if the buffer type allocates extra memory
5961 // ref: https://github.com/ggml-org/llama.cpp/pull/15966
60- mrp = {
62+ mr = {
6163 /* .pb =*/ (uint64_t ) tensor->buffer ,
6264 /* .p0 =*/ (uint64_t ) tensor->data ,
6365 /* .p1 =*/ (uint64_t ) tensor->data + ggml_backend_buft_get_alloc_size (tensor->buffer ->buft , tensor),
6466 /* .pt =*/ pt,
6567 };
6668 } else {
67- // otherwise, the tensor ptr is used as an unique id of the memory ranges
69+ // otherwise, the pointer address is used as an unique id of the memory ranges
6870 // that the tensor will be using when it is allocated
69- mrp = {
71+ mr = {
7072 /* .pb =*/ (uint64_t ) tensor,
7173 /* .p0 =*/ 0 , //
7274 /* .p1 =*/ 1024 , // [0, 1024) is a dummy range, not used
7375 /* .pt =*/ pt,
7476 };
7577 };
7678
77- return mrp ;
79+ return mr ;
7880}
7981
8082static ggml_mem_range ggml_mem_range_from_tensor_src (const ggml_tensor * tensor) {
@@ -88,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
8890static bool ggml_mem_ranges_add_src (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
8991 GGML_ASSERT (tensor);
9092
91- ggml_mem_range mrp = ggml_mem_range_from_tensor_src (tensor);
93+ ggml_mem_range mr = ggml_mem_range_from_tensor_src (tensor);
9294
9395 if (mrs->debug > 2 ) {
94- GGML_LOG_DEBUG (" %s: add src range buf=%lld, [%lld, %lld)\n " , __func__, mrp .pb , mrp .p0 , mrp .p1 );
96+ GGML_LOG_DEBUG (" %s: add src range buf=%lld, [%lld, %lld)\n " , __func__, mr .pb , mr .p0 , mr .p1 );
9597 }
9698
97- return ggml_mem_ranges_add (mrs, mrp );
99+ return ggml_mem_ranges_add (mrs, mr );
98100}
99101
100102static bool ggml_mem_ranges_add_dst (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
101103 GGML_ASSERT (tensor);
102104
103- ggml_mem_range mrp = ggml_mem_range_from_tensor_dst (tensor);
105+ ggml_mem_range mr = ggml_mem_range_from_tensor_dst (tensor);
104106
105107 if (mrs->debug > 2 ) {
106- GGML_LOG_DEBUG (" %s: add dst range buf=%lld, [%lld, %lld)\n " , __func__, mrp .pb , mrp .p0 , mrp .p1 );
108+ GGML_LOG_DEBUG (" %s: add dst range buf=%lld, [%lld, %lld)\n " , __func__, mr .pb , mr .p0 , mr .p1 );
107109 }
108110
109- return ggml_mem_ranges_add (mrs, mrp );
111+ return ggml_mem_ranges_add (mrs, mr );
110112}
111113
112114bool ggml_mem_ranges_add (ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
@@ -119,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
119121 return ggml_mem_ranges_add_dst (mrs, tensor);
120122}
121123
122- static bool ggml_mem_ranges_check (const ggml_mem_ranges * mrs, ggml_mem_range mrp ) {
124+ static bool ggml_mem_ranges_check (const ggml_mem_ranges * mrs, ggml_mem_range mr ) {
123125 for (size_t i = 0 ; i < mrs->ranges .size (); i++) {
124126 const auto & cmp = mrs->ranges [i];
125127
126- if (mrp.pb != cmp.pb ) {
128+ // two memory ranges cannot intersect if they are in different buffers
129+ if (mr.pb != cmp.pb ) {
127130 continue ;
128131 }
129132
130- if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
133+ // intersecting source ranges are allowed
134+ if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
131135 continue ;
132136 }
133137
134- if (mrp .p0 < cmp.p1 && mrp .p1 >= cmp.p0 ) {
138+ if (mr .p0 < cmp.p1 && mr .p1 >= cmp.p0 ) {
135139 if (mrs->debug > 2 ) {
136140 GGML_LOG_DEBUG (" %s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n " ,
137141 __func__,
138- mrp .pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
139- mrp .pb , mrp .p0 , mrp .p1 ,
142+ mr .pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
143+ mr .pb , mr .p0 , mr .p1 ,
140144 cmp.pt == MEM_RANGE_TYPE_SRC ? " src" : " dst" ,
141145 cmp.pb , cmp.p0 , cmp.p1 );
142146 }
@@ -151,19 +155,19 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
151155static bool ggml_mem_ranges_check_src (const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
152156 GGML_ASSERT (tensor);
153157
154- ggml_mem_range mrp = ggml_mem_range_from_tensor_src (tensor);
158+ ggml_mem_range mr = ggml_mem_range_from_tensor_src (tensor);
155159
156- const bool res = ggml_mem_ranges_check (mrs, mrp );
160+ const bool res = ggml_mem_ranges_check (mrs, mr );
157161
158162 return res;
159163}
160164
161165static bool ggml_mem_ranges_check_dst (const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
162166 GGML_ASSERT (tensor);
163167
164- ggml_mem_range mrp = ggml_mem_range_from_tensor_dst (tensor);
168+ ggml_mem_range mr = ggml_mem_range_from_tensor_dst (tensor);
165169
166- const bool res = ggml_mem_ranges_check (mrs, mrp );
170+ const bool res = ggml_mem_ranges_check (mrs, mr );
167171
168172 return res;
169173}
@@ -227,6 +231,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
227231 }
228232 }
229233
234+ // keep track of the sources of the fused nodes as well
230235 for (const auto * fused : node.fused ) {
231236 for (int i = 0 ; i < GGML_MAX_SRC; i++) {
232237 if (fused->src [i]) {
@@ -295,7 +300,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
295300
296301 std::vector<bool > used (n, false );
297302
303+ // the memory ranges for the set of currently concurrent nodes
298304 ggml_mem_ranges * mrs0 = ggml_mem_ranges_init (0 );
305+
306+ // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
299307 ggml_mem_ranges * mrs1 = ggml_mem_ranges_init (0 );
300308
301309 for (int i0 = 0 ; i0 < n; i0++) {
@@ -424,8 +432,8 @@ void ggml_metal_graph_optimize(ggml_cgraph * gf) {
424432 nodes.push_back (std::move (node));
425433 }
426434
427- // reorder to improve concurrency
428435#if 1
436+ // reorder to improve concurrency
429437 const auto order = ggml_metal_graph_optimize_reorder (nodes);
430438#else
431439 std::vector<int> order(nodes.size());
0 commit comments