@@ -594,46 +594,46 @@ bool validate_flash_attention_args(
594594 const Tensor& key,
595595 const Tensor& value,
596596 const optional<Tensor>& attn_mask) {
597- ET_LOG_MSG_AND_RETURN_IF_FALSE (query.dim () == 4 , " query must be a 4D tensor" );
598- ET_LOG_MSG_AND_RETURN_IF_FALSE (key.dim () == 4 , " key must be a 4D tensor" );
599- ET_LOG_MSG_AND_RETURN_IF_FALSE (value.dim () == 4 , " value must be a 4D tensor" );
597+ ET_CHECK_OR_RETURN_FALSE (query.dim () == 4 , " query must be a 4D tensor" );
598+ ET_CHECK_OR_RETURN_FALSE (key.dim () == 4 , " key must be a 4D tensor" );
599+ ET_CHECK_OR_RETURN_FALSE (value.dim () == 4 , " value must be a 4D tensor" );
600600
601601 // Sizes
602- ET_LOG_MSG_AND_RETURN_IF_FALSE (
602+ ET_CHECK_OR_RETURN_FALSE (
603603 (query.size (3 ) == value.size (3 )) && (key.size (3 ) == value.size (3 )),
604604 " scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size" );
605605
606- ET_LOG_MSG_AND_RETURN_IF_FALSE (
606+ ET_CHECK_OR_RETURN_FALSE (
607607 (query.scalar_type () == ScalarType::Float), " Query must be Float type" );
608608
609- ET_LOG_MSG_AND_RETURN_IF_FALSE (
609+ ET_CHECK_OR_RETURN_FALSE (
610610 (query.scalar_type () == key.scalar_type ()) &&
611611 (query.scalar_type () == value.scalar_type ()),
612612 " Key and Value must have the same data type as Query" );
613613
614- ET_LOG_MSG_AND_RETURN_IF_FALSE (
614+ ET_CHECK_OR_RETURN_FALSE (
615615 !attn_mask.has_value () || attn_mask.value ().dim () == 2 ,
616616 " Attention mask must be a 2D tensor" );
617617
618- ET_LOG_MSG_AND_RETURN_IF_FALSE (
618+ ET_CHECK_OR_RETURN_FALSE (
619619 !attn_mask.has_value () ||
620620 attn_mask.value ().scalar_type () == query.scalar_type (),
621621 " Attention mask must be a 2D tensor" );
622622
623- ET_LOG_MSG_AND_RETURN_IF_FALSE (
623+ ET_CHECK_OR_RETURN_FALSE (
624624 is_contiguous_dim_order (query.dim_order ().data (), query.dim ()),
625625 " key cache must be in contiguous dim order" );
626626
627- ET_LOG_MSG_AND_RETURN_IF_FALSE (
627+ ET_CHECK_OR_RETURN_FALSE (
628628 is_contiguous_dim_order (key.dim_order ().data (), key.dim ()),
629629 " value cache must be in contiguous dim order" );
630630
631- ET_LOG_MSG_AND_RETURN_IF_FALSE (
631+ ET_CHECK_OR_RETURN_FALSE (
632632 is_contiguous_dim_order (value.dim_order ().data (), value.dim ()),
633633 " value cache must be in contiguous dim order" );
634634
635635 if (attn_mask.has_value ()) {
636- ET_LOG_MSG_AND_RETURN_IF_FALSE (
636+ ET_CHECK_OR_RETURN_FALSE (
637637 is_contiguous_dim_order (
638638 attn_mask.value ().dim_order ().data (), attn_mask.value ().dim ()),
639639 " value cache must be in contiguous dim order" );
@@ -647,21 +647,19 @@ bool validate_cache_params(
647647 const Tensor& v_cache,
648648 int64_t start_pos,
649649 int64_t seq_length) {
650- ET_LOG_MSG_AND_RETURN_IF_FALSE (
651- k_cache.dim () == 4 , " kcache must be a 4D tensor" );
650+ ET_CHECK_OR_RETURN_FALSE (k_cache.dim () == 4 , " kcache must be a 4D tensor" );
652651
653- ET_LOG_MSG_AND_RETURN_IF_FALSE (
654- v_cache.dim () == 4 , " v_cache must be a 4D tensor" );
652+ ET_CHECK_OR_RETURN_FALSE (v_cache.dim () == 4 , " v_cache must be a 4D tensor" );
655653
656- ET_LOG_MSG_AND_RETURN_IF_FALSE (
654+ ET_CHECK_OR_RETURN_FALSE (
657655 start_pos < k_cache.size (1 ),
658656 " start_pos must be less than key cache at dim 1" );
659657
660- ET_LOG_MSG_AND_RETURN_IF_FALSE (
658+ ET_CHECK_OR_RETURN_FALSE (
661659 start_pos < v_cache.size (1 ),
662660 " start_pos must be less than value cache at dim 1" );
663661
664- ET_LOG_MSG_AND_RETURN_IF_FALSE (
662+ ET_CHECK_OR_RETURN_FALSE (
665663 (start_pos + seq_length) <= k_cache.size (1 ),
666664 " start_post + seq_length must be less than max seq length supported by key cache."
667665 " start pos: %" PRId64 " , seq_length: %" PRId64
@@ -671,7 +669,7 @@ bool validate_cache_params(
671669 seq_length,
672670 k_cache.size (1 ));
673671
674- ET_LOG_MSG_AND_RETURN_IF_FALSE (
672+ ET_CHECK_OR_RETURN_FALSE (
675673 (start_pos + seq_length) <= v_cache.size (1 ),
676674 " start_post + seq_length must be less than max seq length supported by key cache."
677675 " start pos: %" PRId64 " , seq_length: %" PRId64
@@ -682,11 +680,11 @@ bool validate_cache_params(
682680 v_cache.size (1 ));
683681
684682 // Make sure they are in contiguous dim order
685- ET_LOG_MSG_AND_RETURN_IF_FALSE (
683+ ET_CHECK_OR_RETURN_FALSE (
686684 is_contiguous_dim_order (k_cache.dim_order ().data (), k_cache.dim ()),
687685 " key cache must be in contiguous dim order" );
688686
689- ET_LOG_MSG_AND_RETURN_IF_FALSE (
687+ ET_CHECK_OR_RETURN_FALSE (
690688 is_contiguous_dim_order (v_cache.dim_order ().data (), v_cache.dim ()),
691689 " value cache must be in contiguous dim order" );
692690
0 commit comments