Skip to content

Commit 20a2dbd

Browse files
Merge pull request #478 from InfiniTensor/issue/477
issue/477 - Cambricon MLU NeoX
2 parents 6b903fd + 6af2e42 commit 20a2dbd

File tree

3 files changed

+68
-34
lines changed

3 files changed

+68
-34
lines changed

src/infiniop/ops/rope/bang/rope_bang.mlu

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ infiniStatus_t Descriptor::create(
2121
auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo);
2222
CHECK_RESULT(info);
2323

24-
if (algo != INFINIOP_ROPE_ALGO_GPT_J) {
25-
return INFINI_STATUS_NOT_IMPLEMENTED;
26-
}
27-
2824
// Create descriptor
2925
*desc_ptr = new Descriptor(
3026
info.take(),
@@ -62,7 +58,8 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info,
6258
y, x, pos_ids, sin_table, cos_table,
6359
dimx, dimy, table_dim,
6460
info.y_stride_seqlen, info.y_stride_nhead,
65-
info.x_stride_seqlen, info.x_stride_nhead);
61+
info.x_stride_seqlen, info.x_stride_nhead,
62+
info.algo);
6663

6764
cnrtQueueSync(queue);
6865

src/infiniop/ops/rope/bang/rope_bang_kernel.mlu

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "../../../devices/bang/common_bang.h"
2+
#include "rope_bang.h"
23

34
__nram__ char nram_buffer[NRAM_MAX_SIZE];
45

