@@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) {
524524  EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
525525}
526526
527- TEST (OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528-   TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
529- 
530-   executorch::aten::Tensor query = tfFloat.make (
531-       {1 , 1 , 4 , 4 },
532-       {0.8823 ,
533-        0.9150 ,
534-        0.3829 ,
535-        0.9593 ,
536-        0.3904 ,
537-        0.6009 ,
538-        0.2566 ,
539-        0.7936 ,
540-        0.9408 ,
541-        0.1332 ,
542-        0.9346 ,
543-        0.5936 ,
544-        0.8694 ,
545-        0.5677 ,
546-        0.7411 ,
547-        0.4294 });
548-   executorch::aten::Tensor key = tfFloat.make (
549-       {1 , 1 , 4 , 4 },
550-       {0.8854 ,
551-        0.5739 ,
552-        0.2666 ,
553-        0.6274 ,
554-        0.2696 ,
555-        0.4414 ,
556-        0.2969 ,
557-        0.8317 ,
558-        0.1053 ,
559-        0.2695 ,
560-        0.3588 ,
561-        0.1994 ,
562-        0.5472 ,
563-        0.0062 ,
564-        0.9516 ,
565-        0.0753 });
566-   executorch::aten::Tensor value = tfFloat.make (
567-       {1 , 1 , 4 , 4 },
568-       {0.8860 ,
569-        0.5832 ,
570-        0.3376 ,
571-        0.8090 ,
572-        0.5779 ,
573-        0.9040 ,
574-        0.5547 ,
575-        0.3423 ,
576-        0.6343 ,
577-        0.3644 ,
578-        0.7104 ,
579-        0.9464 ,
580-        0.7890 ,
581-        0.2814 ,
582-        0.7886 ,
583-        0.5895 });
584-   executorch::aten::Tensor attn_mask = tfFloat.make ({1 , 1 }, {0 });
585-   executorch::aten::Tensor key_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
586-   executorch::aten::Tensor value_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
587-   executorch::aten::Tensor key_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
588-   executorch::aten::Tensor value_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
589-   executorch::aten::Tensor key_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
590-   executorch::aten::Tensor value_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
591-   double  dropout_p = 0 ;
592-   bool  is_causal = false ;
593-   executorch::aten::optional<double > scale;
594- 
595-   //  start pos: 0 layer id 0
596-   executorch::aten::Tensor ret_expected_0 = tfFloat.make (
597-       {1 , 1 , 4 , 4 },
598-       {0.8860 ,
599-        0.5832 ,
600-        0.3376 ,
601-        0.8090 ,
602-        0.5779 ,
603-        0.9040 ,
604-        0.5547 ,
605-        0.3423 ,
606-        0.6343 ,
607-        0.3644 ,
608-        0.7104 ,
609-        0.9464 ,
610-        0.7890 ,
611-        0.2814 ,
612-        0.7886 ,
613-        0.5895 });
614- 
615-   std::vector<int32_t > out_size = {1 , 1 , 4 , 4 };
616-   executorch::aten::Tensor out = tfFloat.zeros (out_size);
617-   executorch::aten::Tensor ret = op_sdpa_with_kv_cache (
618-       query,
619-       key,
620-       value,
621-       key_cache_0,
622-       value_cache_0,
623-       0 ,
624-       1 ,
625-       attn_mask,
626-       dropout_p,
627-       is_causal,
628-       scale,
629-       out);
630-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_0, 1e-4 , 1e-4 );
631- 
632-   //  start pos: 0 layer id 2
633-   executorch::aten::Tensor ret_expected_1 = tfFloat.make (
634-       {1 , 1 , 4 , 4 },
635-       {0.8860 ,
636-        0.5832 ,
637-        0.3376 ,
638-        0.8090 ,
639-        0.5779 ,
640-        0.9040 ,
641-        0.5547 ,
642-        0.3423 ,
643-        0.6343 ,
644-        0.3644 ,
645-        0.7104 ,
646-        0.9464 ,
647-        0.7890 ,
648-        0.2814 ,
649-        0.7886 ,
650-        0.5895 });
651-   out = tfFloat.zeros (out_size);
652-   ret = op_sdpa_with_kv_cache (
653-       query,
654-       key,
655-       value,
656-       key_cache_2,
657-       value_cache_2,
658-       0 ,
659-       1 ,
660-       attn_mask,
661-       dropout_p,
662-       is_causal,
663-       scale,
664-       out);
665-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_1, 1e-4 , 1e-4 );
666- 
667-   attn_mask = tfFloat.make ({1 , 2 }, {0 , 0 });
668-   //  start pos: 1 layer id 0
669-   executorch::aten::Tensor ret_expected_2 = tfFloat.make (
670-       {1 , 1 , 4 , 4 },
671-       {0.8860 ,
672-        0.5832 ,
673-        0.3376 ,
674-        0.8090 ,
675-        0.5779 ,
676-        0.9040 ,
677-        0.5547 ,
678-        0.3423 ,
679-        0.6343 ,
680-        0.3644 ,
681-        0.7104 ,
682-        0.9464 ,
683-        0.7890 ,
684-        0.2814 ,
685-        0.7886 ,
686-        0.5895 });
687-   out = tfFloat.zeros (out_size);
688-   ret = op_sdpa_with_kv_cache (
689-       query,
690-       key,
691-       value,
692-       key_cache_0,
693-       value_cache_0,
694-       1 ,
695-       1 ,
696-       attn_mask,
697-       dropout_p,
698-       is_causal,
699-       scale,
700-       out);
701-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_2, 1e-4 , 1e-4 );
702- 
703-   //  start pos: 1 layer id 1
704-   executorch::aten::Tensor ret_expected_3 = tfFloat.make (
705-       {1 , 1 , 4 , 4 },
706-       {0.6486 ,
707-        0.4270 ,
708-        0.2472 ,
709-        0.5922 ,
710-        0.3669 ,
711-        0.5740 ,
712-        0.3522 ,
713-        0.2173 ,
714-        0.3635 ,
715-        0.2088 ,
716-        0.4071 ,
717-        0.5423 ,
718-        0.5110 ,
719-        0.1822 ,
720-        0.5107 ,
721-        0.3817 });
722-   out = tfFloat.zeros (out_size);
723-   ret = op_sdpa_with_kv_cache (
724-       query,
725-       key,
726-       value,
727-       key_cache_1,
728-       value_cache_1,
729-       1 ,
730-       1 ,
731-       attn_mask,
732-       dropout_p,
733-       is_causal,
734-       scale,
735-       out);
736-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_3, 1e-4 , 1e-4 );
737- 
738-   attn_mask = tfFloat.make ({1 , 3 }, {0 , 0 , 0 });
739-   //  start pos: 2 layer id 1
740-   executorch::aten::Tensor ret_expected_4 = tfFloat.make (
741-       {1 , 1 , 4 , 4 },
742-       {0.7490 ,
743-        0.4930 ,
744-        0.2854 ,
745-        0.6838 ,
746-        0.4489 ,
747-        0.7021 ,
748-        0.4308 ,
749-        0.2659 ,
750-        0.4622 ,
751-        0.2655 ,
752-        0.5176 ,
753-        0.6895 ,
754-        0.6202 ,
755-        0.2212 ,
756-        0.6199 ,
757-        0.4634 });
758-   out = tfFloat.zeros (out_size);
759-   ret = op_sdpa_with_kv_cache (
760-       query,
761-       key,
762-       value,
763-       key_cache_1,
764-       value_cache_1,
765-       2 ,
766-       1 ,
767-       attn_mask,
768-       dropout_p,
769-       is_causal,
770-       scale,
771-       out);
772-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_4, 1e-4 , 1e-4 );
773- 
774-   //  start pos: 2 layer id 2
775-   executorch::aten::Tensor ret_expected_5 = tfFloat.make (
776-       {1 , 1 , 4 , 4 },
777-       {0.7490 ,
778-        0.4930 ,
779-        0.2854 ,
780-        0.6838 ,
781-        0.4489 ,
782-        0.7021 ,
783-        0.4308 ,
784-        0.2659 ,
785-        0.4622 ,
786-        0.2655 ,
787-        0.5176 ,
788-        0.6895 ,
789-        0.6202 ,
790-        0.2212 ,
791-        0.6199 ,
792-        0.4634 });
793-   out = tfFloat.zeros (out_size);
794-   ret = op_sdpa_with_kv_cache (
795-       query,
796-       key,
797-       value,
798-       key_cache_2,
799-       value_cache_2,
800-       2 ,
801-       1 ,
802-       attn_mask,
803-       dropout_p,
804-       is_causal,
805-       scale,
806-       out);
807-   EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
808- }
809- 
810527TEST (OpScaledDotProductAttentionTest, SequenceTest) {
811528  TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
812529
0 commit comments