Skip to content

Commit f0fcfd4

Browse files
fix logic for RoPE support, CUDA graphs
1 parent 8778dd2 commit f0fcfd4

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2488,7 +2488,7 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
24882488
#endif
24892489
}
24902490

2491-
if (node->op == GGML_OP_MUL_MAT_ID) {
2491+
if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) {
24922492
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
24932493
#ifndef NDEBUG
24942494
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
@@ -3202,9 +3202,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32023202
}
32033203
case GGML_OP_ROPE:
32043204
case GGML_OP_ROPE_BACK: {
3205-
const size_t ts = ggml_type_size(op->src[0]->type);
3206-
const int64_t ne0_012 = op->src[0]->ne[0] * op->src[0]->ne[1] * op->src[0]->ne[2];
3207-
return op->src[0]->nb[0] == ts && op->src[0]->nb[3] == ne0_012*ts;
3205+
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
32083206
}
32093207
case GGML_OP_IM2COL:
32103208
case GGML_OP_POOL_2D:

0 commit comments

Comments
 (0)