Skip to content

Commit 70f9d88

Browse files
authored
Merge pull request #3 from lanhin/tensor_export_turnoff
Tensor export turnoff
2 parents 3ec476b + 46aaeba commit 70f9d88

File tree

6 files changed

+177
-74
lines changed

6 files changed

+177
-74
lines changed

dpu/dpu_main.c

Lines changed: 69 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,26 @@
1111
#include <alloc.h>
1212
#include <barrier.h>
1313
#include <seqread.h>
14+
#include <mutex_pool.h>
1415

1516
#define PIM_KERNEL_DPU 1
1617
#include "../ggml/include/ggml.h"
1718
#define GGML_COMMON_DECL_C
1819
#include "../ggml/src/ggml-common.h"
1920

2021
#define PRINT 0
22+
#define SEGMENT_PER_ROW 4
23+
24+
// Find the lowest index for the rank-th group
25+
#define BLOCK_LOW(rank, size, n) ((rank) * (n) / (size))
26+
27+
// Find the highest index for the rank-th group
28+
#define BLOCK_HIGH(rank, size, n) (BLOCK_LOW((rank) + 1, (size), (n)) - 1)
2129

2230
__mram_ptr float *ptable_f32_f16;
2331

32+
__host int16_t mul_table_int4_int8[1<<4][1<<8];
33+
2434
inline static float lookup_fp16_to_fp32(uint16_t f) {
2535
uint16_t s;
2636
memcpy(&s, &f, sizeof(uint16_t));
@@ -35,6 +45,7 @@ inline static float lookup_fp16_to_fp32(uint16_t f) {
3545

3646
// Barrier
3747
BARRIER_INIT(my_barrier, NR_TASKLETS);
48+
MUTEX_POOL_INIT(g_psumf_mutex_pool, NR_TASKLETS);
3849

3950
/*
4051
DPU MRAM Memory:
@@ -91,8 +102,9 @@ int wram2mram(__mram_ptr void *pmram,void *pwram,uint32_t size)
91102
}
92103

93104

94-
// set psumf to global value for each thread access
95-
static float *psumf = NULL;
105+
// set g_psumf to global value for each thread access
106+
static float *g_psumf = NULL;
107+
static block_q8_0 *g_pinput_cache = NULL;
96108

97109
void init(unsigned int tasklet_id) {
98110
#if PRINT
@@ -140,9 +152,11 @@ int main() {
140152
#endif
141153

142154
// set sart line, end line and line number in each thread
143-
uint16_t weight_rows_per_thread = cache_meta->rows_per_dpu / NR_TASKLETS;
144-
uint16_t weight_start_row = tasklet_id * weight_rows_per_thread;
145-
uint16_t weight_end_row = weight_start_row + weight_rows_per_thread;
155+
uint16_t segments_num = cache_meta->rows_per_dpu * SEGMENT_PER_ROW;
156+
uint16_t segment_start = BLOCK_LOW(tasklet_id, NR_TASKLETS, segments_num);
157+
uint16_t segment_end = BLOCK_HIGH(tasklet_id, NR_TASKLETS, segments_num);
158+
159+
assert(segment_start <= segment_end && "There are not enough segments to allocate to the tasklets");
146160

147161
// todo:rest row is existed, first thread in every dpu can one more row
148162
uint16_t weight_rows_cur_thread;
@@ -184,83 +198,82 @@ int main() {
184198
return -1;
185199
}
186200
int nb = pinputcache->ne[0]/QK8_0;
201+
202+
assert(SEGMENT_PER_ROW <= nb && nb % SEGMENT_PER_ROW == 0
203+
&& "Too many segments are allocated to each row.");
204+
187205
int qk = QK8_0;
188206
input_row_size = nb*sizeof(block_q8_0);
189207
__mram_ptr void *pweight_base = (__mram_ptr void *)(weightmetadatabase + sizeof(struct pim_meta));
190208
__mram_ptr void *pinput_base = DPU_MRAM_HEAP_POINTER + cache_meta->input_offset + sizeof(pim_matrix_des);
191-
209+
192210
if (tasklet_id == 0) {
193-
psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
211+
g_psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
212+
g_pinput_cache = (block_q8_0 *) mem_alloc(sizeof(block_q8_0) * nb);
213+
memset(g_psumf, 0 ,sizeof(float)*input_cols*weight_rows_cur_thread);
194214
}
195-
barrier_wait(&my_barrier);
196215

197-
// psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
198-
memset(psumf, 0 ,sizeof(float)*input_cols*weight_rows_cur_thread);
199-
200216
#if PRINT
201217
printf("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n",input_cols,weight_rows_cur_thread,nb,input_row_size);
202218
#endif
203-
block_q4_0 *pweight_cache = (block_q4_0 *) mem_alloc(sizeof(block_q4_0)*nb);
204-
block_q8_0 *pinput_cache = (block_q8_0 *) mem_alloc(sizeof(block_q8_0)*nb);
219+
220+
uint16_t segment_nb_size = nb / SEGMENT_PER_ROW;
221+
block_q4_0 *pweight_cache = (block_q4_0 *) mem_alloc(sizeof(block_q4_0) * segment_nb_size);
205222

206223
// weight_rows_cur_thread = 16;
207224
for(int l = 0;l < input_cols;l++) {
208-
__mram_ptr block_q8_0 *pinput = pinput_base + l * nb * sizeof(block_q8_0);
209-
mram2wram(pinput, pinput_cache, sizeof(block_q8_0)*nb);
210-
#if PRINT
211-
printf("input:\n");
212-
for (int i = 0; i < nb; i++) {
213-
printf("d=%u\n",pinput[i].d);
214-
for (int kkk=0;kkk<QK8_0;kkk++) {
215-
printf("%d ",pinput[i].qs[kkk]);
216-
}
217-
printf("\n");
225+
if (tasklet_id == 0) {
226+
__mram_ptr block_q8_0 *pinput = pinput_base + l * nb * sizeof(block_q8_0);
227+
mram2wram(pinput, g_pinput_cache, sizeof(block_q8_0)*nb);
218228
}
219-
printf("pweight_base: %p\n", pweight_base);
220-
#endif
221-
// for(int k = 0;k < weight_rows_cur_thread;k++) {
222-
for (int k = weight_start_row; k < weight_end_row; ++k) {
223-
__mram_ptr block_q4_0 *pweight = pweight_base + pinputcache->layerid * cache_meta->layer_len + k * nb * sizeof(block_q4_0);
224-
mram2wram(pweight, pweight_cache, sizeof(block_q4_0)*nb);
225-
#if PRINT
226-
if (k % 64 == 0) {
227-
printf("pweight_cache[%d].d=%d\n pweight_cache[%d].qs=", k*128, pweight_cache[0].d, k*128);
228-
for (int kkk=0;kkk<QK4_0/2;kkk++) {
229-
int v0 = (pweight_cache[0].qs[kkk] & 0x0f) - 8;
230-
int v1 = (pweight_cache[0].qs[kkk] >> 4) - 8;
231-
printf(" %d, %d", v0, v1);
232-
}
233-
printf("\n");
234-
}
235-
#endif
236229

237-
for (int i = 0; i < nb; i++) {
238-
//printf("input_col:%d, current inner weight row idx:%d\n",l,k);
230+
barrier_wait(&my_barrier);
231+
232+
__mram_ptr block_q4_0 *pweight_addr = pweight_base + pinputcache->layerid * cache_meta->layer_len;
239233

234+
for (int k = segment_start; k <= segment_end; ++k) {
235+
__mram_ptr block_q4_0 *pweight = pweight_addr + k * segment_nb_size;
236+
mram2wram(pweight, pweight_cache, sizeof(block_q4_0) * segment_nb_size);
237+
238+
block_q8_0 *pinput_cache = g_pinput_cache + k % SEGMENT_PER_ROW * segment_nb_size;
239+
240+
for (int i = 0; i < segment_nb_size; i++) {
240241
int sumi = 0;
241242
for (int j = 0; j < qk/2; ++j) {
242-
const int v0 = (pweight_cache[i].qs[j] & 0x0F) - 8;
243-
const int v1 = (pweight_cache[i].qs[j] >> 4) - 8;
243+
const int8_t v0 = (pweight_cache[i].qs[j] & 0x0F) - 8;
244+
const int8_t v1 = (pweight_cache[i].qs[j] >> 4) - 8;
244245

245-
sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
246+
// sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
247+
sumi += mul_table_int4_int8[v0 + 8][pinput_cache[i].qs[j] - INT8_MIN] +
248+
mul_table_int4_int8[v1 + 8][pinput_cache[i].qs[j + qk/2] - INT8_MIN];
246249
}
247-
248-
psumf[l*weight_rows_cur_thread + k] += sumi*FP16_TO_FP32(pweight_cache[i].d)*FP16_TO_FP32(pinput_cache[i].d);
250+
251+
int psumf_idx = l * weight_rows_cur_thread + k / SEGMENT_PER_ROW;
252+
float sum = sumi * FP16_TO_FP32(pweight_cache[i].d) * FP16_TO_FP32(pinput_cache[i].d);
253+
mutex_pool_lock(&g_psumf_mutex_pool, psumf_idx);
254+
g_psumf[psumf_idx] += sum;
255+
// g_psumf[psumf_idx] += sumi;
256+
mutex_pool_unlock(&g_psumf_mutex_pool, psumf_idx);
249257
}
250258
}
251259
}
252260
}
253261

254-
offset += (sizeof(pim_matrix_des) + input_row_size * input_cols);
255-
#if PRINT
256-
for(int iii=0;iii<cache_meta->rows_per_dpu;iii+=128) {
257-
printf("psumf[%d]=%f\n",iii,psumf[iii]);
262+
barrier_wait(&my_barrier);
263+
264+
if (tasklet_id == 0){
265+
offset += (sizeof(pim_matrix_des) + input_row_size * input_cols);
266+
#if PRINT
267+
for(int iii=0;iii<cache_meta->rows_per_dpu;iii+=128) {
268+
printf("g_psumf[%d]=%f\n",iii,g_psumf[iii]);
269+
}
270+
271+
printf("output offset=%d\n",offset);
272+
#endif
273+
// Write C Matrix to current MRAM block
274+
// Note: with input_cols > 1, the results should be rearranged on host
275+
wram2mram((__mram_ptr void *) (DPU_MRAM_HEAP_POINTER + offset), g_psumf, sizeof(float)*input_cols*weight_rows_cur_thread);
258276
}
259277

260-
printf("output offset=%d\n",offset);
261-
#endif
262-
// Write C Matrix to current MRAM block
263-
// Note: with input_cols > 1, the results should be rearranged on host
264-
wram2mram((__mram_ptr void *) (DPU_MRAM_HEAP_POINTER + offset), psumf, sizeof(float)*input_cols*weight_rows_cur_thread);
265278
return 0;
266279
}

dpu/pim_build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
#!/bin/bash
2-
dpu-upmem-dpurte-clang -Wall -Wextra -O2 -DNR_TASKLETS=8 -DBL=11 -o gemv_dpu dpu_main.c
2+
dpu-upmem-dpurte-clang -Wall -Wextra -O3 -DNR_TASKLETS=16 -DBL=11 -o gemv_dpu dpu_main.c

examples/tensor/ts.cpp

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
#include <iomanip>
44
#include <chrono>
55

6-
#include <vector>
7-
8-
#define NR_DPUS 2048
6+
#define NR_DPUS 512
97
#define NR_LAYER 2
108
#define DPU_BINARY "./dpu/gemv_dpu"
119

10+
int16_t mul_table_int4_int8[1<<4][1<<8];
11+
1212
void fp_table_init(void) {
1313
for (int i = 0; i < (1 << 16); ++i) {
1414
union {
@@ -19,12 +19,22 @@ void fp_table_init(void) {
1919
}
2020
}
2121

22+
void mul_table_int4_int8_init(void) {
23+
for(int i = 0; i < (1 << 4); ++i){
24+
for(int j = 0; j< (1 << 8); ++j){
25+
mul_table_int4_int8[i][j] = (i - 8) * (j + INT8_MIN);
26+
}
27+
}
28+
}
29+
30+
#ifdef PIM_KERNEL
2231
int gemv_dpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct ggml_tensor * in_q, struct ggml_tensor * res) {
2332
uint32_t pim_offset = 0;
2433
struct dpu_set_t dpu;
2534

2635
std::chrono::high_resolution_clock::time_point ex_tp1 = std::chrono::high_resolution_clock::now();
2736

37+
DPU_ASSERT(dpu_broadcast_to(context->dpu_set, "mul_table_int4_int8", 0, (void *)(mul_table_int4_int8), sizeof(mul_table_int4_int8), DPU_XFER_DEFAULT));
2838
//ggml_table_f32_f16 tbl is transferred to pim
2939
DPU_ASSERT(dpu_broadcast_to(context->dpu_set, DPU_MRAM_HEAP_POINTER_NAME, pim_offset, (void *)(ggml_table_f32_f16), sizeof(ggml_table_f32_f16), DPU_XFER_DEFAULT));
3040
pim_offset += sizeof(ggml_table_f32_f16);
@@ -104,8 +114,8 @@ int gemv_dpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct
104114

105115
dur = ex_tp2 - ex_tp1;
106116

107-
std::cout << "dpu: 执行用时:" << std::chrono::duration_cast<std::chrono::milliseconds>(dur).count() << " ms" << std::endl;
108-
std::cout << "dpu: 执行用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl;
117+
// std::cout << "执行用时:" << std::chrono::duration_cast<std::chrono::milliseconds>(dur).count() << " ms" << std::endl;
118+
std::cout << "执行用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl;
109119

110120
// Check results
111121
float *mul_mat_res = (float *)res->data;
@@ -116,6 +126,7 @@ int gemv_dpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct
116126

117127
return 0;
118128
}
129+
#endif
119130

120131

121132
void gemv_cpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct ggml_tensor * in_q, struct ggml_tensor * res_comp) {
@@ -163,7 +174,9 @@ void gemv_cpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct
163174
int main(int argc, char** argv) {
164175
// init fp table for fp16 dump
165176
fp_table_init();
177+
mul_table_int4_int8_init();
166178

179+
#ifdef PIM_KERNEL
167180
// WQ-PIM allocate dpu
168181
struct pim_context *pqcontext = (struct pim_context *)malloc(sizeof(struct pim_context));
169182
memset(pqcontext,0,sizeof(struct pim_context));
@@ -213,6 +226,6 @@ int main(int argc, char** argv) {
213226
// float first_res = mul_add_q4_0_q8_0(ts_a, ts_bq);
214227
// std::cout<<"first element: "<<std::fixed << std::setprecision(6)<<first_res<<std::endl;
215228

216-
229+
#endif
217230
return 0;
218231
}

0 commit comments

Comments
 (0)