@@ -1608,7 +1608,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
16081608 /*.data =*/ obj_alloc_size > 0 ? (void * )(result + 1 ) : data ,
16091609 /*.name =*/ { 0 },
16101610 /*.extra =*/ NULL ,
1611- /*.use_count =*/ 0 ,
16121611 /*.padding =*/ { 0 },
16131612 };
16141613
@@ -5816,9 +5815,31 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58165815 (cgraph -> order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT ) ? i :
58175816 (cgraph -> order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT ) ? (GGML_MAX_SRC - 1 - i ) :
58185817 /* unknown order, just fall back to using i*/ i ;
5819- if (node -> src [k ]) {
5820- ggml_visit_parents (cgraph , node -> src [k ]);
5821- node -> src [k ]-> use_count ++ ;
5818+
5819+ struct ggml_tensor * s = node -> src [k ];
5820+ if (s ) {
5821+ ggml_visit_parents (cgraph , s );
5822+
5823+ // Update the use count for this operand.
5824+ // Skip if it's a leaf node
5825+ if (!(s -> op == GGML_OP_NONE && !(s -> flags & GGML_TENSOR_FLAG_PARAM ))) {
5826+ // the src can be the node itself (happens in ggml_cast)
5827+ if (s == node ) {
5828+ cgraph -> use_counts [cgraph -> n_nodes ]++ ;
5829+ } else {
5830+ // Search backward to find the src. This usually takes very few
5831+ // (most often one) iteration(s). Probably comparable to hashing
5832+ // on average..
5833+ int j = cgraph -> n_nodes - 1 ;
5834+ for (; j >= 0 ; -- j ) {
5835+ if (s == cgraph -> nodes [j ]) {
5836+ break ;
5837+ }
5838+ }
5839+ GGML_ASSERT (j >= 0 );
5840+ cgraph -> use_counts [j ]++ ;
5841+ }
5842+ }
58225843 }
58235844 }
58245845
@@ -5979,6 +6000,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
59796000 incr_ptr_aligned (& p , sizeof (struct ggml_cgraph ), 1 );
59806001 incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // nodes
59816002 incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // leafs
6003+ incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t )); // use_counts
59826004 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // hash keys
59836005 if (grads ) {
59846006 incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )); // grads
@@ -6010,6 +6032,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60106032
60116033 struct ggml_tensor * * nodes_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60126034 struct ggml_tensor * * leafs_ptr = incr_ptr_aligned (& p , size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
6035+ int32_t * use_counts_ptr = incr_ptr_aligned (& p , size * sizeof (int32_t ), sizeof (int32_t ));
60136036 struct ggml_tensor * * hash_keys_ptr = incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * ));
60146037 struct ggml_tensor * * grads_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
60156038 struct ggml_tensor * * grad_accs_ptr = grads ? incr_ptr_aligned (& p , hash_size * sizeof (struct ggml_tensor * ), sizeof (struct ggml_tensor * )) : NULL ;
@@ -6027,6 +6050,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60276050 /*.grads =*/ grads_ptr ,
60286051 /*.grad_accs =*/ grad_accs_ptr ,
60296052 /*.leafs =*/ leafs_ptr ,
6053+ /*.use_counts =*/ use_counts_ptr ,
60306054 /*.hash_table =*/ { hash_size , hash_used , hash_keys_ptr },
60316055 /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT ,
60326056 };
@@ -6036,6 +6060,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60366060 memset (cgraph -> grads , 0 , hash_size * sizeof (struct ggml_tensor * ));
60376061 memset (cgraph -> grad_accs , 0 , hash_size * sizeof (struct ggml_tensor * ));
60386062 }
6063+ memset (cgraph -> use_counts , 0 , size * sizeof (int32_t ));
60396064
60406065 return cgraph ;
60416066}
@@ -6053,6 +6078,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60536078 /*.grads =*/ NULL , // gradients would need visited_hash_set
60546079 /*.grad_accs =*/ NULL ,
60556080 /*.leafs =*/ NULL ,
6081+ /*.use_counts =*/ cgraph0 -> use_counts + i0 ,
60566082 /*.visited_hash_set =*/ { 0 , NULL , NULL },
60576083 /*.order =*/ cgraph0 -> order ,
60586084 };
0 commit comments