Skip to content

Commit 1630435

Browse files
committed
change use_counts to be indexed by hash table slot
1 parent 9ddd425 commit 1630435

File tree

3 files changed

+14
-33
lines changed

3 files changed

+14
-33
lines changed

ggml/src/ggml-backend.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,8 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
818818
if (sched->debug > 1) {
819819
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
820820
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
821-
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), graph->use_counts[i]);
821+
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
822+
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
822823
for (int j = 0; j < GGML_MAX_SRC; j++) {
823824
struct ggml_tensor * src = node->src[j];
824825
if (src == NULL) {

ggml/src/ggml-impl.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,7 @@ struct ggml_cgraph {
301301
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
302302
struct ggml_tensor ** grad_accs; // accumulators for node gradients
303303
struct ggml_tensor ** leafs; // tensors with constant data
304-
int32_t * use_counts;// number of uses of each tensor
305-
int32_t * hash_to_node;// map hash index to node index
304+
int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
306305

307306
struct ggml_hash_set visited_hash_set;
308307

@@ -474,7 +473,7 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
474473
static inline bool ggml_node_has_N_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t N) {
475474
const struct ggml_tensor * node = cgraph->nodes[node_idx];
476475
// check the use count against how many we're replacing
477-
if (cgraph->use_counts[node_idx] != N) {
476+
if (cgraph->use_counts[ggml_hash_find(&cgraph->visited_hash_set, node)] != N) {
478477
return false;
479478
}
480479

ggml/src/ggml.c

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5809,8 +5809,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58095809
if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
58105810
return;
58115811
}
5812-
5813-
bool incr_use_count = false;
5812+
cgraph->use_counts[ggml_hash_find(&cgraph->visited_hash_set, node)] = 0;
58145813

58155814
for (int i = 0; i < GGML_MAX_SRC; ++i) {
58165815
const int k =
@@ -5825,14 +5824,7 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58255824
// Update the use count for this operand.
58265825
// Skip if it's a leaf node
58275826
if (!(s->op == GGML_OP_NONE && !(s->flags & GGML_TENSOR_FLAG_PARAM))) {
5828-
// the src can be the node itself (happens in ggml_cast)
5829-
if (s == node) {
5830-
incr_use_count = true;
5831-
} else {
5832-
size_t s_idx = cgraph->hash_to_node[ggml_hash_find(&cgraph->visited_hash_set, s)];
5833-
GGML_ASSERT(cgraph->nodes[s_idx] == s);
5834-
cgraph->use_counts[s_idx]++;
5835-
}
5827+
cgraph->use_counts[ggml_hash_find(&cgraph->visited_hash_set, s)]++;
58365828
}
58375829
}
58385830
}
@@ -5855,10 +5847,6 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
58555847
}
58565848

58575849
cgraph->nodes[cgraph->n_nodes] = node;
5858-
if (incr_use_count) {
5859-
cgraph->use_counts[cgraph->n_nodes]++;
5860-
}
5861-
cgraph->hash_to_node[ggml_hash_find(&cgraph->visited_hash_set, node)] = cgraph->n_nodes;
58625850
cgraph->n_nodes++;
58635851
}
58645852
}
@@ -5995,11 +5983,10 @@ static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
59955983
static size_t ggml_graph_nbytes(size_t size, bool grads) {
59965984
size_t hash_size = ggml_hash_size(size * 2);
59975985
void * p = 0;
5998-
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5999-
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
6000-
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
6001-
incr_ptr_aligned(&p, size * sizeof(int32_t), sizeof(int32_t)); // use_counts
6002-
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // hash_to_node
5986+
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
5987+
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
5988+
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
5989+
incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
60035990
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
60045991
if (grads) {
60055992
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -6031,8 +6018,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60316018

60326019
struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
60336020
struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
6034-
int32_t * use_counts_ptr= incr_ptr_aligned(&p, size * sizeof(int32_t), sizeof(int32_t));
6035-
int32_t * hash_to_node_ptr= incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
6021+
int32_t * use_counts_ptr= incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
60366022
struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
60376023
struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
60386024
struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
@@ -6051,7 +6037,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60516037
/*.grad_accs =*/ grad_accs_ptr,
60526038
/*.leafs =*/ leafs_ptr,
60536039
/*.use_counts =*/ use_counts_ptr,
6054-
/*.hash_to_node =*/ hash_to_node_ptr,
60556040
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
60566041
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
60576042
};
@@ -6061,8 +6046,6 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
60616046
memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *));
60626047
memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *));
60636048
}
6064-
memset(cgraph->use_counts, 0, size*sizeof(int32_t));
6065-
memset(cgraph->hash_to_node, -1, size*sizeof(int32_t));
60666049

60676050
return cgraph;
60686051
}
@@ -6080,9 +6063,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
60806063
/*.grads =*/ NULL, // gradients would need visited_hash_set
60816064
/*.grad_accs =*/ NULL,
60826065
/*.leafs =*/ NULL,
6083-
/*.use_counts =*/ cgraph0->use_counts + i0,
6084-
/*.hash_to_node =*/ NULL,
6085-
/*.visited_hash_set =*/ { 0, NULL, NULL },
6066+
/*.use_counts =*/ cgraph0->use_counts,
6067+
/*.visited_hash_set =*/ cgraph0->visited_hash_set,
60866068
/*.order =*/ cgraph0->order,
60876069
};
60886070

@@ -6104,14 +6086,13 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
61046086

61056087
for (int i = 0; i < src->n_nodes; ++i) {
61066088
dst->nodes[i] = src->nodes[i];
6107-
dst->use_counts[i] = src->use_counts[i];
61086089
}
61096090

61106091
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
61116092
// copy all hashset keys (tensors) that are in use
61126093
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
61136094
ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
6114-
dst->hash_to_node[i] = src->hash_to_node[i];
6095+
dst->use_counts[i] = src->use_counts[i];
61156096
}
61166097
}
61176098

0 commit comments

Comments
 (0)