Skip to content

Commit 00d7d43

Browse files
bssrdfggerganov
authored andcommitted
add tensor type checking as part of cuda graph properties (llama/19186)
1 parent ccee88e commit 00d7d43

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

src/ggml-cuda/common.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,6 +1124,7 @@ struct ggml_tensor_extra_gpu {
11241124
struct ggml_cuda_graph_node_properties {
11251125
void * node_data;
11261126
ggml_op node_op;
1127+
enum ggml_type node_type;
11271128
int32_t flags;
11281129
int64_t ne[GGML_MAX_DIMS];
11291130
size_t nb[GGML_MAX_DIMS];

src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2920,6 +2920,7 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties
29202920
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
29212921
props->node_data = node->data;
29222922
props->node_op = node->op;
2923+
props->node_type = node->type;
29232924
props->flags = node->flags;
29242925
for (int i = 0; i < GGML_MAX_DIMS; i++) {
29252926
props->ne[i] = node->ne[i];
@@ -2944,6 +2945,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
29442945
return false;
29452946
}
29462947

2948+
if (node->type != props->node_type) {
2949+
return false;
2950+
}
2951+
29472952
for (int i = 0; i < GGML_MAX_DIMS; i++) {
29482953
if (node->ne[i] != props->ne[i]) {
29492954
return false;

0 commit comments

Comments
 (0)