@@ -5810,6 +5810,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58105810 return ;
58115811 }
58125812
5813+ bool incr_use_count = false;
5814+
58135815 for (int i = 0 ; i < GGML_MAX_SRC ; ++ i ) {
58145816 const int k =
58155817 (cgraph -> order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT ) ? i :
@@ -5825,19 +5827,11 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58255827 if (!(s -> op == GGML_OP_NONE && !(s -> flags & GGML_TENSOR_FLAG_PARAM ))) {
58265828 // the src can be the node itself (happens in ggml_cast)
58275829 if (s == node ) {
5828- cgraph -> use_counts [ cgraph -> n_nodes ] ++ ;
5830+ incr_use_count = true ;
58295831 } else {
5830- // Search backward to find the src. This usually takes very few
5831- // (most often one) iteration(s). Probably comparable to hashing
5832- // on average..
5833- int j = cgraph -> n_nodes - 1 ;
5834- for (; j >= 0 ; -- j ) {
5835- if (s == cgraph -> nodes [j ]) {
5836- break ;
5837- }
5838- }
5839- GGML_ASSERT (j >= 0 );
5840- cgraph -> use_counts [j ]++ ;
5832+ size_t s_idx = cgraph -> hash_to_node [ggml_hash_find (& cgraph -> visited_hash_set , s )];
5833+ GGML_ASSERT (cgraph -> nodes [s_idx ] == s );
5834+ cgraph -> use_counts [s_idx ]++ ;
58415835 }
58425836 }
58435837 }
@@ -5861,6 +5855,10 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58615855 }
58625856
58635857 cgraph -> nodes [cgraph -> n_nodes ] = node ;
5858+ if (incr_use_count ) {
5859+ cgraph -> use_counts [cgraph -> n_nodes ]++ ;
5860+ }
5861+ cgraph -> hash_to_node [ggml_hash_find (& cgraph -> visited_hash_set , node )] = cgraph -> n_nodes ;
58645862 cgraph -> n_nodes ++ ;
58655863 }
58665864}
@@ -5997,10 +5995,11 @@ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
59975995static size_t ggml_graph_nbytes (size_t size , bool grads ) {
59985996 size_t hash_size = ggml_hash_size (size * 2 );
59995997 void * p = 0 ;
6000- incr_ptr_aligned (& p , sizeof (struct ggml_cgraph ), 1 );
6001- incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // nodes
6002- incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // leafs
6003- incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t )); // use_counts
5998+ incr_ptr_aligned (& p , sizeof (struct ggml_cgraph ), 1 );
5999+ incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // nodes
6000+ incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // leafs
6001+ incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t )); // use_counts
6002+ incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t )); // hash_to_node
60046003 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // hash keys
60056004 if (grads ) {
60066005 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // grads
@@ -6033,6 +6032,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60336032 struct ggml_tensor * * nodes_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60346033 struct ggml_tensor * * leafs_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60356034 int32_t * use_counts_ptr = incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t ));
6035+ int32_t * hash_to_node_ptr = incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t ));
60366036 struct ggml_tensor * * hash_keys_ptr = incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60376037 struct ggml_tensor * * grads_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
60386038 struct ggml_tensor * * grad_accs_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
@@ -6051,6 +6051,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60516051 /*.grad_accs =*/ grad_accs_ptr ,
60526052 /*.leafs =*/ leafs_ptr ,
60536053 /*.use_counts =*/ use_counts_ptr ,
6054+ /*.hash_to_node =*/ hash_to_node_ptr ,
60546055 /*.hash_table =*/ { hash_size , hash_used , hash_keys_ptr },
60556056 /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT ,
60566057 };
@@ -6061,6 +6062,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60616062 memset (cgraph -> grad_accs , 0 , hash_size * sizeof (struct ggml_tensor * ));
60626063 }
60636064 memset (cgraph -> use_counts , 0 , size * sizeof (int32_t ));
6065+ memset (cgraph -> hash_to_node , -1 , size * sizeof (int32_t ));
60646066
60656067 return cgraph ;
60666068}
@@ -6079,6 +6081,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60796081 /*.grad_accs =*/ NULL ,
60806082 /*.leafs =*/ NULL ,
60816083 /*.use_counts =*/ cgraph0 -> use_counts + i0 ,
6084+ /*.hash_to_node =*/ NULL ,
60826085 /*.visited_hash_set =*/ { 0 , NULL , NULL },
60836086 /*.order =*/ cgraph0 -> order ,
60846087 };
@@ -6101,12 +6104,14 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
61016104
61026105 for (int i = 0 ; i < src -> n_nodes ; ++ i ) {
61036106 dst -> nodes [i ] = src -> nodes [i ];
6107+ dst -> use_counts [i ] = src -> use_counts [i ];
61046108 }
61056109
61066110 for (size_t i = 0 ; i < src -> visited_hash_set .size ; ++ i ) {
61076111 // copy all hashset keys (tensors) that are in use
61086112 if (ggml_bitset_get (src -> visited_hash_set .used , i )) {
61096113 ggml_hash_insert (& dst -> visited_hash_set , src -> visited_hash_set .keys [i ]);
6114+ dst -> hash_to_node [i ] = src -> hash_to_node [i ];
61106115 }
61116116 }
61126117
0 commit comments