@@ -29,6 +29,36 @@ static __device__ void padding(
2929 *y = Ta (rms * x * w);
3030}
3131
32+ // 三维张量的 padding 函数
33+ template <unsigned int BLOCK_SIZE, class Tw , class Ta >
34+ static __device__ void padding_3d (
35+ Ta *__restrict__ y_,
36+ int const stride_y_batch,
37+ int const stride_y_seq,
38+ Ta const *__restrict__ x_,
39+ int const stride_x_batch,
40+ int const stride_x_seq,
41+ Tw const *__restrict__ w_,
42+ float const epsilon) {
43+
44+ // blockIdx.x = batch index, blockIdx.y = seq index
45+ auto y = y_ + blockIdx .x * stride_y_batch + blockIdx .y * stride_y_seq + threadIdx .x ;
46+ float const x = x_[blockIdx .x * stride_x_batch + blockIdx .y * stride_x_seq + threadIdx .x ];
47+ float const w = w_[threadIdx .x ];
48+
49+ using BlockOp = cub::BlockReduce<float , BLOCK_SIZE>;
50+ __shared__ typename BlockOp::TempStorage temp_storage;
51+ auto acc = BlockOp (temp_storage).Reduce (x * x, cub::Sum ());
52+
53+ __shared__ float rms;
54+ if (threadIdx .x == 0 ) {
55+ rms = rsqrtf (acc / float (blockDim .x ) + epsilon);
56+ }
57+ __syncthreads ();
58+
59+ *y = Ta (rms * x * w);
60+ }
61+
3262template <unsigned int BLOCK_SIZE, unsigned int NUM_ITEMS_THREAD, class Tw , class Ta >
3363static __device__ void folding (
3464 Ta *__restrict__ y,
@@ -79,3 +109,59 @@ static __device__ void folding(
79109 BlockOp (temp_storage).Store (y, data, items_size);
80110 }
81111}
112+
113+ // 三维张量的 folding 函数
114+ template <unsigned int BLOCK_SIZE, unsigned int NUM_ITEMS_THREAD, class Tw , class Ta >
115+ static __device__ void folding_3d (
116+ Ta *__restrict__ y,
117+ int const stride_y_batch,
118+ int const stride_y_seq,
119+ Ta const *__restrict__ x,
120+ int const stride_x_batch,
121+ int const stride_x_seq,
122+ Tw const *__restrict__ w,
123+ float const epsilon,
124+ unsigned int const items_size) {
125+
126+ // blockIdx.x = batch index, blockIdx.y = seq index
127+ y += blockIdx .x * stride_y_batch + blockIdx .y * stride_y_seq;
128+ x += blockIdx .x * stride_x_batch + blockIdx .y * stride_x_seq;
129+
130+ float data[NUM_ITEMS_THREAD], weight[NUM_ITEMS_THREAD];
131+ {
132+ using BlockOp = cub::BlockLoad<float , BLOCK_SIZE, NUM_ITEMS_THREAD>;
133+ __shared__ typename BlockOp::TempStorage temp_storage;
134+ BlockOp (temp_storage).Load (x, data, items_size, 0 .f );
135+ BlockOp (temp_storage).Load (w, weight, items_size, 0 .f );
136+ }
137+
138+ float squared = 0 ;
139+ #pragma unroll
140+ for (unsigned int i = 0 ; i < NUM_ITEMS_THREAD; ++i) {
141+ squared += data[i] * data[i];
142+ }
143+
144+ float acc;
145+ {
146+ using BlockOp = cub::BlockReduce<float , BLOCK_SIZE>;
147+ __shared__ typename BlockOp::TempStorage temp_storage;
148+ acc = BlockOp (temp_storage).Reduce (squared, cub::Sum ());
149+ }
150+
151+ __shared__ float rms;
152+ if (threadIdx .x == 0 ) {
153+ rms = rsqrtf (acc / float (items_size) + epsilon);
154+ }
155+ __syncthreads ();
156+
157+ #pragma unroll
158+ for (unsigned int i = 0 ; i < NUM_ITEMS_THREAD; ++i) {
159+ data[i] = rms * data[i] * weight[i];
160+ }
161+
162+ {
163+ using BlockOp = cub::BlockStore<float , BLOCK_SIZE, NUM_ITEMS_THREAD>;
164+ __shared__ typename BlockOp::TempStorage temp_storage;
165+ BlockOp (temp_storage).Store (y, data, items_size);
166+ }
167+ }
0 commit comments