@@ -5809,8 +5809,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58095809 if (ggml_hash_insert (& cgraph -> visited_hash_set , node ) == GGML_HASHSET_ALREADY_EXISTS ) {
58105810 return ;
58115811 }
5812-
5813- bool incr_use_count = false;
5812+ cgraph -> use_counts [ggml_hash_find (& cgraph -> visited_hash_set , node )] = 0 ;
58145813
58155814 for (int i = 0 ; i < GGML_MAX_SRC ; ++ i ) {
58165815 const int k =
@@ -5825,14 +5824,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58255824 // Update the use count for this operand.
58265825 // Skip if it's a leaf node
58275826 if (!(s -> op == GGML_OP_NONE && !(s -> flags & GGML_TENSOR_FLAG_PARAM ))) {
5828- // the src can be the node itself (happens in ggml_cast)
5829- if (s == node ) {
5830- incr_use_count = true;
5831- } else {
5832- size_t s_idx = cgraph -> hash_to_node [ggml_hash_find (& cgraph -> visited_hash_set , s )];
5833- GGML_ASSERT (cgraph -> nodes [s_idx ] == s );
5834- cgraph -> use_counts [s_idx ]++ ;
5835- }
5827+ cgraph -> use_counts [ggml_hash_find (& cgraph -> visited_hash_set , s )]++ ;
58365828 }
58375829 }
58385830 }
@@ -5855,10 +5847,6 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58555847 }
58565848
58575849 cgraph -> nodes [cgraph -> n_nodes ] = node ;
5858- if (incr_use_count ) {
5859- cgraph -> use_counts [cgraph -> n_nodes ]++ ;
5860- }
5861- cgraph -> hash_to_node [ggml_hash_find (& cgraph -> visited_hash_set , node )] = cgraph -> n_nodes ;
58625850 cgraph -> n_nodes ++ ;
58635851 }
58645852}
@@ -5995,11 +5983,10 @@ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
59955983static size_t ggml_graph_nbytes (size_t size , bool grads ) {
59965984 size_t hash_size = ggml_hash_size (size * 2 );
59975985 void * p = 0 ;
5998- incr_ptr_aligned (& p , sizeof (struct ggml_cgraph ), 1 );
5999- incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // nodes
6000- incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // leafs
6001- incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t )); // use_counts
6002- incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t )); // hash_to_node
5986+ incr_ptr_aligned (& p , sizeof (struct ggml_cgraph ), 1 );
5987+ incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // nodes
5988+ incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // leafs
5989+ incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t )); // use_counts
60035990 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // hash keys
60045991 if (grads ) {
60055992 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // grads
@@ -6031,8 +6018,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60316018
60326019 struct ggml_tensor * * nodes_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60336020 struct ggml_tensor * * leafs_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
6034- int32_t * use_counts_ptr = incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t ));
6035- int32_t * hash_to_node_ptr = incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t ));
6021+ int32_t * use_counts_ptr = incr_ptr_aligned (& p , hash_size * sizeof (int32_t ), sizeof (int32_t ));
60366022 struct ggml_tensor * * hash_keys_ptr = incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60376023 struct ggml_tensor * * grads_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
60386024 struct ggml_tensor * * grad_accs_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
@@ -6051,7 +6037,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60516037 /*.grad_accs =*/ grad_accs_ptr ,
60526038 /*.leafs =*/ leafs_ptr ,
60536039 /*.use_counts =*/ use_counts_ptr ,
6054- /*.hash_to_node =*/ hash_to_node_ptr ,
60556040 /*.hash_table =*/ { hash_size , hash_used , hash_keys_ptr },
60566041 /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT ,
60576042 };
@@ -6061,8 +6046,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60616046 memset (cgraph -> grads , 0 , hash_size * sizeof (struct ggml_tensor * ));
60626047 memset (cgraph -> grad_accs , 0 , hash_size * sizeof (struct ggml_tensor * ));
60636048 }
6064- memset (cgraph -> use_counts , 0 , size * sizeof (int32_t ));
6065- memset (cgraph -> hash_to_node , -1 , size * sizeof (int32_t ));
60666049
60676050 return cgraph ;
60686051}
@@ -6080,9 +6063,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60806063 /*.grads =*/ NULL , // gradients would need visited_hash_set
60816064 /*.grad_accs =*/ NULL ,
60826065 /*.leafs =*/ NULL ,
6083- /*.use_counts =*/ cgraph0 -> use_counts + i0 ,
6084- /*.hash_to_node =*/ NULL ,
6085- /*.visited_hash_set =*/ { 0 , NULL , NULL },
6066+ /*.use_counts =*/ cgraph0 -> use_counts ,
6067+ /*.visited_hash_set =*/ cgraph0 -> visited_hash_set ,
60866068 /*.order =*/ cgraph0 -> order ,
60876069 };
60886070
@@ -6104,14 +6086,13 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
61046086
61056087 for (int i = 0 ; i < src -> n_nodes ; ++ i ) {
61066088 dst -> nodes [i ] = src -> nodes [i ];
6107- dst -> use_counts [i ] = src -> use_counts [i ];
61086089 }
61096090
61106091 for (size_t i = 0 ; i < src -> visited_hash_set .size ; ++ i ) {
61116092 // copy all hashset keys (tensors) that are in use
61126093 if (ggml_bitset_get (src -> visited_hash_set .used , i )) {
61136094 ggml_hash_insert (& dst -> visited_hash_set , src -> visited_hash_set .keys [i ]);
6114- dst -> hash_to_node [i ] = src -> hash_to_node [i ];
6095+ dst -> use_counts [i ] = src -> use_counts [i ];
61156096 }
61166097 }
61176098
0 commit comments