@@ -11,29 +12,44 @@ __mlu_device__ void calculateRope(
1112
Tdata *input_0, Tdata *input_1, Tdata *input_cache,
1213
int theta_index, int out_index, int in_index,
1314
int chunk_size, int half_chunk_size, int data_segsize,
14-
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride) {
15+
int src_load_stride, int dst_load_stride, int src_write_stride, int dst_write_stride,
16+
bool is_gpt_j_style) {
17+
1518
// Load sin/cos data
1619
__memcpy(sin_cache, sin_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
1720
__memcpy(cos_cache, cos_table + theta_index, half_chunk_size * sizeof(Tdata), GDRAM2NRAM);
1821

1922
// Load input data
2023
__memcpy(input_cache, in + in_index, chunk_size * sizeof(Tdata), GDRAM2NRAM);
2124

22-
// Split input into even and odd positions
23-
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
24-
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
25+
if (is_gpt_j_style) {
26+
// GPT-J: (x0, x1), (x2, x3), ...
27+
// Split input into even and odd positions
28+
__memcpy(input_0, input_cache, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
29+
__memcpy(input_1, input_cache + 1, data_segsize, NRAM2NRAM, dst_load_stride, src_load_stride, half_chunk_size - 1);
30+
} else {
31+
// GPT-NeoX: (x0...xd/2-1), (xd/2...xd-1)
32+
__memcpy(input_0, input_cache, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
33+
__memcpy(input_1, input_cache + half_chunk_size, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
34+
}
2535

26-
// Compute even positions: y0 = x0 * cos - x1 * sin and y1 = x0 * sin + x1 * cos
36+
// Compute rotations
2737
__bang_mul(x0cos, input_0, cos_cache, half_chunk_size);
2838
__bang_mul(x1sin, input_1, sin_cache, half_chunk_size);
2939
__bang_mul(x0sin, input_0, sin_cache, half_chunk_size);
3040
__bang_mul(x1cos, input_1, cos_cache, half_chunk_size);
3141
__bang_sub(input_0, x0cos, x1sin, half_chunk_size);
3242
__bang_add(input_1, x0sin, x1cos, half_chunk_size);
3343

34-
// Interleave results back into output buffer
35-
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
36-
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
44+
if (is_gpt_j_style) {
45+
// GPT-J
46+
__memcpy(input_cache, input_0, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
47+
__memcpy(input_cache + 1, input_1, data_segsize, NRAM2NRAM, dst_write_stride, src_write_stride, half_chunk_size - 1);
48+
} else {
49+
// GPT-NeoX
50+
__memcpy(input_cache, input_0, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
51+
__memcpy(input_cache + half_chunk_size, input_1, half_chunk_size * sizeof(Tdata), NRAM2NRAM);
52+
}
3753

3854
// Write back results
3955
__memcpy(out + out_index, input_cache, chunk_size * sizeof(Tdata), NRAM2GDRAM);
@@ -52,22 +68,42 @@ __mlu_global__ void ropeKernel(
5268
ptrdiff_t y_stride_seqlen,
5369
ptrdiff_t y_stride_nhead,
5470
ptrdiff_t x_stride_seqlen,
55-
ptrdiff_t x_stride_nhead) {
71+
ptrdiff_t x_stride_nhead,
72+
infiniopRoPEAlgo_t algo) {
73+
74+
const bool is_gpt_j_style = (algo == INFINIOP_ROPE_ALGO_GPT_J);
5675

5776
// Calculate available NRAM space after alignment
58-
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9); // 9 buffers need alignment
77+
const size_t nram_usable = NRAM_MAX_SIZE - (ALIGN_SIZE * 9);
5978
const size_t max_chunk_elements = nram_usable / (9 * sizeof(Tdata));
6079

6180
// Key variables that determine execution path
6281
const bool use_pos_ids_buffer = (seqlen * sizeof(Tindex) <= (nram_usable / 2));
63-
const int half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
6482

65-
// Common stride configurations
66-
const int data_segsize = sizeof(Tdata);
67-
const int src_load_stride = 2 * sizeof(Tdata);
68-
const int dst_load_stride = 1 * sizeof(Tdata);
69-
const int src_write_stride = 1 * sizeof(Tdata);
70-
const int dst_write_stride = 2 * sizeof(Tdata);
83+
int half_chunk_size;
84+
if (is_gpt_j_style) {
85+
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
86+
} else {
87+
half_chunk_size = std::min((int)(max_chunk_elements / 2), (int)table_dim);
88+
}
89+
90+
int data_segsize, src_load_stride, dst_load_stride, src_write_stride, dst_write_stride;
91+
92+
if (is_gpt_j_style) {
93+
// GPT-J
94+
data_segsize = sizeof(Tdata);
95+
src_load_stride = 2 * sizeof(Tdata);
96+
dst_load_stride = 1 * sizeof(Tdata);
97+
src_write_stride = 1 * sizeof(Tdata);
98+
dst_write_stride = 2 * sizeof(Tdata);
99+
} else {
100+
// GPT-NeoX
101+
data_segsize = half_chunk_size * sizeof(Tdata);
102+
src_load_stride = 1 * sizeof(Tdata);
103+
dst_load_stride = 1 * sizeof(Tdata);
104+
src_write_stride = 1 * sizeof(Tdata);
105+
dst_write_stride = 1 * sizeof(Tdata);
106+
}
71107

72108
// Task distribution
73109
const int batch_volume = seqlen * nhead;
@@ -100,29 +136,29 @@ __mlu_global__ void ropeKernel(
100136

101137
// Main processing loop
102138
for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) {
103-
// Calculate output and input indices
104139
int seq_idx = i / nhead;
105140
int head_idx = i % nhead;
106141

107-
// Output indices (y)
108142
int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead;
109-
110-
// Input indices (x)
111143
int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead;
112144

113-
// Get position index
114145
Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx];
115146
int rot_offset = pos_idx * table_dim;
116147

117-
// Process in chunks that fit in NRAM
118148
int processed = 0;
119149
while (processed < table_dim) {
120-
// Calculate current chunk size
121150
int current_half_chunk = std::min<uint32_t>(half_chunk_size, table_dim - processed);
122151
int current_chunk_size = 2 * current_half_chunk;
123152
int theta_offset = rot_offset + processed;
124-
int dst_offset = out_offset + processed * 2;
125-
int src_offset = in_offset + processed * 2;
153+
154+
int dst_offset, src_offset;
155+
if (is_gpt_j_style) {
156+
dst_offset = out_offset + processed * 2;
157+
src_offset = in_offset + processed * 2;
158+
} else {
159+
dst_offset = out_offset + processed;
160+
src_offset = in_offset + processed;
161+
}
126162

127163
// Set up NRAM buffers for this chunk
128164
char *chunk_base = aligned_nram;
@@ -143,7 +179,8 @@ __mlu_global__ void ropeKernel(
143179
theta_offset, dst_offset, src_offset,
144180
current_chunk_size, current_half_chunk,
145181
data_segsize,
146-
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride);
182+
src_load_stride, dst_load_stride, src_write_stride, dst_write_stride,
183+
is_gpt_j_style);
147184

148185
processed += current_half_chunk;
149186
}

test/infiniop/rope.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ def _torch_rope(sin, cos, t1, t2):
9797

9898
return t_out_1, t_out_2
9999

100-
101100
dh = t.shape[-1]
102101
dt = t.dtype
103102
assert dh % 2 == 0, "Embedding dimension must be even."
@@ -111,7 +110,7 @@ def _torch_rope(sin, cos, t1, t2):
111110
ans[..., 0::2] = t_out_even.to(dt)
112111
ans[..., 1::2] = t_out_odd.to(dt)
113112
else:
114-
half_dim = dh // 2
113+
half_dim = dh // 2
115114
t_first = t[..., :half_dim]
116115
t_second = t[..., half_dim:]
117116

@@ -232,6 +231,7 @@ def lib_rope():
232231
sin_table.torch_tensor(),
233232
cos_table.torch_tensor(),
234233
device,
234+
algo,
235235
),
236236
device,
237237
NUM_PRERUN,

0 commit comments

Comments
 (0)