Skip to content

Commit 6dd71af

Browse files
authored
Merge pull request #2995 from Flamefire/fix_thread_buffer_init
Don't overwrite blas_thread_buffer if already set
2 parents 7e9cb39 + 60005eb commit 6dd71af

File tree

1 file changed

+22
-26
lines changed

1 file changed

+22
-26
lines changed

driver/others/blas_server_omp.c

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,28 @@ static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
7676
static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
7777
#endif
7878

79-
void goto_set_num_threads(int num_threads) {
79+
static void adjust_thread_buffers() {
8080

8181
int i=0, j=0;
8282

83+
//adjust buffer for each thread
84+
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
85+
for(j=0; j < blas_cpu_number; j++){
86+
if(blas_thread_buffer[i][j] == NULL){
87+
blas_thread_buffer[i][j] = blas_memory_alloc(2);
88+
}
89+
}
90+
for(; j < MAX_CPU_NUMBER; j++){
91+
if(blas_thread_buffer[i][j] != NULL){
92+
blas_memory_free(blas_thread_buffer[i][j]);
93+
blas_thread_buffer[i][j] = NULL;
94+
}
95+
}
96+
}
97+
}
98+
99+
void goto_set_num_threads(int num_threads) {
100+
83101
if (num_threads < 1) num_threads = blas_num_threads;
84102

85103
if (num_threads > MAX_CPU_NUMBER) num_threads = MAX_CPU_NUMBER;
@@ -92,20 +110,7 @@ void goto_set_num_threads(int num_threads) {
92110

93111
omp_set_num_threads(blas_cpu_number);
94112

95-
//adjust buffer for each thread
96-
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
97-
for(j=0; j<blas_cpu_number; j++){
98-
if(blas_thread_buffer[i][j]==NULL){
99-
blas_thread_buffer[i][j]=blas_memory_alloc(2);
100-
}
101-
}
102-
for(; j<MAX_CPU_NUMBER; j++){
103-
if(blas_thread_buffer[i][j]!=NULL){
104-
blas_memory_free(blas_thread_buffer[i][j]);
105-
blas_thread_buffer[i][j]=NULL;
106-
}
107-
}
108-
}
113+
adjust_thread_buffers();
109114
#if defined(ARCH_MIPS64)
110115
//set parameters for different number of threads.
111116
blas_set_parameter();
@@ -119,20 +124,11 @@ void openblas_set_num_threads(int num_threads) {
119124

120125
int blas_thread_init(void){
121126

122-
int i=0, j=0;
123-
124127
blas_get_cpu_number();
125128

126-
blas_server_avail = 1;
129+
adjust_thread_buffers();
127130

128-
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
129-
for(j=0; j<blas_num_threads; j++){
130-
blas_thread_buffer[i][j]=blas_memory_alloc(2);
131-
}
132-
for(; j<MAX_CPU_NUMBER; j++){
133-
blas_thread_buffer[i][j]=NULL;
134-
}
135-
}
131+
blas_server_avail = 1;
136132

137133
return 0;
138134
}

0 commit comments

Comments
 (0)