@@ -524,6 +524,7 @@ struct ggml_numa_nodes {
524524#ifdef GGML_USE_NUMA_MIGRATE
525525 int * node_num_of_cpu ;
526526 int * cpu_core_mapping ; // x logic core, y physical core
527+ int * thread_start_id ;
527528 int logic_core_cnts ;
528529 int cores_per_numa [GGML_NUMA_MIGRATE_NODES ];
529530#endif
@@ -646,18 +647,22 @@ int ggml_get_node_from_cpu(int ith) {
646647}
647648
648649int ggml_get_start_id_in_node (int ith ) {
650+ return g_state .numa .thread_start_id [ith ];
651+ }
652+
653+ static void ggml_set_start_id_in_node (int ith ) {
649654 int total_cpus = 0 ;
650655 int prev_total_cpus = 0 ;
651656 for (int node = 0 ; node < GGML_NUMA_MIGRATE_NODES ; node ++ ) {
652657 prev_total_cpus = total_cpus ;
653658 total_cpus += g_state .numa .cores_per_numa [node ];
654659 if (ith < total_cpus ) {
655- return (ith - prev_total_cpus );
660+ g_state .numa .thread_start_id [ith ] = (ith - prev_total_cpus );
661+ return ;
656662 }
657663 }
658664
659665 assert (0 );
660- return -1 ;
661666}
662667
663668int ggml_cores_per_numa (int ith ) {
@@ -676,7 +681,8 @@ void ggml_barrier_numa_aware(struct ggml_threadpool * tp, int ith, int node_n) {
676681 return ;
677682 }
678683 if (n_threads != g_state .numa .logic_core_cnts ) {
679- printf ("bolt-test: n_threads: %d, g_state.numa.logic_core_cnts: %d\n" , n_threads , g_state .numa .logic_core_cnts );
684+ printf ("WARNING: n_threads: %d not equal to core counts: %d, please check thread numbers and GGML_NUMA_CORE_IDS\n" ,
685+ n_threads , g_state .numa .logic_core_cnts );
680686 ggml_barrier (tp );
681687 return ;
682688 }
@@ -808,6 +814,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
808814#ifdef GGML_USE_NUMA_MIGRATE
809815 g_state .numa .node_num_of_cpu = (int * )malloc (g_state .numa .total_cpus * sizeof (int ));
810816 g_state .numa .cpu_core_mapping = (int * )malloc (g_state .numa .total_cpus * sizeof (int ));
817+ g_state .numa .thread_start_id = (int * )malloc (g_state .numa .total_cpus * sizeof (int ));
811818 int logic_core_index = 0 ;
812819
813820 const char * env_var = getenv ("GGML_NUMA_CORE_IDS" );
@@ -862,6 +869,11 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) {
862869
863870 fclose (fp );
864871 }
872+
873+ for (int i = 0 ; i < g_state .numa .logic_core_cnts ; i ++ ) {
874+ ggml_set_start_id_in_node (i );
875+ }
876+
865877#endif
866878
867879 if (ggml_is_numa ()) {
0 commit comments