@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
30
30
const paddle::optional<DenseTensor>& v,
31
31
const paddle::optional<DenseTensor>& sin,
32
32
const paddle::optional<DenseTensor>& cos,
33
+ const paddle::optional<DenseTensor>& position_ids,
34
+ bool use_neox_rotary_style,
33
35
DenseTensor* out_q,
34
36
DenseTensor* out_k,
35
37
DenseTensor* out_v) {
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
59
61
phi::Array<T*, 3 > outs_data;
60
62
phi::Array<const T*, 3 > ins_data;
61
63
phi::Array<const T*, 2 > sin_cos_data;
64
+ const int64_t * position_ids_data = NULL ;
62
65
63
66
ins_data[0 ] = q.data <T>();
64
67
outs_data[0 ] = out_q->data <T>();
@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
109
112
" The batch_size and num_heads of sin and cos must be 1." ));
110
113
}
111
114
int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0 ;
112
- PADDLE_ENFORCE_EQ ((sin_dims[dims_size - 1 ] == head_dim &&
113
- sin_dims[sin_seq_len_dim] == seq_len),
114
- true ,
115
- phi::errors::InvalidArgument (
116
- " The seq_len and head_dim of sin and cos "
117
- " must be the same as those of q. But recieved sin's "
118
- " shape is {%s}, q's shape is {%s}." ,
119
- sin_dims,
120
- q.dims ()));
115
+
116
+ if (position_ids.get_ptr ()) {
117
+ PADDLE_ENFORCE_EQ (
118
+ (sin_dims[dims_size - 1 ] == head_dim &&
119
+ sin_dims[sin_seq_len_dim] >= seq_len),
120
+ true ,
121
+ phi::errors::InvalidArgument (
122
+ " The seq_len of sin and cos must be greater than or equal to "
123
+ " this of q. The head_dim of sin and cos must be the same as this "
124
+ " of q. But recieved sin's "
125
+ " shape is {%s}, q's shape is {%s}." ,
126
+ sin_dims,
127
+ q.dims ()));
128
+
129
+ auto position_ids_dims = position_ids.get_ptr ()->dims ();
130
+ PADDLE_ENFORCE_EQ (position_ids_dims.size (),
131
+ 2 ,
132
+ phi::errors::InvalidArgument (
133
+ " The dims of position_ids is expected to "
134
+ " be 2, but recieved %d." ,
135
+ position_ids_dims.size ()));
136
+
137
+ PADDLE_ENFORCE_EQ (
138
+ (position_ids_dims[0 ] == batch_size &&
139
+ position_ids_dims[1 ] == seq_len),
140
+ true ,
141
+ phi::errors::InvalidArgument (
142
+ " The batch_size and seq_len of position_ids must be the same as "
143
+ " those of q. But recieved position_ids's "
144
+ " shape is {%s}, q's shape is {%s}." ,
145
+ position_ids_dims,
146
+ q.dims ()));
147
+
148
+ position_ids_data = position_ids->data <int64_t >();
149
+ } else {
150
+ PADDLE_ENFORCE_EQ (
151
+ (sin_dims[dims_size - 1 ] == head_dim &&
152
+ sin_dims[sin_seq_len_dim] == seq_len),
153
+ true ,
154
+ phi::errors::InvalidArgument (
155
+ " The seq_len and head_dim of sin and cos "
156
+ " must be the same as those of q. But recieved sin's "
157
+ " shape is {%s}, q's shape is {%s}." ,
158
+ sin_dims,
159
+ q.dims ()));
160
+ }
121
161
122
162
sin_cos_data[0 ] = sin->data <T>();
123
163
sin_cos_data[1 ] = cos->data <T>();
@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
126
166
}
127
167
128
168
int sign = 1 ;
129
- VectorizedFusedRopeKernel<T, MPType, vec_size>
130
- <<<grid, block, 0 , stream>>> (ins_data,
131
- sin_cos_data,
132
- flag_sin_cos,
133
- sign,
134
- batch_size,
135
- seq_len,
136
- num_heads,
137
- head_dim,
138
- outs_data,
139
- num_inputs,
140
- div_c);
169
+ if (use_neox_rotary_style) {
170
+ VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
171
+ <<<grid, block, 0 , stream>>> (ins_data,
172
+ sin_cos_data,
173
+ position_ids_data,
174
+ flag_sin_cos,
175
+ sign,
176
+ batch_size,
177
+ seq_len,
178
+ num_heads,
179
+ head_dim,
180
+ outs_data,
181
+ num_inputs,
182
+ div_c);
183
+ } else {
184
+ VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
185
+ <<<grid, block, 0 , stream>>> (ins_data,
186
+ sin_cos_data,
187
+ position_ids_data,
188
+ flag_sin_cos,
189
+ sign,
190
+ batch_size,
191
+ seq_len,
192
+ num_heads,
193
+ head_dim,
194
+ outs_data,
195
+ num_inputs,
196
+ div_c);
197
+ }
141
198
}
142
199
} // namespace fusion
143
200
} // namespace phi
0 commit comments