Skip to content

Commit 46aaeba

Browse files
mryvaelanhin
authored andcommitted
use mul_table_int4_int8
1 parent 7004727 commit 46aaeba

File tree

5 files changed

+38
-5
lines changed

5 files changed

+38
-5
lines changed

dpu/dpu_main.c

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929

3030
__mram_ptr float *ptable_f32_f16;
3131

32+
__host int16_t mul_table_int4_int8[1<<4][1<<8];
33+
3234
inline static float lookup_fp16_to_fp32(uint16_t f) {
3335
uint16_t s;
3436
memcpy(&s, &f, sizeof(uint16_t));
@@ -238,10 +240,12 @@ int main() {
238240
for (int i = 0; i < segment_nb_size; i++) {
239241
int sumi = 0;
240242
for (int j = 0; j < qk/2; ++j) {
241-
const int v0 = (pweight_cache[i].qs[j] & 0x0F) - 8;
242-
const int v1 = (pweight_cache[i].qs[j] >> 4) - 8;
243+
const int8_t v0 = (pweight_cache[i].qs[j] & 0x0F) - 8;
244+
const int8_t v1 = (pweight_cache[i].qs[j] >> 4) - 8;
243245

244-
sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
246+
// sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
247+
sumi += mul_table_int4_int8[v0 + 8][pinput_cache[i].qs[j] - INT8_MIN] +
248+
mul_table_int4_int8[v1 + 8][pinput_cache[i].qs[j + qk/2] - INT8_MIN];
245249
}
246250

247251
int psumf_idx = l * weight_rows_cur_thread + k / SEGMENT_PER_ROW;

examples/tensor/ts.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
#include <iomanip>
44
#include <chrono>
55

6-
#define NR_DPUS 64
6+
#define NR_DPUS 512
77
#define NR_LAYER 2
88
#define DPU_BINARY "./dpu/gemv_dpu"
99

10+
int16_t mul_table_int4_int8[1<<4][1<<8];
11+
1012
void fp_table_init(void) {
1113
for (int i = 0; i < (1 << 16); ++i) {
1214
union {
@@ -17,13 +19,22 @@ void fp_table_init(void) {
1719
}
1820
}
1921

22+
void mul_table_int4_int8_init(void) {
23+
for(int i = 0; i < (1 << 4); ++i){
24+
for(int j = 0; j< (1 << 8); ++j){
25+
mul_table_int4_int8[i][j] = (i - 8) * (j + INT8_MIN);
26+
}
27+
}
28+
}
29+
2030
#ifdef PIM_KERNEL
2131
int gemv_dpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct ggml_tensor * in_q, struct ggml_tensor * res) {
2232
uint32_t pim_offset = 0;
2333
struct dpu_set_t dpu;
2434

2535
std::chrono::high_resolution_clock::time_point ex_tp1 = std::chrono::high_resolution_clock::now();
2636

37+
DPU_ASSERT(dpu_broadcast_to(context->dpu_set, "mul_table_int4_int8", 0, (void *)(mul_table_int4_int8), sizeof(mul_table_int4_int8), DPU_XFER_DEFAULT));
2738
//ggml_table_f32_f16 tbl is transferred to pim
2839
DPU_ASSERT(dpu_broadcast_to(context->dpu_set, DPU_MRAM_HEAP_POINTER_NAME, pim_offset, (void *)(ggml_table_f32_f16), sizeof(ggml_table_f32_f16), DPU_XFER_DEFAULT));
2940
pim_offset += sizeof(ggml_table_f32_f16);
@@ -163,6 +174,7 @@ void gemv_cpu_kernel(struct pim_context *context, struct ggml_tensor * w, struct
163174
int main(int argc, char** argv) {
164175
// init fp table for fp16 dump
165176
fp_table_init();
177+
mul_table_int4_int8_init();
166178

167179
#ifdef PIM_KERNEL
168180
// WQ-PIM allocate dpu

ggml/src/ggml.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,10 @@ static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
490490
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
491491
float ggml_table_f32_f16[1 << 16];
492492

493+
#ifdef PIM_KERNEL
494+
int16_t mul_table_int4_int8[1<<4][1<<8];
495+
#endif
496+
493497
#if defined(__ARM_ARCH)
494498
struct ggml_arm_arch_features_type {
495499
int has_neon;

include/llama.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ extern "C" {
425425

426426
#ifdef PIM_KERNEL
427427
#define NR_DPUS 512
428-
#define NR_LAYER 2
428+
#define NR_LAYER 32
429429
#define DPU_BINARY "./dpu/gemv_dpu"
430430
enum WeightId {
431431
WQ,

src/llama.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9342,6 +9342,16 @@ static struct ggml_tensor * llm_build_lora_mm(
93429342

93439343
#ifdef PIM_KERNEL
93449344
extern float ggml_table_f32_f16[1 << 16];
9345+
extern int16_t mul_table_int4_int8[1<<4][1<<8];
9346+
9347+
static void mul_table_int4_int8_init(void) {
9348+
for(int i = 0; i < (1 << 4); ++i){
9349+
for(int j = 0; j< (1 << 8); ++j){
9350+
mul_table_int4_int8[i][j] = (i - 8) * (j + INT8_MIN);
9351+
}
9352+
}
9353+
}
9354+
93459355
int load_weight2dpu(enum WeightId w_id, struct dpu_set_t dpu_set, struct llama_model *model, struct pim_meta *pim_metadata, uint32_t offset_base) {
93469356
GGML_ASSERT(w_id < WCNT);
93479357
struct dpu_set_t dpu;
@@ -9393,6 +9403,9 @@ int llama_load2dpu(struct llama_context *ctx, struct llama_model *model) {
93939403
DPU_ASSERT(dpu_alloc(NR_DPUS, NULL, &pqcontext->dpu_set));
93949404
DPU_ASSERT(dpu_load(pqcontext->dpu_set, DPU_BINARY, NULL));
93959405

9406+
mul_table_int4_int8_init();
9407+
DPU_ASSERT(dpu_broadcast_to(pqcontext->dpu_set, "mul_table_int4_int8", 0, (void *)(mul_table_int4_int8), sizeof(mul_table_int4_int8), DPU_XFER_DEFAULT));
9408+
93969409
for (int uuu=0;uuu<16;uuu++) {
93979410
printf("ggml_table_f32_f16[%d]=%f\n",uuu,ggml_table_f32_f16[uuu]);
93989411
}

0 commit comments

Comments
 (0)