@@ -30,6 +30,8 @@ void FusedRopeKernel(const Context& dev_ctx,
3030 const paddle::optional<DenseTensor>& v,
3131 const paddle::optional<DenseTensor>& sin,
3232 const paddle::optional<DenseTensor>& cos,
33+ const paddle::optional<DenseTensor>& position_ids,
34+ bool use_neox_rotary_style,
3335 DenseTensor* out_q,
3436 DenseTensor* out_k,
3537 DenseTensor* out_v) {
@@ -59,6 +61,7 @@ void FusedRopeKernel(const Context& dev_ctx,
5961 phi::Array<T*, 3 > outs_data;
6062 phi::Array<const T*, 3 > ins_data;
6163 phi::Array<const T*, 2 > sin_cos_data;
64+ const int64_t * position_ids_data = NULL ;
6265
6366 ins_data[0 ] = q.data <T>();
6467 outs_data[0 ] = out_q->data <T>();
@@ -109,15 +112,52 @@ void FusedRopeKernel(const Context& dev_ctx,
109112 " The batch_size and num_heads of sin and cos must be 1." ));
110113 }
111114 int sin_seq_len_dim = (dims_size) == 4 ? 1 : 0 ;
112- PADDLE_ENFORCE_EQ ((sin_dims[dims_size - 1 ] == head_dim &&
113- sin_dims[sin_seq_len_dim] == seq_len),
114- true ,
115- phi::errors::InvalidArgument (
116- " The seq_len and head_dim of sin and cos "
117- " must be the same as those of q. But recieved sin's "
118- " shape is {%s}, q's shape is {%s}." ,
119- sin_dims,
120- q.dims ()));
115+
116+ if (position_ids.get_ptr ()) {
117+ PADDLE_ENFORCE_EQ (
118+ (sin_dims[dims_size - 1 ] == head_dim &&
119+ sin_dims[sin_seq_len_dim] >= seq_len),
120+ true ,
121+ phi::errors::InvalidArgument (
122+ " The seq_len of sin and cos must be greater than or equal to "
123+ " this of q. The head_dim of sin and cos must be the same as this "
124+ " of q. But recieved sin's "
125+ " shape is {%s}, q's shape is {%s}." ,
126+ sin_dims,
127+ q.dims ()));
128+
129+ auto position_ids_dims = position_ids.get_ptr ()->dims ();
130+ PADDLE_ENFORCE_EQ (position_ids_dims.size (),
131+ 2 ,
132+ phi::errors::InvalidArgument (
133+ " The dims of position_ids is expected to "
134+ " be 2, but recieved %d." ,
135+ position_ids_dims.size ()));
136+
137+ PADDLE_ENFORCE_EQ (
138+ (position_ids_dims[0 ] == batch_size &&
139+ position_ids_dims[1 ] == seq_len),
140+ true ,
141+ phi::errors::InvalidArgument (
142+ " The batch_size and seq_len of position_ids must be the same as "
143+ " those of q. But recieved position_ids's "
144+ " shape is {%s}, q's shape is {%s}." ,
145+ position_ids_dims,
146+ q.dims ()));
147+
148+ position_ids_data = position_ids->data <int64_t >();
149+ } else {
150+ PADDLE_ENFORCE_EQ (
151+ (sin_dims[dims_size - 1 ] == head_dim &&
152+ sin_dims[sin_seq_len_dim] == seq_len),
153+ true ,
154+ phi::errors::InvalidArgument (
155+ " The seq_len and head_dim of sin and cos "
156+ " must be the same as those of q. But recieved sin's "
157+ " shape is {%s}, q's shape is {%s}." ,
158+ sin_dims,
159+ q.dims ()));
160+ }
121161
122162 sin_cos_data[0 ] = sin->data <T>();
123163 sin_cos_data[1 ] = cos->data <T>();
@@ -126,18 +166,35 @@ void FusedRopeKernel(const Context& dev_ctx,
126166 }
127167
128168 int sign = 1 ;
129- VectorizedFusedRopeKernel<T, MPType, vec_size>
130- <<<grid, block, 0 , stream>>> (ins_data,
131- sin_cos_data,
132- flag_sin_cos,
133- sign,
134- batch_size,
135- seq_len,
136- num_heads,
137- head_dim,
138- outs_data,
139- num_inputs,
140- div_c);
169+ if (use_neox_rotary_style) {
170+ VectorizedFusedRopeWithRotateEveryTwoKernel<T, MPType, vec_size>
171+ <<<grid, block, 0 , stream>>> (ins_data,
172+ sin_cos_data,
173+ position_ids_data,
174+ flag_sin_cos,
175+ sign,
176+ batch_size,
177+ seq_len,
178+ num_heads,
179+ head_dim,
180+ outs_data,
181+ num_inputs,
182+ div_c);
183+ } else {
184+ VectorizedFusedRopeWithRotateHalfKernel<T, MPType, vec_size>
185+ <<<grid, block, 0 , stream>>> (ins_data,
186+ sin_cos_data,
187+ position_ids_data,
188+ flag_sin_cos,
189+ sign,
190+ batch_size,
191+ seq_len,
192+ num_heads,
193+ head_dim,
194+ outs_data,
195+ num_inputs,
196+ div_c);
197+ }
141198}
142199} // namespace fusion
143200} // namespace phi
0 commit comments