@@ -754,6 +754,152 @@ void update_cache(
754754 }
755755}
756756
757+ /*
758+ Input params
759+ @param[in] q_projected Projected query with query weights.
760+ Format [n_layers, batch size, seq_len, num heads, head dim]
761+ @param[in] k_projected Projected query with key weights.
762+ Format [n_layers, batch size, seq_len, num heads, head dim]
763+ @param[in] v_projected Projected query with value weights.
764+ Format [n_layers, batch size, seq_len, num heads, head dim]
765+ @param[in] key_cache Cache of previous k_projected.
766+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
767+ @param[in] key_cache Cache of previous v_projected.
768+ Format [n_layers, batch size, max_seq_len, num heads, head dim]
769+ ....
770+ @param[in] start_pos: sequence position
771+ @param[in] seq_len: Seq length. e.g. seq_len dim of q_projected.
772+ */
773+ Tensor& custom_sdpa_out (
774+ RuntimeContext& ctx,
775+ const Tensor& q,
776+ const Tensor& k,
777+ const Tensor& v,
778+ const int64_t start_pos,
779+ const int64_t seq_len,
780+ const optional<Tensor>& attn_mask,
781+ const double dropout_p,
782+ const bool is_causal,
783+ // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
784+ const optional<double > scale,
785+ Tensor& output) {
786+ ET_KERNEL_CHECK_MSG (
787+ ctx,
788+ !attn_mask.has_value () || !is_causal,
789+ InvalidArgument,
790+ output,
791+ " attn_mask and is_causal cannot be set at the same time" );
792+
793+ ET_CHECK_MSG (q.dim () == 4 , " query must be a 4D tensor" );
794+
795+ auto q_seq_len = q.size (1 );
796+
797+ // Refactor the following into create_view util perhaps using
798+ // TensorPtr
799+ std::array<exec_aten::DimOrderType, util::kKVDim > sliced_key_dim_order{
800+ 0 , 1 , 2 , 3 };
801+ std::array<exec_aten::SizesType, util::kKVDim > sliced_key_sizes;
802+ sliced_key_sizes[0 ] = k.size (0 );
803+ sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
804+ sliced_key_sizes[2 ] = k.size (2 );
805+ sliced_key_sizes[3 ] = k.size (3 );
806+ std::array<exec_aten::StridesType, util::kKVDim > sliced_key_strides;
807+ dim_order_to_stride_nocheck (
808+ sliced_key_sizes.data (),
809+ sliced_key_dim_order.data (),
810+ util::kKVDim ,
811+ sliced_key_strides.data ());
812+ // since the cache is sliced, the batch stride needs to stay the same.
813+ sliced_key_strides[0 ] = k.strides ()[0 ];
814+ void * key_cache_data = k.mutable_data_ptr ();
815+ TensorImpl k_impl = TensorImpl (
816+ k.scalar_type (),
817+ util::kKVDim ,
818+ sliced_key_sizes.data (),
819+ key_cache_data,
820+ sliced_key_dim_order.data (),
821+ sliced_key_strides.data (),
822+ TensorShapeDynamism::STATIC);
823+ Tensor sliced_key_cache (&k_impl);
824+
825+ std::array<exec_aten::DimOrderType, util::kKVDim > sliced_value_dim_order{
826+ 0 , 1 , 2 , 3 };
827+ std::array<exec_aten::SizesType, util::kKVDim > sliced_value_sizes;
828+ sliced_value_sizes[0 ] = v.size (0 );
829+ sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
830+ sliced_value_sizes[2 ] = v.size (2 );
831+ sliced_value_sizes[3 ] = v.size (3 );
832+ std::array<exec_aten::StridesType, util::kKVDim > sliced_value_strides;
833+ dim_order_to_stride_nocheck (
834+ sliced_value_sizes.data (),
835+ sliced_value_dim_order.data (),
836+ util::kKVDim ,
837+ sliced_value_strides.data ());
838+ // since the cache is sliced, the batch stride needs to stay the same.
839+ sliced_value_strides[0 ] = v.strides ()[0 ];
840+ void * value_cache_data = v.mutable_data_ptr ();
841+ TensorImpl value_impl = TensorImpl (
842+ v.scalar_type (),
843+ util::kKVDim ,
844+ sliced_value_sizes.data (),
845+ value_cache_data,
846+ sliced_value_dim_order.data (),
847+ sliced_value_strides.data (),
848+ TensorShapeDynamism::STATIC);
849+ Tensor sliced_value_cache (&value_impl);
850+
851+ ET_KERNEL_CHECK (
852+ ctx,
853+ resize_tensor (output, q.sizes ()) == Error::Ok,
854+ InvalidArgument,
855+ output);
856+
857+ // TODO(task): replace the template param selection logic
858+ // with whatever apprpriately makes more sense for
859+ ET_SWITCH_FLOAT_TYPES (q.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
860+ // TODO we need to re-evaluate this for ARM CPUs
861+ // And there can be many so instead of templatizing
862+ // we might consider another appraoch
863+ if (q_seq_len >= 768 ) {
864+ cpu_flash_attention<CTYPE, 256 , 512 >(
865+ output,
866+ q,
867+ sliced_key_cache,
868+ sliced_value_cache,
869+ dropout_p,
870+ is_causal,
871+ attn_mask,
872+ scale,
873+ true ,
874+ start_pos);
875+ } else if (q_seq_len >= 192 ) {
876+ cpu_flash_attention<CTYPE, 64 , 512 >(
877+ output,
878+ q,
879+ sliced_key_cache,
880+ sliced_value_cache,
881+ dropout_p,
882+ is_causal,
883+ attn_mask,
884+ scale,
885+ true ,
886+ start_pos);
887+ } else {
888+ cpu_flash_attention<CTYPE, 32 , 512 >(
889+ output,
890+ q,
891+ sliced_key_cache,
892+ sliced_value_cache,
893+ dropout_p,
894+ is_causal,
895+ attn_mask,
896+ scale,
897+ true ,
898+ start_pos);
899+ }
900+ });
901+ return output;
902+ }
757903} // anonymous namespace
758904
759905Tensor& flash_attention_kernel_out (
@@ -860,129 +1006,24 @@ Tensor& sdpa_with_kv_cache_out(
8601006 InvalidArgument,
8611007 output);
8621008
863- ET_KERNEL_CHECK_MSG (
864- ctx,
865- !attn_mask.has_value () || !is_causal,
866- InvalidArgument,
867- output,
868- " attn_mask and is_causal cannot be set at the same time" );
869-
8701009 ET_CHECK_MSG (q_projected.dim () == 4 , " query must be a 4D tensor" );
8711010
8721011 update_cache (k_projected, key_cache, start_pos, seq_len);
8731012 update_cache (v_projected, value_cache, start_pos, seq_len);
8741013
875- auto q_seq_len = q_projected.size (1 );
876-
877- std::array<exec_aten::DimOrderType, util::kKVDim > sliced_key_dim_order{
878- 0 , 1 , 2 , 3 };
879- std::array<exec_aten::SizesType, util::kKVDim > sliced_key_sizes;
880- sliced_key_sizes[0 ] = key_cache.size (0 );
881- sliced_key_sizes[1 ] = start_pos + seq_len; // key_cache.size(2);
882- sliced_key_sizes[2 ] = key_cache.size (2 );
883- sliced_key_sizes[3 ] = key_cache.size (3 );
884- std::array<exec_aten::StridesType, util::kKVDim > sliced_key_strides;
885- dim_order_to_stride_nocheck (
886- sliced_key_sizes.data (),
887- sliced_key_dim_order.data (),
888- util::kKVDim ,
889- sliced_key_strides.data ());
890- // since the cache is sliced, the batch stride needs to stay the same.
891- sliced_key_strides[0 ] = key_cache.strides ()[0 ];
892- void * key_cache_data = key_cache.mutable_data_ptr ();
893- TensorImpl k_impl = TensorImpl (
894- key_cache.scalar_type (),
895- util::kKVDim ,
896- sliced_key_sizes.data (),
897- key_cache_data,
898- sliced_key_dim_order.data (),
899- sliced_key_strides.data (),
900- TensorShapeDynamism::STATIC);
901- Tensor sliced_key_cache (&k_impl);
902-
903- std::array<exec_aten::DimOrderType, util::kKVDim > sliced_value_dim_order{
904- 0 , 1 , 2 , 3 };
905- std::array<exec_aten::SizesType, util::kKVDim > sliced_value_sizes;
906- sliced_value_sizes[0 ] = value_cache.size (0 );
907- sliced_value_sizes[1 ] = start_pos + seq_len; // value_cache.size(2);
908- sliced_value_sizes[2 ] = value_cache.size (2 );
909- sliced_value_sizes[3 ] = value_cache.size (3 );
910- std::array<exec_aten::StridesType, util::kKVDim > sliced_value_strides;
911- dim_order_to_stride_nocheck (
912- sliced_value_sizes.data (),
913- sliced_value_dim_order.data (),
914- util::kKVDim ,
915- sliced_value_strides.data ());
916- // since the cache is sliced, the batch stride needs to stay the same.
917- sliced_value_strides[0 ] = value_cache.strides ()[0 ];
918- void * value_cache_data = value_cache.mutable_data_ptr ();
919- TensorImpl value_impl = TensorImpl (
920- value_cache.scalar_type (),
921- util::kKVDim ,
922- sliced_value_sizes.data (),
923- value_cache_data,
924- sliced_value_dim_order.data (),
925- sliced_value_strides.data (),
926- TensorShapeDynamism::STATIC);
927- Tensor sliced_value_cache (&value_impl);
928-
929- // Is this true?
930- // Cant do this as is because the expectation of this kernel is
931- // that q, k, v are [B, num heads, seq length, head dim]
932- // and the cache is [B, max seq len, num heads, head dim]
933- // and q, k, v are all [B, seq length, num heads, head dim]
934-
935- ET_KERNEL_CHECK (
1014+ custom_sdpa_out (
9361015 ctx,
937- resize_tensor (output, q_projected.sizes ()) == Error::Ok,
938- InvalidArgument,
1016+ q_projected,
1017+ key_cache,
1018+ value_cache,
1019+ start_pos,
1020+ seq_len,
1021+ attn_mask,
1022+ dropout_p,
1023+ is_causal,
1024+ scale,
9391025 output);
9401026
941- // TODO(task): replace the template param selection logic
942- // with whatever apprpriately makes more sense for
943- ET_SWITCH_FLOAT_TYPES (
944- q_projected.scalar_type (), ctx, " flash_attention" , CTYPE, [&] {
945- // TODO we need to re-evaluate this for ARM CPUs
946- // And there can be many so instead of templatizing
947- // we might consider another appraoch
948- if (q_seq_len >= 768 ) {
949- cpu_flash_attention<CTYPE, 256 , 512 >(
950- output,
951- q_projected,
952- sliced_key_cache,
953- sliced_value_cache,
954- dropout_p,
955- is_causal,
956- attn_mask,
957- scale,
958- true ,
959- start_pos);
960- } else if (q_seq_len >= 192 ) {
961- cpu_flash_attention<CTYPE, 64 , 512 >(
962- output,
963- q_projected,
964- sliced_key_cache,
965- sliced_value_cache,
966- dropout_p,
967- is_causal,
968- attn_mask,
969- scale,
970- true ,
971- start_pos);
972- } else {
973- cpu_flash_attention<CTYPE, 32 , 512 >(
974- output,
975- q_projected,
976- sliced_key_cache,
977- sliced_value_cache,
978- dropout_p,
979- is_causal,
980- attn_mask,
981- scale,
982- true ,
983- start_pos);
984- }
985- });
9861027 return output;
9871028}
9881029} // namespace native
0 commit comments