Skip to content

Commit 9ddd425

Browse files
committed
use hash table lookup to find node index
1 parent 8c50a9b commit 9ddd425

File tree

2 files changed

+22
-16
lines changed

2 files changed

+22
-16
lines changed

ggml/src/ggml-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ struct ggml_cgraph {
302302
struct ggml_tensor ** grad_accs; // accumulators for node gradients
303303
struct ggml_tensor ** leafs; // tensors with constant data
304304
int32_t * use_counts;// number of uses of each tensor
305+
int32_t * hash_to_node;// map hash index to node index
305306

306307
struct ggml_hash_set visited_hash_set;
307308

ggml/src/ggml.c

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5810,6 +5810,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58105810
return;
58115811
}
58125812

5813+
bool incr_use_count = false;
5814+
58135815
for (int i = 0; i < GGML_MAX_SRC; ++i) {
58145816
const int k =
58155817
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
@@ -5825,19 +5827,11 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58255827
if (!(s->op == GGML_OP_NONE && !(s->flags & GGML_TENSOR_FLAG_PARAM))) {
58265828
// the src can be the node itself (happens in ggml_cast)
58275829
if (s == node) {
5828-
cgraph->use_counts[cgraph->n_nodes]++;
5830+
incr_use_count = true;
58295831
} else {
5830-
// Search backward to find the src. This usually takes very few
5831-
// (most often one) iteration(s). Probably comparable to hashing
5832-
// on average..
5833-
int j = cgraph->n_nodes - 1;
5834-
for (; j >= 0; --j) {
5835-
if (s == cgraph->nodes[j]) {
5836-
break;
5837-
}
5838-
}
5839-
GGML_ASSERT(j >= 0);
5840-
cgraph->use_counts[j]++;
5832+
size_t s_idx = cgraph->hash_to_node[ggml_hash_find(&cgraph->visited_hash_set, s)];
5833+
GGML_ASSERT(cgraph->nodes[s_idx] == s);
5834+
cgraph->use_counts[s_idx]++;
58415835
}
58425836
}
58435837
}
@@ -5861,6 +5855,10 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58615855
}
58625856

58635857
cgraph->nodes[cgraph->n_nodes] = node;
5858+
if (incr_use_count) {
5859+
cgraph->use_counts[cgraph->n_nodes]++;
5860+
}
5861+
cgraph->hash_to_node[ggml_hash_find(&cgraph->visited_hash_set, node)] = cgraph->n_nodes;
58645862
cgraph->n_nodes++;
58655863
}
58665864
}
@@ -5997,10 +5995,11 @@ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
59975995
static size_t ggml_graph_nbytes(size_t size, bool grads) {
59985996
size_t hash_size = ggml_hash_size(size * 2);
59995997
void * p = 0;
6000-
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
6001-
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6002-
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6003-
incr_ptr_aligned(&p, size * sizeof(int32_t), sizeof(int32_t)); // use_counts
5998+
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5999+
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6000+
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6001+
incr_ptr_aligned(&p, size * sizeof(int32_t), sizeof(int32_t)); // use_counts
6002+
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // hash_to_node
60046003
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
60056004
if (grads) {
60066005
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -6033,6 +6032,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60336032
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
60346033
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
60356034
int32_t * use_counts_ptr= incr_ptr_aligned(&p, size * sizeof(int32_t), sizeof(int32_t));
6035+
int32_t * hash_to_node_ptr= incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
60366036
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
60376037
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
60386038
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
@@ -6051,6 +6051,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60516051
/*.grad_accs =*/ grad_accs_ptr,
60526052
/*.leafs =*/ leafs_ptr,
60536053
/*.use_counts =*/ use_counts_ptr,
6054+
/*.hash_to_node =*/ hash_to_node_ptr,
60546055
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
60556056
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
60566057
};
@@ -6061,6 +6062,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60616062
memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
60626063
}
60636064
memset(cgraph->use_counts, 0, size*sizeof(int32_t));
6065+
memset(cgraph->hash_to_node, -1, size*sizeof(int32_t));
60646066

60656067
return cgraph;
60666068
}
@@ -6079,6 +6081,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60796081
/*.grad_accs =*/ NULL,
60806082
/*.leafs =*/ NULL,
60816083
/*.use_counts =*/ cgraph0->use_counts + i0,
6084+
/*.hash_to_node =*/ NULL,
60826085
/*.visited_hash_set =*/ { 0, NULL, NULL },
60836086
/*.order =*/ cgraph0->order,
60846087
};
@@ -6101,12 +6104,14 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
61016104

61026105
for (int i = 0; i < src->n_nodes; ++i) {
61036106
dst->nodes[i] = src->nodes[i];
6107+
dst->use_counts[i] = src->use_counts[i];
61046108
}
61056109

61066110
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
61076111
// copy all hashset keys (tensors) that are in use
61086112
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
61096113
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6114+
dst->hash_to_node[i] = src->hash_to_node[i];
61106115
}
61116116
}
61126117

0 commit comments

Comments
 (0)