@@ -746,6 +746,150 @@ static void serialize_graph(const ggml_cgraph * cgraph, std::vector<uint8_t> & o
746746 memcpy (out_tensors, tensors.data (), n_tensors * sizeof (rpc_tensor));
747747}
748748
749+ // Helper function to validate graph operands before computation
750+ static bool validate_graph_operands (const ggml_cgraph *graph) {
751+ GGML_PRINT_DEBUG (" [%s] Validating graph with %d nodes\n " , __func__, graph->n_nodes );
752+ for (uint32_t i = 0 ; i < (uint32_t )graph->n_nodes ; ++i) {
753+ const ggml_tensor* node = graph->nodes [i];
754+ // Initial null check added for safety.
755+ if (node == nullptr ) {
756+ GGML_LOG_ERROR (" [%s] Graph node %d is null.\n " , __func__, i);
757+ return false ;
758+ }
759+
760+ // Lambda to check for required source operands up to max_src_idx
761+ auto check_src = [&](int max_src_idx) -> bool {
762+ for (int s_idx = 0 ; s_idx <= max_src_idx; ++s_idx) {
763+ if (node->src [s_idx] == nullptr ) {
764+ GGML_LOG_ERROR (" [%s] Graph node %d (op %s, name '%s') missing required input src[%d].\n " , __func__, i, ggml_op_name (node->op ), node->name , s_idx);
765+ return false ;
766+ }
767+ }
768+ return true ;
769+ };
770+
771+ // Switch based on operation to determine required src operands
772+ switch (node->op ) {
773+ // Ops requiring src[0]
774+ case GGML_OP_DUP:
775+ case GGML_OP_SQR:
776+ case GGML_OP_SQRT:
777+ case GGML_OP_LOG:
778+ case GGML_OP_SIN:
779+ case GGML_OP_COS:
780+ case GGML_OP_SUM:
781+ case GGML_OP_SUM_ROWS:
782+ case GGML_OP_MEAN:
783+ case GGML_OP_ARGMAX:
784+ case GGML_OP_NORM:
785+ case GGML_OP_RMS_NORM:
786+ case GGML_OP_GROUP_NORM:
787+ case GGML_OP_L2_NORM:
788+ case GGML_OP_CONT:
789+ case GGML_OP_RESHAPE:
790+ case GGML_OP_VIEW:
791+ case GGML_OP_PERMUTE:
792+ case GGML_OP_TRANSPOSE:
793+ case GGML_OP_DIAG:
794+ case GGML_OP_DIAG_MASK_INF:
795+ case GGML_OP_DIAG_MASK_ZERO:
796+ case GGML_OP_SOFT_MAX:
797+ case GGML_OP_CLAMP:
798+ case GGML_OP_POOL_1D:
799+ case GGML_OP_POOL_2D:
800+ case GGML_OP_UPSCALE:
801+ case GGML_OP_PAD:
802+ case GGML_OP_PAD_REFLECT_1D:
803+ case GGML_OP_TIMESTEP_EMBEDDING:
804+ case GGML_OP_ARGSORT:
805+ case GGML_OP_LEAKY_RELU:
806+ case GGML_OP_WIN_PART:
807+ case GGML_OP_WIN_UNPART:
808+ case GGML_OP_GET_REL_POS:
809+ case GGML_OP_MAP_CUSTOM1:
810+ case GGML_OP_UNARY:
811+ if (!check_src (0 )) { return false ; }
812+ break ;
813+
814+ // Ops requiring src[0], src[1]
815+ case GGML_OP_ADD:
816+ case GGML_OP_ADD1:
817+ case GGML_OP_SUB:
818+ case GGML_OP_MUL:
819+ case GGML_OP_DIV:
820+ case GGML_OP_ACC:
821+ case GGML_OP_SCALE:
822+ case GGML_OP_RMS_NORM_BACK:
823+ case GGML_OP_CROSS_ENTROPY_LOSS:
824+ case GGML_OP_COUNT_EQUAL:
825+ case GGML_OP_REPEAT:
826+ case GGML_OP_REPEAT_BACK:
827+ case GGML_OP_CONCAT:
828+ case GGML_OP_SILU_BACK:
829+ case GGML_OP_MUL_MAT:
830+ case GGML_OP_OUT_PROD:
831+ case GGML_OP_SET:
832+ case GGML_OP_CPY:
833+ case GGML_OP_GET_ROWS:
834+ case GGML_OP_SOFT_MAX_BACK:
835+ case GGML_OP_ROPE:
836+ case GGML_OP_ROPE_BACK:
837+ case GGML_OP_CONV_TRANSPOSE_1D:
838+ case GGML_OP_IM2COL:
839+ case GGML_OP_IM2COL_BACK:
840+ case GGML_OP_CONV_TRANSPOSE_2D:
841+ case GGML_OP_POOL_2D_BACK:
842+ case GGML_OP_MAP_CUSTOM2:
843+ if (!check_src (1 )) { return false ; }
844+ break ;
845+
846+ // Ops requiring src[0], src[1], src[2]
847+ case GGML_OP_MUL_MAT_ID:
848+ case GGML_OP_GET_ROWS_BACK:
849+ case GGML_OP_ADD_REL_POS:
850+ case GGML_OP_MAP_CUSTOM3:
851+ case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
852+ case GGML_OP_SSM_CONV:
853+ if (!check_src (2 )) { return false ; }
854+ break ;
855+
856+ // Ops requiring src[0], src[1], src[2], src[3]
857+ case GGML_OP_FLASH_ATTN_EXT:
858+ if (!check_src (2 )) { return false ; }
859+ break ;
860+
861+ // Ops requiring src[0], src[1], src[2], src[3], src[4]
862+ case GGML_OP_FLASH_ATTN_BACK:
863+ case GGML_OP_OPT_STEP_ADAMW:
864+ if (!check_src (4 )) { return false ; }
865+ break ;
866+
867+ // Ops requiring src[0], src[1], src[2], src[3], src[4], src[5]
868+ case GGML_OP_SSM_SCAN:
869+ case GGML_OP_RWKV_WKV6:
870+ case GGML_OP_GATED_LINEAR_ATTN:
871+ if (!check_src (5 )) { return false ; }
872+ break ;
873+
874+ // Ops requiring src[0], src[1], src[2], src[3], src[4], src[5], src[6]
875+ case GGML_OP_RWKV_WKV7:
876+ if (!check_src (6 )) { return false ; }
877+ break ;
878+
879+ // Ops with no required src[] inputs or handled by default
880+ case GGML_OP_NONE:
881+ case GGML_OP_ARANGE:
882+ case GGML_OP_CUSTOM:
883+ case GGML_OP_COUNT:
884+ default :
885+ // Assume valid or cannot validate generically
886+ break ;
887+ }
888+ }
889+ GGML_PRINT_DEBUG (" [%s] Graph validation successful\n " , __func__);
890+ return true ;
891+ }
892+
749893static enum ggml_status ggml_backend_rpc_graph_compute (ggml_backend_t backend, ggml_cgraph * cgraph) {
750894 ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context ;
751895 std::vector<uint8_t > input;
@@ -1354,6 +1498,11 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
13541498 return false ;
13551499 }
13561500 }
1501+
1502+ if (!validate_graph_operands (graph)) {
1503+ return false ;
1504+ }
1505+
13571506 ggml_status status = ggml_backend_graph_compute (backend, graph);
13581507 response.result = status;
13591508 return true ;
0 commit comments