|
| 1 | +#include "trace_driver.h" |
| 2 | +#include <iostream> |
| 3 | +#include <iomanip> |
| 4 | +#include <chrono> |
| 5 | +#include <pthread.h> |
| 6 | + |
| 7 | +extern "C" { |
| 8 | +#include "../../PIM-tensorStore/host/pim_llm.h" |
| 9 | +} |
| 10 | + |
| 11 | + |
| 12 | +#define NR_DPUS 512 |
| 13 | +#define NR_LAYER 2 |
| 14 | +#define NR_THREADS 3 |
| 15 | +#define DPU_BINARY "./PIM-tensorStore/build/dpu_task" |
| 16 | +#define PIM_KERNEL |
| 17 | + |
| 18 | +int16_t mul_table_int4_int8[1<<4][1<<8]; |
| 19 | + |
| 20 | +void fp_table_init(void) { |
| 21 | + for (int i = 0; i < (1 << 16); ++i) { |
| 22 | + union { |
| 23 | + uint16_t u16; |
| 24 | + ggml_fp16_t fp16; |
| 25 | + } u = {i}; |
| 26 | + ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16); |
| 27 | + } |
| 28 | +} |
| 29 | + |
| 30 | +void mul_table_int4_int8_init(void) { |
| 31 | + for(int i = 0; i < (1 << 4); ++i){ |
| 32 | + for(int j = 0; j< (1 << 8); ++j){ |
| 33 | + mul_table_int4_int8[i][j] = (i - 8) * (j + INT8_MIN); |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +#ifdef PIM_KERNEL |
| 39 | + |
| 40 | +struct param |
| 41 | +{ |
| 42 | + struct dpu_set_t dpu_set; |
| 43 | + struct ggml_tensor *w; |
| 44 | + remote_ptr table_f32_f16_pim_ptr; |
| 45 | + remote_ptr w_pim_ptr; |
| 46 | + struct ggml_tensor * in_q; |
| 47 | + struct ggml_tensor * res; |
| 48 | +}; |
| 49 | + |
| 50 | +int gemv_load_weight(struct dpu_set_t dpu_set, struct ggml_tensor *w, remote_ptr* table_f32_f16_pim_ptr, remote_ptr* w_pim_ptr){ |
| 51 | + DPU_ASSERT(dpu_broadcast_to(dpu_set, "mul_table_int4_int8", 0, (void *)(mul_table_int4_int8), sizeof(mul_table_int4_int8), DPU_XFER_DEFAULT)); |
| 52 | + //ggml_table_f32_f16 tbl is transferred to pim |
| 53 | + |
| 54 | + all_dpu_mm_reset(); |
| 55 | + *table_f32_f16_pim_ptr = all_dpu_alloc(sizeof(ggml_table_f32_f16)); |
| 56 | + assert((*table_f32_f16_pim_ptr).dpu_id == ALL_DPU && (*table_f32_f16_pim_ptr).dpu_addr == FREE_STORAGE_OFFSET); |
| 57 | + dpu_broadcast_direct(dpu_set, *table_f32_f16_pim_ptr, (void *)(ggml_table_f32_f16), sizeof(ggml_table_f32_f16)); |
| 58 | + // DPU_ASSERT(dpu_broadcast_to(dpu_set, "table_f32_f16", 0, (void *)(ggml_table_f32_f16), sizeof(ggml_table_f32_f16), DPU_XFER_DEFAULT)); |
| 59 | + std::cout << "ggml_table_f32_f16 len = " << sizeof(ggml_table_f32_f16) << std::endl; |
| 60 | + |
| 61 | + assert(w->ne[1] % NR_DPUS == 0); |
| 62 | + |
| 63 | + *w_pim_ptr = all_dpu_alloc(w->nb[1] * (w->ne[1] / NR_DPUS)); |
| 64 | + assert((*w_pim_ptr).dpu_id == ALL_DPU && (*w_pim_ptr).dpu_addr == FREE_STORAGE_OFFSET + sizeof(ggml_table_f32_f16)); |
| 65 | + |
| 66 | + void *src_w_ptrs[NR_DPUS]; |
| 67 | + for (int i = 0; i < NR_DPUS; i++) |
| 68 | + { |
| 69 | + src_w_ptrs[i] = (void *)((unsigned char *)w->data + i * w->nb[1] * (w->ne[1] / NR_DPUS)); |
| 70 | + } |
| 71 | + |
| 72 | + dpu_send_direct(dpu_set, *w_pim_ptr, src_w_ptrs, w->nb[1] * (w->ne[1] / NR_DPUS)); |
| 73 | + return 0; |
| 74 | +} |
| 75 | + |
| 76 | +void* gemv_dpu_kernel(void *arg) { |
| 77 | + std::chrono::high_resolution_clock::time_point ex_tp1; |
| 78 | + std::chrono::high_resolution_clock::time_point ex_tp2; |
| 79 | + std::chrono::duration<size_t, std::nano> dur; |
| 80 | + struct param *pa = (struct param *)arg; |
| 81 | + struct dpu_set_t dpu_set = pa->dpu_set; |
| 82 | + struct ggml_tensor *w = pa->w; |
| 83 | + remote_ptr table_f32_f16_pim_ptr = pa->table_f32_f16_pim_ptr; |
| 84 | + remote_ptr w_pim_ptr = pa->w_pim_ptr; |
| 85 | + struct ggml_tensor * in_q = pa->in_q; |
| 86 | + struct ggml_tensor * res = pa->res; |
| 87 | + |
| 88 | + ex_tp1 = std::chrono::high_resolution_clock::now(); |
| 89 | + |
| 90 | + msg_block_des msg_gemv; |
| 91 | + printf("%d\n", table_f32_f16_pim_ptr.dpu_addr); |
| 92 | + msg_block_builder_op_gemv_q4_q8(&msg_gemv, w_pim_ptr, w->ne[0], w->ne[1] / NR_DPUS, in_q->ne[0], in_q->data, in_q->nb[1], table_f32_f16_pim_ptr); |
| 93 | + |
| 94 | + msg_buffer buffer; |
| 95 | + msg_buffer_init(&buffer); |
| 96 | + msg_buffer_clear(&buffer); |
| 97 | + msg_buffer_append(&buffer, &msg_gemv); |
| 98 | + msg_buffer_finish(&buffer); |
| 99 | + // msg_buffer_dump_int32(&buffer); |
| 100 | + msg_buffer_send(&buffer, dpu_set); |
| 101 | + |
| 102 | + ex_tp2 = std::chrono::high_resolution_clock::now(); |
| 103 | + |
| 104 | + dur = ex_tp2 - ex_tp1; |
| 105 | + |
| 106 | + std::cout << "dpu: in_q传输用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl; |
| 107 | + |
| 108 | + ex_tp1 = std::chrono::high_resolution_clock::now(); |
| 109 | + dpu_set_launch(dpu_set); |
| 110 | + ex_tp2 = std::chrono::high_resolution_clock::now(); |
| 111 | + |
| 112 | + dur = ex_tp2 - ex_tp1; |
| 113 | + |
| 114 | + std::cout << "执行用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl; |
| 115 | + |
| 116 | + // dpu_set_log_read(dpu_set); |
| 117 | + // Check results |
| 118 | + float *mul_mat_res = (float *)res->data; |
| 119 | + |
| 120 | + void *dst_ptrs[NR_DPUS]; |
| 121 | + for (int i = 0; i < NR_DPUS; i++) |
| 122 | + { |
| 123 | + dst_ptrs[i] = (void *)(mul_mat_res + i * w->ne[1] / NR_DPUS); |
| 124 | + } |
| 125 | + |
| 126 | + ex_tp1 = std::chrono::high_resolution_clock::now(); |
| 127 | + msg_buffer_recv(dpu_set, dst_ptrs, w->ne[1] / NR_DPUS * sizeof(float)); |
| 128 | + ex_tp2 = std::chrono::high_resolution_clock::now(); |
| 129 | + |
| 130 | + dur = ex_tp2 - ex_tp1; |
| 131 | + |
| 132 | + std::cout << "传回结果用时:" << std::chrono::duration_cast<std::chrono::microseconds>(dur).count() << " us" << std::endl; |
| 133 | + return NULL; |
| 134 | +} |
| 135 | +#endif |
| 136 | + |
| 137 | +int main(int argc, char** argv) { |
| 138 | + // init fp table for fp16 dump |
| 139 | + fp_table_init(); |
| 140 | + mul_table_int4_int8_init(); |
| 141 | + |
| 142 | +#ifdef PIM_KERNEL |
| 143 | + // WQ-PIM allocate dpu |
| 144 | + param pas[NR_THREADS]; |
| 145 | + for(int i=0;i<NR_THREADS;i++){ |
| 146 | + struct dpu_set_t& dpu_set = pas[i].dpu_set; |
| 147 | + DPU_ASSERT(dpu_alloc(NR_DPUS, NULL, &dpu_set)); |
| 148 | + DPU_ASSERT(dpu_load(dpu_set, DPU_BINARY, NULL)); |
| 149 | + } |
| 150 | + |
| 151 | + |
| 152 | + const char* filenamea = "tensor-files/a.tensor"; |
| 153 | + const char* filenameb = "tensor-files/b.tensor"; |
| 154 | + const char* filenamebq = "tensor-files/b_quant.tensor"; |
| 155 | + const char* filenamec = "tensor-files/c.tensor"; |
| 156 | + const char* filenamec_p = "tensor-files/c_pim.tensor"; |
| 157 | + struct ggml_tensor * ts_a = tensor_import(filenamea); |
| 158 | + struct ggml_tensor * ts_b = tensor_import(filenameb); |
| 159 | + struct ggml_tensor * ts_bq = tensor_import(filenamebq); |
| 160 | + struct ggml_tensor * ts_c = tensor_import(filenamec); |
| 161 | + struct ggml_tensor * ts_c_pim = tensor_import(filenamec_p); |
| 162 | + |
| 163 | + std::cout << "ts_a: " << std::endl; |
| 164 | + print_tensor(ts_a, stdout); |
| 165 | + std::cout << "ts_b: " << std::endl; |
| 166 | + print_tensor(ts_b, stdout); |
| 167 | + |
| 168 | + for(int i=0;i<NR_THREADS;i++){ |
| 169 | + pas[i].w = ts_a; |
| 170 | + pas[i].in_q = ts_bq; |
| 171 | + pas[i].res = ts_c_pim; |
| 172 | + } |
| 173 | + |
| 174 | + for(int i=0;i<NR_THREADS;i++){ |
| 175 | + struct dpu_set_t& dpu_set = pas[i].dpu_set; |
| 176 | + remote_ptr table_f32_f16_pim_ptr; |
| 177 | + remote_ptr w_pim_ptr; |
| 178 | + gemv_load_weight(dpu_set, ts_a, &table_f32_f16_pim_ptr, &w_pim_ptr); |
| 179 | + pas[i].table_f32_f16_pim_ptr = table_f32_f16_pim_ptr; |
| 180 | + pas[i].w_pim_ptr = w_pim_ptr; |
| 181 | + } |
| 182 | + |
| 183 | + |
| 184 | + uint64_t start = usec(); |
| 185 | + for(int i=0;i<NR_THREADS;i++){ |
| 186 | + gemv_dpu_kernel(&(pas[i])); |
| 187 | + } |
| 188 | + uint64_t end = usec(); |
| 189 | + std::cout<<"single thread sum time: "<<end - start << " us"<<std::endl; |
| 190 | + |
| 191 | + start = usec(); |
| 192 | + pthread_t pid[NR_THREADS]; |
| 193 | + for (int i = 0; i < NR_THREADS; i++) |
| 194 | + { |
| 195 | + pthread_create(&(pid[i]), NULL, gemv_dpu_kernel, &(pas[i])); |
| 196 | + } |
| 197 | + for (int i = 0; i < NR_THREADS; i++) |
| 198 | + { |
| 199 | + pthread_join(pid[i], NULL); |
| 200 | + } |
| 201 | + end = usec(); |
| 202 | + std::cout<<"multi thread sum time: "<<end - start << " us"<<std::endl; |
| 203 | + |
| 204 | + float first_res = mul_add_q4_0_q8_0(ts_a, ts_bq); |
| 205 | + std::cout<<"first element: "<<std::fixed << std::setprecision(6)<<first_res<<std::endl; |
| 206 | + |
| 207 | + std::cout << "error between c and c_pim:" << std::endl; |
| 208 | + compare_tensor(ts_c, ts_c_pim); |
| 209 | + |
| 210 | +#endif |
| 211 | + return 0; |
| 212 | +} |
0 commit comments