Skip to content

Commit 86182fb

Browse files
committed
fix(rpc): validate graph operands
The RPC server could crash if a client sent a graph with operations missing required source tensors (e.g., ADD with NULL src[1]). This adds a validation step after graph construction but before backend execution to check for required non-null src operands based on the ggml_op. Signed-off-by: Ville Vesilehto <[email protected]>
1 parent 5f5e39e commit 86182fb

File tree

1 file changed

+149
-0
lines changed

1 file changed

+149
-0
lines changed

ggml/src/ggml-rpc/ggml-rpc.cpp

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,150 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
746746
memcpy(out_tensors, tensors.data(), n_tensors * sizeof(rpc_tensor));
747747
}
748748

749+
// Helper function to validate graph operands before computation
750+
static bool validate_graph_operands(const ggml_cgraph *graph) {
751+
GGML_PRINT_DEBUG("[%s] Validating graph with %d nodes\n", __func__, graph->n_nodes);
752+
for (uint32_t i = 0; i < (uint32_t)graph->n_nodes; ++i) {
753+
const ggml_tensor* node = graph->nodes[i];
754+
// Initial null check added for safety.
755+
if (node == nullptr) {
756+
GGML_LOG_ERROR("[%s] Graph node %d is null.\n", __func__, i);
757+
return false;
758+
}
759+
760+
// Lambda to check for required source operands up to max_src_idx
761+
auto check_src = [&](int max_src_idx) -> bool {
762+
for (int s_idx = 0; s_idx <= max_src_idx; ++s_idx) {
763+
if (node->src[s_idx] == nullptr) {
764+
GGML_LOG_ERROR("[%s] Graph node %d (op %s, name '%s') missing required input src[%d].\n", __func__, i, ggml_op_name(node->op), node->name, s_idx);
765+
return false;
766+
}
767+
}
768+
return true;
769+
};
770+
771+
// Switch based on operation to determine required src operands
772+
switch (node->op) {
773+
// Ops requiring src[0]
774+
case GGML_OP_DUP:
775+
case GGML_OP_SQR:
776+
case GGML_OP_SQRT:
777+
case GGML_OP_LOG:
778+
case GGML_OP_SIN:
779+
case GGML_OP_COS:
780+
case GGML_OP_SUM:
781+
case GGML_OP_SUM_ROWS:
782+
case GGML_OP_MEAN:
783+
case GGML_OP_ARGMAX:
784+
case GGML_OP_NORM:
785+
case GGML_OP_RMS_NORM:
786+
case GGML_OP_GROUP_NORM:
787+
case GGML_OP_L2_NORM:
788+
case GGML_OP_CONT:
789+
case GGML_OP_RESHAPE:
790+
case GGML_OP_VIEW:
791+
case GGML_OP_PERMUTE:
792+
case GGML_OP_TRANSPOSE:
793+
case GGML_OP_DIAG:
794+
case GGML_OP_DIAG_MASK_INF:
795+
case GGML_OP_DIAG_MASK_ZERO:
796+
case GGML_OP_SOFT_MAX:
797+
case GGML_OP_CLAMP:
798+
case GGML_OP_POOL_1D:
799+
case GGML_OP_POOL_2D:
800+
case GGML_OP_UPSCALE:
801+
case GGML_OP_PAD:
802+
case GGML_OP_PAD_REFLECT_1D:
803+
case GGML_OP_TIMESTEP_EMBEDDING:
804+
case GGML_OP_ARGSORT:
805+
case GGML_OP_LEAKY_RELU:
806+
case GGML_OP_WIN_PART:
807+
case GGML_OP_WIN_UNPART:
808+
case GGML_OP_GET_REL_POS:
809+
case GGML_OP_MAP_CUSTOM1:
810+
case GGML_OP_UNARY:
811+
if (!check_src(0)) { return false; }
812+
break;
813+
814+
// Ops requiring src[0], src[1]
815+
case GGML_OP_ADD:
816+
case GGML_OP_ADD1:
817+
case GGML_OP_SUB:
818+
case GGML_OP_MUL:
819+
case GGML_OP_DIV:
820+
case GGML_OP_ACC:
821+
case GGML_OP_SCALE:
822+
case GGML_OP_RMS_NORM_BACK:
823+
case GGML_OP_CROSS_ENTROPY_LOSS:
824+
case GGML_OP_COUNT_EQUAL:
825+
case GGML_OP_REPEAT:
826+
case GGML_OP_REPEAT_BACK:
827+
case GGML_OP_CONCAT:
828+
case GGML_OP_SILU_BACK:
829+
case GGML_OP_MUL_MAT:
830+
case GGML_OP_OUT_PROD:
831+
case GGML_OP_SET:
832+
case GGML_OP_CPY:
833+
case GGML_OP_GET_ROWS:
834+
case GGML_OP_SOFT_MAX_BACK:
835+
case GGML_OP_ROPE:
836+
case GGML_OP_ROPE_BACK:
837+
case GGML_OP_CONV_TRANSPOSE_1D:
838+
case GGML_OP_IM2COL:
839+
case GGML_OP_IM2COL_BACK:
840+
case GGML_OP_CONV_TRANSPOSE_2D:
841+
case GGML_OP_POOL_2D_BACK:
842+
case GGML_OP_MAP_CUSTOM2:
843+
if (!check_src(1)) { return false; }
844+
break;
845+
846+
// Ops requiring src[0], src[1], src[2]
847+
case GGML_OP_MUL_MAT_ID:
848+
case GGML_OP_GET_ROWS_BACK:
849+
case GGML_OP_ADD_REL_POS:
850+
case GGML_OP_MAP_CUSTOM3:
851+
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
852+
case GGML_OP_SSM_CONV:
853+
if (!check_src(2)) { return false; }
854+
break;
855+
856+
// Ops requiring src[0], src[1], src[2], src[3]
857+
case GGML_OP_FLASH_ATTN_EXT:
858+
if (!check_src(2)) { return false; }
859+
break;
860+
861+
// Ops requiring src[0], src[1], src[2], src[3], src[4]
862+
case GGML_OP_FLASH_ATTN_BACK:
863+
case GGML_OP_OPT_STEP_ADAMW:
864+
if (!check_src(4)) { return false; }
865+
break;
866+
867+
// Ops requiring src[0], src[1], src[2], src[3], src[4], src[5]
868+
case GGML_OP_SSM_SCAN:
869+
case GGML_OP_RWKV_WKV6:
870+
case GGML_OP_GATED_LINEAR_ATTN:
871+
if (!check_src(5)) { return false; }
872+
break;
873+
874+
// Ops requiring src[0], src[1], src[2], src[3], src[4], src[5], src[6]
875+
case GGML_OP_RWKV_WKV7:
876+
if (!check_src(6)) { return false; }
877+
break;
878+
879+
// Ops with no required src[] inputs or handled by default
880+
case GGML_OP_NONE:
881+
case GGML_OP_ARANGE:
882+
case GGML_OP_CUSTOM:
883+
case GGML_OP_COUNT:
884+
default:
885+
// Assume valid or cannot validate generically
886+
break;
887+
}
888+
}
889+
GGML_PRINT_DEBUG("[%s] Graph validation successful\n", __func__);
890+
return true;
891+
}
892+
749893
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
750894
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
751895
std::vector<uint8_t> input;
@@ -1354,6 +1498,11 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13541498
return false;
13551499
}
13561500
}
1501+
1502+
if (!validate_graph_operands(graph)) {
1503+
return false;
1504+
}
1505+
13571506
ggml_status status = ggml_backend_graph_compute(backend, graph);
13581507
response.result = status;
13591508
return true;

0 commit comments

Comments
 (0)