11#include "../../../devices/bang/common_bang.h"
2+ #include "rope_bang.h"
23
34__nram__ char nram_buffer[NRAM_MAX_SIZE];
45
@@ -11,29 +12,44 @@ __mlu_device__ void calculateRope(
1112 Tdata *input_0, Tdata *input_1, Tdata *input_cache,
1213 int theta_index, int out_index, int in_index,
1314 int chunk_size, int half_chunk_size, int data_segsize,
14- int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
15+ int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride,
16+ bool is_gpt_j_style) {
17+
1518 // Load sin/cos data
1619 __memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
1720 __memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
1821
1922 // Load input data
2023 __memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
2124
22- // Split input into even and odd positions
23- __memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
24- __memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
25+ if (is_gpt_j_style) {
26+ // GPT-J: (x0, x1), (x2, x3), ...
27+ // Split input into even and odd positions
28+ __memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
29+ __memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
30+ } else {
31+ // GPT-NeoX: (x0...xd/2-1), (xd/2...xd-1)
32+ __memcpy(input_0, input_cache, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
33+ __memcpy(input_1, input_cache + half_chunk_size, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
34+ }
2535
26- // Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
36+ // Compute rotations
2737 __bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
2838 __bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
2939 __bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
3040 __bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
3141 __bang_sub(input_0, x0cos, x1sin, half_chunk_size);
3242 __bang_add(input_1, x0sin, x1cos, half_chunk_size);
3343
34- // Interleave results back into output buffer
35- __memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
36- __memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
44+ if (is_gpt_j_style) {
45+ // GPT-J
46+ __memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
47+ __memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
48+ } else {
49+ // GPT-NeoX
50+ __memcpy(input_cache, input_0, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
51+ __memcpy(input_cache + half_chunk_size, input_1, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
52+ }
3753
3854 // Write back results
3955 __memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
@@ -52,22 +68,42 @@ __mlu_global__ void ropeKernel(
5268 ptrdiff_t y_stride_seqlen,
5369 ptrdiff_t y_stride_nhead,
5470 ptrdiff_t x_stride_seqlen,
55- ptrdiff_t x_stride_nhead) {
71+ ptrdiff_t x_stride_nhead,
72+ infiniopRoPEAlgo_t algo) {
73+
74+ const bool is_gpt_j_style = (algo == INFINIOP_ROPE_ALGO_GPT_J);
5675
5776 // Calculate available NRAM space after alignment
58- const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
77+ const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9);
5978 const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
6079
6180 // Key variables that determine execution path
6281 const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
63- const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
6482
65- // Common stride configurations
66- const int data_segsize = sizeof(Tdata);
67- const int src_load_stride = 2 * sizeof(Tdata);
68- const int dst_load_stride = 1 * sizeof(Tdata);
69- const int src_write_stride = 1 * sizeof(Tdata);
70- const int dst_write_stride = 2 * sizeof(Tdata);
83+ int half_chunk_size;
84+ if (is_gpt_j_style) {
85+ half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
86+ } else {
87+ half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
88+ }
89+
90+ int data_segsize, src_load_stride, dst_load_stride, src_write_stride, dst_write_stride;
91+
92+ if (is_gpt_j_style) {
93+ // GPT-J
94+ data_segsize = sizeof(Tdata);
95+ src_load_stride = 2 * sizeof(Tdata);
96+ dst_load_stride = 1 * sizeof(Tdata);
97+ src_write_stride = 1 * sizeof(Tdata);
98+ dst_write_stride = 2 * sizeof(Tdata);
99+ } else {
100+ // GPT-NeoX
101+ data_segsize = half_chunk_size * sizeof(Tdata);
102+ src_load_stride = 1 * sizeof(Tdata);
103+ dst_load_stride = 1 * sizeof(Tdata);
104+ src_write_stride = 1 * sizeof(Tdata);
105+ dst_write_stride = 1 * sizeof(Tdata);
106+ }
71107
72108 // Task distribution
73109 const int batch_volume = seqlen * nhead;
@@ -100,29 +136,29 @@ __mlu_global__ void ropeKernel(
100136
101137 // Main processing loop
102138 for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
103- // Calculate output and input indices
104139 int seq_idx = i / nhead;
105140 int head_idx = i % nhead;
106141
107- // Output indices (y)
108142 int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
109-
110- // Input indices (x)
111143 int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
112144
113- // Get position index
114145 Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
115146 int rot_offset = pos_idx * table_dim;
116147
117- // Process in chunks that fit in NRAM
118148 int processed = 0;
119149 while (processed < table_dim) {
120- // Calculate current chunk size
121150 int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
122151 int current_chunk_size = 2 * current_half_chunk;
123152 int theta_offset = rot_offset + processed;
124- int dst_offset = out_offset + processed * 2;
125- int src_offset = in_offset + processed * 2;
153+
154+ int dst_offset, src_offset;
155+ if (is_gpt_j_style) {
156+ dst_offset = out_offset + processed * 2;
157+ src_offset = in_offset + processed * 2;
158+ } else {
159+ dst_offset = out_offset + processed;
160+ src_offset = in_offset + processed;
161+ }
126162
127163 // Set up NRAM buffers for this chunk
128164 char *chunk_base = aligned_nram;
@@ -143,7 +179,8 @@ __mlu_global__ void ropeKernel(
143179 theta_offset, dst_offset, src_offset,
144180 current_chunk_size, current_half_chunk,
145181 data_segsize,
146- src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
182+ src_load_stride, dst_load_stride, src_write_stride, dst_write_stride,
183+ is_gpt_j_style);
147184
148185 processed += current_half_chunk;
149186 }
0 commit comments