Skip to content

Commit 7004727

Browse files
mryvaelanhin
authored andcommitted
use row_segment for supporting 16 tasklets
1 parent 3fa5935 commit 7004727

File tree

4 files changed

+80
-62
lines changed

4 files changed

+80
-62
lines changed

dpu/dpu_main.c

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,21 @@
1111
#include <alloc.h>
1212
#include <barrier.h>
1313
#include <seqread.h>
14+
#include <mutex_pool.h>
1415

1516
#define PIM_KERNEL_DPU 1
1617
#include "../ggml/include/ggml.h"
1718
#define GGML_COMMON_DECL_C
1819
#include "../ggml/src/ggml-common.h"
1920

2021
#define PRINT 0
22+
#define SEGMENT_PER_ROW 4
23+
24+
// Find the lowest index for the rank-th group
25+
#define BLOCK_LOW(rank, size, n) ((rank) * (n) / (size))
26+
27+
// Find the highest index for the rank-th group
28+
#define BLOCK_HIGH(rank, size, n) (BLOCK_LOW((rank) + 1, (size), (n)) - 1)
2129

2230
__mram_ptr float *ptable_f32_f16;
2331

@@ -35,6 +43,7 @@ inline static float lookup_fp16_to_fp32(uint16_t f) {
3543

3644
// Barrier
3745
BARRIER_INIT(my_barrier, NR_TASKLETS);
46+
MUTEX_POOL_INIT(g_psumf_mutex_pool, NR_TASKLETS);
3847

3948
/*
4049
DPU MRAM Memory:
@@ -91,8 +100,9 @@ int wram2mram(__mram_ptr void *pmram,void *pwram,uint32_t size)
91100
}
92101

93102

94-
// set psumf to global value for each thread access
95-
static float *psumf = NULL;
103+
// set g_psumf to global value for each thread access
104+
static float *g_psumf = NULL;
105+
static block_q8_0 *g_pinput_cache = NULL;
96106

97107
void init(unsigned int tasklet_id) {
98108
#if PRINT
@@ -140,9 +150,11 @@ int main() {
140150
#endif
141151

142152
// set sart line, end line and line number in each thread
143-
uint16_t weight_rows_per_thread = cache_meta->rows_per_dpu / NR_TASKLETS;
144-
uint16_t weight_start_row = tasklet_id * weight_rows_per_thread;
145-
uint16_t weight_end_row = weight_start_row + weight_rows_per_thread;
153+
uint16_t segments_num = cache_meta->rows_per_dpu * SEGMENT_PER_ROW;
154+
uint16_t segment_start = BLOCK_LOW(tasklet_id, NR_TASKLETS, segments_num);
155+
uint16_t segment_end = BLOCK_HIGH(tasklet_id, NR_TASKLETS, segments_num);
156+
157+
assert(segment_start <= segment_end && "There are not enough segments to allocate to the tasklets");
146158

147159
// todo:rest row is existed, first thread in every dpu can one more row
148160
uint16_t weight_rows_cur_thread;
@@ -184,83 +196,80 @@ int main() {
184196
return -1;
185197
}
186198
int nb = pinputcache->ne[0]/QK8_0;
199+
200+
assert(SEGMENT_PER_ROW <= nb && nb % SEGMENT_PER_ROW == 0
201+
&& "Too many segments are allocated to each row.");
202+
187203
int qk = QK8_0;
188204
input_row_size = nb*sizeof(block_q8_0);
189205
__mram_ptr void *pweight_base = (__mram_ptr void *)(weightmetadatabase + sizeof(struct pim_meta));
190206
__mram_ptr void *pinput_base = DPU_MRAM_HEAP_POINTER + cache_meta->input_offset + sizeof(pim_matrix_des);
191-
207+
192208
if (tasklet_id == 0) {
193-
psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
209+
g_psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
210+
g_pinput_cache = (block_q8_0 *) mem_alloc(sizeof(block_q8_0) * nb);
211+
memset(g_psumf, 0 ,sizeof(float)*input_cols*weight_rows_cur_thread);
194212
}
195-
barrier_wait(&my_barrier);
196213

197-
// psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
198-
memset(psumf, 0 ,sizeof(float)*input_cols*weight_rows_cur_thread);
199-
200214
#if PRINT
201215
printf("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n",input_cols,weight_rows_cur_thread,nb,input_row_size);
202216
#endif
203-
block_q4_0 *pweight_cache = (block_q4_0 *) mem_alloc(sizeof(block_q4_0)*nb);
204-
block_q8_0 *pinput_cache = (block_q8_0 *) mem_alloc(sizeof(block_q8_0)*nb);
217+
218+
uint16_t segment_nb_size = nb / SEGMENT_PER_ROW;
219+
block_q4_0 *pweight_cache = (block_q4_0 *) mem_alloc(sizeof(block_q4_0) * segment_nb_size);
205220

206221
// weight_rows_cur_thread = 16;
207222
for(int l = 0;l < input_cols;l++) {
208-
__mram_ptr block_q8_0 *pinput = pinput_base + l * nb * sizeof(block_q8_0);
209-
mram2wram(pinput, pinput_cache, sizeof(block_q8_0)*nb);
210-
#if PRINT
211-
printf("input:\n");
212-
for (int i = 0; i < nb; i++) {
213-
printf("d=%u\n",pinput[i].d);
214-
for (int kkk=0;kkk<QK8_0;kkk++) {
215-
printf("%d ",pinput[i].qs[kkk]);
216-
}
217-
printf("\n");
223+
if (tasklet_id == 0) {
224+
__mram_ptr block_q8_0 *pinput = pinput_base + l * nb * sizeof(block_q8_0);
225+
mram2wram(pinput, g_pinput_cache, sizeof(block_q8_0)*nb);
218226
}
219-
printf("pweight_base: %p\n", pweight_base);
220-
#endif
221-
// for(int k = 0;k < weight_rows_cur_thread;k++) {
222-
for (int k = weight_start_row; k < weight_end_row; ++k) {
223-
__mram_ptr block_q4_0 *pweight = pweight_base + pinputcache->layerid * cache_meta->layer_len + k * nb * sizeof(block_q4_0);
224-
mram2wram(pweight, pweight_cache, sizeof(block_q4_0)*nb);
225-
#if PRINT
226-
if (k % 64 == 0) {
227-
printf("pweight_cache[%d].d=%d\n pweight_cache[%d].qs=", k*128, pweight_cache[0].d, k*128);
228-
for (int kkk=0;kkk<QK4_0/2;kkk++) {
229-
int v0 = (pweight_cache[0].qs[kkk] & 0x0f) - 8;
230-
int v1 = (pweight_cache[0].qs[kkk] >> 4) - 8;
231-
printf(" %d, %d", v0, v1);
232-
}
233-
printf("\n");
234-
}
235-
#endif
236227

237-
for (int i = 0; i < nb; i++) {
238-
//printf("input_col:%d, current inner weight row idx:%d\n",l,k);
228+
barrier_wait(&my_barrier);
229+
230+
__mram_ptr block_q4_0 *pweight_addr = pweight_base + pinputcache->layerid * cache_meta->layer_len;
239231

232+
for (int k = segment_start; k <= segment_end; ++k) {
233+
__mram_ptr block_q4_0 *pweight = pweight_addr + k * segment_nb_size;
234+
mram2wram(pweight, pweight_cache, sizeof(block_q4_0) * segment_nb_size);
235+
236+
block_q8_0 *pinput_cache = g_pinput_cache + k % SEGMENT_PER_ROW * segment_nb_size;
237+
238+
for (int i = 0; i < segment_nb_size; i++) {
240239
int sumi = 0;
241240
for (int j = 0; j < qk/2; ++j) {
242241
const int v0 = (pweight_cache[i].qs[j] & 0x0F) - 8;
243242
const int v1 = (pweight_cache[i].qs[j] >> 4) - 8;
244243

245244
sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
246245
}
247-
248-
psumf[l*weight_rows_cur_thread + k] += sumi*FP16_TO_FP32(pweight_cache[i].d)*FP16_TO_FP32(pinput_cache[i].d);
246+
247+
int psumf_idx = l * weight_rows_cur_thread + k / SEGMENT_PER_ROW;
248+
float sum = sumi * FP16_TO_FP32(pweight_cache[i].d) * FP16_TO_FP32(pinput_cache[i].d);
249+
mutex_pool_lock(&g_psumf_mutex_pool, psumf_idx);
250+
g_psumf[psumf_idx] += sum;
251+
// g_psumf[psumf_idx] += sumi;
252+
mutex_pool_unlock(&g_psumf_mutex_pool, psumf_idx);
249253
}
250254
}
251255
}
252256
}
253257

254-
offset += (sizeof(pim_matrix_des) + input_row_size * input_cols);
255-
#if PRINT
256-
for(int iii=0;iii<cache_meta->rows_per_dpu;iii+=128) {
257-
printf("psumf[%d]=%f\n",iii,psumf[iii]);
258+
barrier_wait(&my_barrier);
259+
260+
if (tasklet_id == 0){
261+
offset += (sizeof(pim_matrix_des) + input_row_size * input_cols);
262+
#if PRINT
263+
for(int iii=0;iii<cache_meta->rows_per_dpu;iii+=128) {
264+
printf("g_psumf[%d]=%f\n",iii,g_psumf[iii]);
265+
}
266+
267+
printf("output offset=%d\n",offset);
268+
#endif
269+
// Write C Matrix to current MRAM block
270+
// Note: with input_cols > 1, the results should be rearranged on host
271+
wram2mram((__mram_ptr void *) (DPU_MRAM_HEAP_POINTER + offset), g_psumf, sizeof(float)*input_cols*weight_rows_cur_thread);
258272
}
259273

260-
printf("output offset=%d\n",offset);
261-
#endif
262-
// Write C Matrix to current MRAM block
263-
// Note: with input_cols > 1, the results should be rearranged on host
264-
wram2mram((__mram_ptr void *) (DPU_MRAM_HEAP_POINTER + offset), psumf, sizeof(float)*input_cols*weight_rows_cur_thread);
265274
return 0;
266275
}

dpu/pim_build.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
#!/bin/bash
2-
dpu-upmem-dpurte-clang -Wall -Wextra -O2 -DNR_TASKLETS=8 -DBL=11 -o gemv_dpu dpu_main.c
2+
dpu-upmem-dpurte-clang -Wall -Wextra -O3 -DNR_TASKLETS=16 -DBL=11 -o gemv_dpu dpu_main.c

examples/tensor/ts.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
#include <iomanip>
44
#include <chrono>
55

6-
#include <vector>
7-
8-
#define NR_DPUS 2048
6+
#define NR_DPUS 64
97
#define NR_LAYER 2
108
#define DPU_BINARY "./dpu/gemv_dpu"
119

@@ -105,8 +103,8 @@ int gemv_dpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct
105103

106104
dur = ex_tp2 - ex_tp1;
107105

108-
std::cout << "dpu: 执行用时:" << std::chrono::duration_cast<std::chrono::milliseconds>(dur).count() << " ms" << std::endl;
109-
std::cout << "dpu: 执行用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl;
106+
// std::cout << "执行用时:" << std::chrono::duration_cast<std::chrono::milliseconds>(dur).count() << " ms" << std::endl;
107+
std::cout << "执行用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl;
110108

111109
// Check results
112110
float *mul_mat_res = (float *)res->data;

ggml/src/ggml.c

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17413,12 +17413,23 @@ static int dpu_launch_gemv_async(
1741317413
uint32_t input_offset = res->inout_offset;
1741417414
dpu_set = *(res->dpu_set);
1741517415
// broadcast input metadata
17416+
17417+
#if PIM_DEBUG_PERF_PRINT
17418+
uint64_t t_start = get_time_us();
17419+
#endif
17420+
1741617421
DPU_ASSERT(dpu_broadcast_to(dpu_set, DPU_MRAM_HEAP_POINTER_NAME, input_offset, &input_descript, sizeof(pim_matrix_des), DPU_XFER_DEFAULT));
1741717422
input_offset += sizeof(pim_matrix_des);
1741817423

1741917424
// broadcast input data
1742017425
uint32_t bclen = ggml_row_size(vec_dot_type, input->ne[0])*input->ne[1]*input->ne[2]*input->ne[3];
1742117426
DPU_ASSERT(dpu_broadcast_to(dpu_set, DPU_MRAM_HEAP_POINTER_NAME, input_offset, wdata, bclen, DPU_XFER_DEFAULT));
17427+
17428+
#if PIM_DEBUG_PERF_PRINT
17429+
uint64_t t_us = get_time_us() - t_start;
17430+
printf("\n%s: PIM broadcast time = %ld \n", __FUNCTION__, t_us);
17431+
#endif
17432+
1742217433
input_offset += bclen;
1742317434

1742417435
res->inout_offset = input_offset;
@@ -17433,9 +17444,9 @@ static __inline__ void dpu_kernel_barrier(struct dpu_set_t dpu_set) {
1743317444
struct dpu_set_t dpu;
1743417445
dpu_sync(dpu_set);
1743517446
//打印dpu log
17436-
DPU_FOREACH(dpu_set, dpu) {
17437-
DPU_ASSERT(dpulog_read_for_dpu(dpu.dpu, stdout));
17438-
}
17447+
// DPU_FOREACH(dpu_set, dpu) {
17448+
// DPU_ASSERT(dpulog_read_for_dpu(dpu.dpu, stdout));
17449+
// }
1743917450
return;
1744017451
}
1744117452

0 commit comments

Comments
 (0)