@@ -594,46 +594,46 @@ bool validate_flash_attention_args(
594594 const Tensor& key,
595595 const Tensor& value,
596596 const optional<Tensor>& attn_mask) {
597- ET_LOG_MSG_AND_RETURN_IF_FALSE (query.dim () == 4 , " query must be a 4D tensor" );
598- ET_LOG_MSG_AND_RETURN_IF_FALSE (key.dim () == 4 , " key must be a 4D tensor" );
599- ET_LOG_MSG_AND_RETURN_IF_FALSE (value.dim () == 4 , " value must be a 4D tensor" );
597+ ET_LOG_MSG_AND_RETURN_UNLESS (query.dim () == 4 , " query must be a 4D tensor" );
598+ ET_LOG_MSG_AND_RETURN_UNLESS (key.dim () == 4 , " key must be a 4D tensor" );
599+ ET_LOG_MSG_AND_RETURN_UNLESS (value.dim () == 4 , " value must be a 4D tensor" );
600600
601601 // Sizes
602- ET_LOG_MSG_AND_RETURN_IF_FALSE (
602+ ET_LOG_MSG_AND_RETURN_UNLESS (
603603 (query.size (3 ) == value.size (3 )) && (key.size (3 ) == value.size (3 )),
604604 " scaled_dot_product_attention_flash_attention: Q/K/V should have the same head size" );
605605
606- ET_LOG_MSG_AND_RETURN_IF_FALSE (
606+ ET_LOG_MSG_AND_RETURN_UNLESS (
607607 (query.scalar_type () == ScalarType::Float), " Query must be Float type" );
608608
609- ET_LOG_MSG_AND_RETURN_IF_FALSE (
609+ ET_LOG_MSG_AND_RETURN_UNLESS (
610610 (query.scalar_type () == key.scalar_type ()) &&
611611 (query.scalar_type () == value.scalar_type ()),
612612 " Key and Value must have the same data type as Query" );
613613
614- ET_LOG_MSG_AND_RETURN_IF_FALSE (
614+ ET_LOG_MSG_AND_RETURN_UNLESS (
615615 !attn_mask.has_value () || attn_mask.value ().dim () == 2 ,
616616 " Attention mask must be a 2D tensor" );
617617
618- ET_LOG_MSG_AND_RETURN_IF_FALSE (
618+ ET_LOG_MSG_AND_RETURN_UNLESS (
619619 !attn_mask.has_value () ||
620620 attn_mask.value ().scalar_type () == query.scalar_type (),
621621 " Attention mask must be a 2D tensor" );
622622
623- ET_LOG_MSG_AND_RETURN_IF_FALSE (
623+ ET_LOG_MSG_AND_RETURN_UNLESS (
624624 is_contiguous_dim_order (query.dim_order ().data (), query.dim ()),
625625 " key cache must be in contiguous dim order" );
626626
627- ET_LOG_MSG_AND_RETURN_IF_FALSE (
627+ ET_LOG_MSG_AND_RETURN_UNLESS (
628628 is_contiguous_dim_order (key.dim_order ().data (), key.dim ()),
629629 " value cache must be in contiguous dim order" );
630630
631- ET_LOG_MSG_AND_RETURN_IF_FALSE (
631+ ET_LOG_MSG_AND_RETURN_UNLESS (
632632 is_contiguous_dim_order (value.dim_order ().data (), value.dim ()),
633633 " value cache must be in contiguous dim order" );
634634
635635 if (attn_mask.has_value ()) {
636- ET_LOG_MSG_AND_RETURN_IF_FALSE (
636+ ET_LOG_MSG_AND_RETURN_UNLESS (
637637 is_contiguous_dim_order (
638638 attn_mask.value ().dim_order ().data (), attn_mask.value ().dim ()),
639639 " value cache must be in contiguous dim order" );
@@ -647,21 +647,21 @@ bool validate_cache_params(
647647 const Tensor& v_cache,
648648 int64_t start_pos,
649649 int64_t seq_length) {
650- ET_LOG_MSG_AND_RETURN_IF_FALSE (
650+ ET_LOG_MSG_AND_RETURN_UNLESS (
651651 k_cache.dim () == 4 , " kcache must be a 4D tensor" );
652652
653- ET_LOG_MSG_AND_RETURN_IF_FALSE (
653+ ET_LOG_MSG_AND_RETURN_UNLESS (
654654 v_cache.dim () == 4 , " v_cache must be a 4D tensor" );
655655
656- ET_LOG_MSG_AND_RETURN_IF_FALSE (
656+ ET_LOG_MSG_AND_RETURN_UNLESS (
657657 start_pos < k_cache.size (1 ),
658658 " start_pos must be less than key cache at dim 1" );
659659
660- ET_LOG_MSG_AND_RETURN_IF_FALSE (
660+ ET_LOG_MSG_AND_RETURN_UNLESS (
661661 start_pos < v_cache.size (1 ),
662662 " start_pos must be less than value cache at dim 1" );
663663
664- ET_LOG_MSG_AND_RETURN_IF_FALSE (
664+ ET_LOG_MSG_AND_RETURN_UNLESS (
665665 (start_pos + seq_length) <= k_cache.size (1 ),
666666 " start_post + seq_length must be less than max seq length supported by key cache."
667667 " start pos: %" PRId64 " , seq_length: %" PRId64
@@ -671,7 +671,7 @@ bool validate_cache_params(
671671 seq_length,
672672 k_cache.size (1 ));
673673
674- ET_LOG_MSG_AND_RETURN_IF_FALSE (
674+ ET_LOG_MSG_AND_RETURN_UNLESS (
675675 (start_pos + seq_length) <= v_cache.size (1 ),
676676 " start_post + seq_length must be less than max seq length supported by key cache."
677677 " start pos: %" PRId64 " , seq_length: %" PRId64
@@ -682,11 +682,11 @@ bool validate_cache_params(
682682 v_cache.size (1 ));
683683
684684 // Make sure they are in contiguous dim order
685- ET_LOG_MSG_AND_RETURN_IF_FALSE (
685+ ET_LOG_MSG_AND_RETURN_UNLESS (
686686 is_contiguous_dim_order (k_cache.dim_order ().data (), k_cache.dim ()),
687687 " key cache must be in contiguous dim order" );
688688
689- ET_LOG_MSG_AND_RETURN_IF_FALSE (
689+ ET_LOG_MSG_AND_RETURN_UNLESS (
690690 is_contiguous_dim_order (v_cache.dim_order ().data (), v_cache.dim ()),
691691 " value cache must be in contiguous dim order" );
692692
0 commit comments