@@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) {
524524 EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
525525}
526526
527- TEST (OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528- TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
529-
530- executorch::aten::Tensor query = tfFloat.make (
531- {1 , 1 , 4 , 4 },
532- {0.8823 ,
533- 0.9150 ,
534- 0.3829 ,
535- 0.9593 ,
536- 0.3904 ,
537- 0.6009 ,
538- 0.2566 ,
539- 0.7936 ,
540- 0.9408 ,
541- 0.1332 ,
542- 0.9346 ,
543- 0.5936 ,
544- 0.8694 ,
545- 0.5677 ,
546- 0.7411 ,
547- 0.4294 });
548- executorch::aten::Tensor key = tfFloat.make (
549- {1 , 1 , 4 , 4 },
550- {0.8854 ,
551- 0.5739 ,
552- 0.2666 ,
553- 0.6274 ,
554- 0.2696 ,
555- 0.4414 ,
556- 0.2969 ,
557- 0.8317 ,
558- 0.1053 ,
559- 0.2695 ,
560- 0.3588 ,
561- 0.1994 ,
562- 0.5472 ,
563- 0.0062 ,
564- 0.9516 ,
565- 0.0753 });
566- executorch::aten::Tensor value = tfFloat.make (
567- {1 , 1 , 4 , 4 },
568- {0.8860 ,
569- 0.5832 ,
570- 0.3376 ,
571- 0.8090 ,
572- 0.5779 ,
573- 0.9040 ,
574- 0.5547 ,
575- 0.3423 ,
576- 0.6343 ,
577- 0.3644 ,
578- 0.7104 ,
579- 0.9464 ,
580- 0.7890 ,
581- 0.2814 ,
582- 0.7886 ,
583- 0.5895 });
584- executorch::aten::Tensor attn_mask = tfFloat.make ({1 , 1 }, {0 });
585- executorch::aten::Tensor key_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
586- executorch::aten::Tensor value_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
587- executorch::aten::Tensor key_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
588- executorch::aten::Tensor value_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
589- executorch::aten::Tensor key_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
590- executorch::aten::Tensor value_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
591- double dropout_p = 0 ;
592- bool is_causal = false ;
593- executorch::aten::optional<double > scale;
594-
595- // start pos: 0 layer id 0
596- executorch::aten::Tensor ret_expected_0 = tfFloat.make (
597- {1 , 1 , 4 , 4 },
598- {0.8860 ,
599- 0.5832 ,
600- 0.3376 ,
601- 0.8090 ,
602- 0.5779 ,
603- 0.9040 ,
604- 0.5547 ,
605- 0.3423 ,
606- 0.6343 ,
607- 0.3644 ,
608- 0.7104 ,
609- 0.9464 ,
610- 0.7890 ,
611- 0.2814 ,
612- 0.7886 ,
613- 0.5895 });
614-
615- std::vector<int32_t > out_size = {1 , 1 , 4 , 4 };
616- executorch::aten::Tensor out = tfFloat.zeros (out_size);
617- executorch::aten::Tensor ret = op_sdpa_with_kv_cache (
618- query,
619- key,
620- value,
621- key_cache_0,
622- value_cache_0,
623- 0 ,
624- 1 ,
625- attn_mask,
626- dropout_p,
627- is_causal,
628- scale,
629- out);
630- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_0, 1e-4 , 1e-4 );
631-
632- // start pos: 0 layer id 2
633- executorch::aten::Tensor ret_expected_1 = tfFloat.make (
634- {1 , 1 , 4 , 4 },
635- {0.8860 ,
636- 0.5832 ,
637- 0.3376 ,
638- 0.8090 ,
639- 0.5779 ,
640- 0.9040 ,
641- 0.5547 ,
642- 0.3423 ,
643- 0.6343 ,
644- 0.3644 ,
645- 0.7104 ,
646- 0.9464 ,
647- 0.7890 ,
648- 0.2814 ,
649- 0.7886 ,
650- 0.5895 });
651- out = tfFloat.zeros (out_size);
652- ret = op_sdpa_with_kv_cache (
653- query,
654- key,
655- value,
656- key_cache_2,
657- value_cache_2,
658- 0 ,
659- 1 ,
660- attn_mask,
661- dropout_p,
662- is_causal,
663- scale,
664- out);
665- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_1, 1e-4 , 1e-4 );
666-
667- attn_mask = tfFloat.make ({1 , 2 }, {0 , 0 });
668- // start pos: 1 layer id 0
669- executorch::aten::Tensor ret_expected_2 = tfFloat.make (
670- {1 , 1 , 4 , 4 },
671- {0.8860 ,
672- 0.5832 ,
673- 0.3376 ,
674- 0.8090 ,
675- 0.5779 ,
676- 0.9040 ,
677- 0.5547 ,
678- 0.3423 ,
679- 0.6343 ,
680- 0.3644 ,
681- 0.7104 ,
682- 0.9464 ,
683- 0.7890 ,
684- 0.2814 ,
685- 0.7886 ,
686- 0.5895 });
687- out = tfFloat.zeros (out_size);
688- ret = op_sdpa_with_kv_cache (
689- query,
690- key,
691- value,
692- key_cache_0,
693- value_cache_0,
694- 1 ,
695- 1 ,
696- attn_mask,
697- dropout_p,
698- is_causal,
699- scale,
700- out);
701- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_2, 1e-4 , 1e-4 );
702-
703- // start pos: 1 layer id 1
704- executorch::aten::Tensor ret_expected_3 = tfFloat.make (
705- {1 , 1 , 4 , 4 },
706- {0.6486 ,
707- 0.4270 ,
708- 0.2472 ,
709- 0.5922 ,
710- 0.3669 ,
711- 0.5740 ,
712- 0.3522 ,
713- 0.2173 ,
714- 0.3635 ,
715- 0.2088 ,
716- 0.4071 ,
717- 0.5423 ,
718- 0.5110 ,
719- 0.1822 ,
720- 0.5107 ,
721- 0.3817 });
722- out = tfFloat.zeros (out_size);
723- ret = op_sdpa_with_kv_cache (
724- query,
725- key,
726- value,
727- key_cache_1,
728- value_cache_1,
729- 1 ,
730- 1 ,
731- attn_mask,
732- dropout_p,
733- is_causal,
734- scale,
735- out);
736- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_3, 1e-4 , 1e-4 );
737-
738- attn_mask = tfFloat.make ({1 , 3 }, {0 , 0 , 0 });
739- // start pos: 2 layer id 1
740- executorch::aten::Tensor ret_expected_4 = tfFloat.make (
741- {1 , 1 , 4 , 4 },
742- {0.7490 ,
743- 0.4930 ,
744- 0.2854 ,
745- 0.6838 ,
746- 0.4489 ,
747- 0.7021 ,
748- 0.4308 ,
749- 0.2659 ,
750- 0.4622 ,
751- 0.2655 ,
752- 0.5176 ,
753- 0.6895 ,
754- 0.6202 ,
755- 0.2212 ,
756- 0.6199 ,
757- 0.4634 });
758- out = tfFloat.zeros (out_size);
759- ret = op_sdpa_with_kv_cache (
760- query,
761- key,
762- value,
763- key_cache_1,
764- value_cache_1,
765- 2 ,
766- 1 ,
767- attn_mask,
768- dropout_p,
769- is_causal,
770- scale,
771- out);
772- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_4, 1e-4 , 1e-4 );
773-
774- // start pos: 2 layer id 2
775- executorch::aten::Tensor ret_expected_5 = tfFloat.make (
776- {1 , 1 , 4 , 4 },
777- {0.7490 ,
778- 0.4930 ,
779- 0.2854 ,
780- 0.6838 ,
781- 0.4489 ,
782- 0.7021 ,
783- 0.4308 ,
784- 0.2659 ,
785- 0.4622 ,
786- 0.2655 ,
787- 0.5176 ,
788- 0.6895 ,
789- 0.6202 ,
790- 0.2212 ,
791- 0.6199 ,
792- 0.4634 });
793- out = tfFloat.zeros (out_size);
794- ret = op_sdpa_with_kv_cache (
795- query,
796- key,
797- value,
798- key_cache_2,
799- value_cache_2,
800- 2 ,
801- 1 ,
802- attn_mask,
803- dropout_p,
804- is_causal,
805- scale,
806- out);
807- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
808- }
809-
810527TEST (OpScaledDotProductAttentionTest, SequenceTest) {
811528 TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
812529
0 commit comments