2626
2727#if ARCH_X86
2828#include "x86/adm_avx2.h"
29+ #if HAVE_AVX512
30+ #include "x86/adm_avx512.h"
31+ #endif
2932#elif ARCH_AARCH64
3033#include "arm64/adm_neon.h"
3134#include <arm_neon.h>
@@ -41,6 +44,31 @@ typedef struct AdmState {
4144 void (* dwt2_8 )(const uint8_t * src , const adm_dwt_band_t * dst ,
4245 AdmBuffer * buf , int w , int h , int src_stride ,
4346 int dst_stride );
47+ void (* dwt2_16 )(const uint16_t * src , const adm_dwt_band_t * dst ,
48+ AdmBuffer * buf , int w , int h , int src_stride ,
49+ int dst_stride , int inp_size_bits );
50+ void (* adm_decouple )(AdmBuffer * buf , int w , int h , int stride ,
51+ double adm_enhn_gain_limit , int32_t * adm_div_lookup );
52+ void (* adm_decouple_s123 )(AdmBuffer * buf , int w , int h , int stride ,
53+ double adm_enhn_gain_limit , int32_t * adm_div_lookup );
54+ float (* adm_csf_den_scale )(const adm_dwt_band_t * src , int w , int h ,
55+ int src_stride , double adm_norm_view_dist ,
56+ int adm_ref_display_height );
57+ void (* adm_csf )(AdmBuffer * buf , int w , int h , int stride ,
58+ double adm_norm_view_dist , int adm_ref_display_height );
59+ float (* adm_cm )(AdmBuffer * buf , int w , int h , int src_stride , int csf_a_stride ,
60+ double adm_norm_view_dist , int adm_ref_display_height );
61+ void (* adm_dwt2_s123_combined )(const int32_t * i4_ref_scale , const int32_t * i4_curr_dis ,
62+ AdmBuffer * buf , int w , int h , int ref_stride ,
63+ int dis_stride , int dst_stride , int scale );
64+ float (* adm_csf_den_s123 )(const i4_adm_dwt_band_t * src , int scale , int w , int h ,
65+ int src_stride , double adm_norm_view_dist ,
66+ int adm_ref_display_height );
67+ void (* i4_adm_csf )(AdmBuffer * buf , int scale , int w , int h , int stride ,
68+ double adm_norm_view_dist , int adm_ref_display_height );
69+ float (* i4_adm_cm )(AdmBuffer * buf , int w , int h , int src_stride ,
70+ int csf_a_stride , int scale , double adm_norm_view_dist ,
71+ int adm_ref_display_height );
4472 VmafDictionary * feature_name_dict ;
4573} AdmState ;
4674
@@ -657,8 +685,8 @@ static void dwt2_src_indices_filt(int **src_ind_y, int **src_ind_x, int w, int h
657685#define MIN (x , y ) (((x) < (y)) ? (x) : (y))
658686#define MAX (x , y ) (((x) > (y)) ? (x) : (y))
659687
660- static void adm_decouple (AdmBuffer * buf , int w , int h , int stride ,
661- double adm_enhn_gain_limit )
688+ static inline void adm_decouple (AdmBuffer * buf , int w , int h , int stride ,
689+ double adm_enhn_gain_limit , int32_t * adm_div_lookup )
662690{
663691 const float cos_1deg_sq = cos (1.0 * M_PI / 180.0 ) * cos (1.0 * M_PI / 180.0 );
664692
@@ -684,7 +712,6 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
684712 if (bottom > h ) {
685713 bottom = h ;
686714 }
687-
688715 int64_t ot_dp , o_mag_sq , t_mag_sq ;
689716
690717 for (int i = top ; i < bottom ; ++ i ) {
@@ -722,6 +749,7 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
722749 o_mag_sq = (int64_t )oh * oh + (int64_t )ov * ov ;
723750 t_mag_sq = (int64_t )th * th + (int64_t )tv * tv ;
724751
752+
725753 /**
726754 * angle_flag is calculated in floating-point by converting fixed-point variables back to
727755 * floating-point
@@ -735,16 +763,17 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
735763 */
736764
737765 int32_t tmp_kh = (oh == 0 ) ?
738- 32768 : (((int64_t )div_lookup [oh + 32768 ] * th ) + 16384 ) >> 15 ;
766+ 32768 : (((int64_t )adm_div_lookup [oh + 32768 ] * th ) + 16384 ) >> 15 ;
739767 int32_t tmp_kv = (ov == 0 ) ?
740- 32768 : (((int64_t )div_lookup [ov + 32768 ] * tv ) + 16384 ) >> 15 ;
768+ 32768 : (((int64_t )adm_div_lookup [ov + 32768 ] * tv ) + 16384 ) >> 15 ;
741769 int32_t tmp_kd = (od == 0 ) ?
742- 32768 : (((int64_t )div_lookup [od + 32768 ] * td ) + 16384 ) >> 15 ;
770+ 32768 : (((int64_t )adm_div_lookup [od + 32768 ] * td ) + 16384 ) >> 15 ;
743771
744772 int32_t kh = tmp_kh < 0 ? 0 : (tmp_kh > 32768 ? 32768 : tmp_kh );
745773 int32_t kv = tmp_kv < 0 ? 0 : (tmp_kv > 32768 ? 32768 : tmp_kv );
746774 int32_t kd = tmp_kd < 0 ? 0 : (tmp_kd > 32768 ? 32768 : tmp_kd );
747775
776+
748777 /**
749778 * kh,kv,kd are in Q15 type and oh,ov,od are in Q16 type hence shifted by
750779 * 15 to make result Q16
@@ -787,7 +816,7 @@ static inline uint16_t get_best15_from32(uint32_t temp, int *x)
787816}
788817
789818static void adm_decouple_s123 (AdmBuffer * buf , int w , int h , int stride ,
790- double adm_enhn_gain_limit )
819+ double adm_enhn_gain_limit , int32_t * adm_div_lookup )
791820{
792821 const float cos_1deg_sq = cos (1.0 * M_PI / 180.0 ) * cos (1.0 * M_PI / 180.0 );
793822
@@ -890,11 +919,11 @@ static void adm_decouple_s123(AdmBuffer *buf, int w, int h, int stride,
890919 uint16_t kv_msb = (abs_ov < (32768 ) ? abs_ov : get_best15_from32 (abs_ov , & kv_shift ));
891920 uint16_t kd_msb = (abs_od < (32768 ) ? abs_od : get_best15_from32 (abs_od , & kd_shift ));
892921
893- int64_t tmp_kh = (oh == 0 ) ? 32768 : (((int64_t )div_lookup [kh_msb + 32768 ] * th )* (kh_sign ) +
922+ int64_t tmp_kh = (oh == 0 ) ? 32768 : (((int64_t )adm_div_lookup [kh_msb + 32768 ] * th )* (kh_sign ) +
894923 (1 << (14 + kh_shift ))) >> (15 + kh_shift );
895- int64_t tmp_kv = (ov == 0 ) ? 32768 : (((int64_t )div_lookup [kv_msb + 32768 ] * tv )* (kv_sign ) +
924+ int64_t tmp_kv = (ov == 0 ) ? 32768 : (((int64_t )adm_div_lookup [kv_msb + 32768 ] * tv )* (kv_sign ) +
896925 (1 << (14 + kv_shift ))) >> (15 + kv_shift );
897- int64_t tmp_kd = (od == 0 ) ? 32768 : (((int64_t )div_lookup [kd_msb + 32768 ] * td )* (kd_sign ) +
926+ int64_t tmp_kd = (od == 0 ) ? 32768 : (((int64_t )adm_div_lookup [kd_msb + 32768 ] * td )* (kd_sign ) +
898927 (1 << (14 + kd_shift ))) >> (15 + kd_shift );
899928
900929 int64_t kh = tmp_kh < 0 ? 0 : (tmp_kh > 32768 ? 32768 : tmp_kh );
@@ -903,12 +932,12 @@ static void adm_decouple_s123(AdmBuffer *buf, int w, int h, int stride,
903932
904933 rst_h = ((kh * oh ) + 16384 ) >> 15 ;
905934 rst_v = ((kv * ov ) + 16384 ) >> 15 ;
906- rst_d = ((kd * od ) + 16384 ) >> 15 ;
935+ rst_d = ((kd * od ) + 16384 ) >> 15 ;
907936
908937 const float rst_h_f = ((float )kh / 32768 ) * ((float )oh / 64 );
909938 const float rst_v_f = ((float )kv / 32768 ) * ((float )ov / 64 );
910939 const float rst_d_f = ((float )kd / 32768 ) * ((float )od / 64 );
911-
940+
912941 if (angle_flag && (rst_h_f > 0. )) rst_h = MIN ((rst_h * adm_enhn_gain_limit ), th );
913942 if (angle_flag && (rst_h_f < 0. )) rst_h = MAX ((rst_h * adm_enhn_gain_limit ), th );
914943
@@ -1416,6 +1445,7 @@ static float adm_cm(AdmBuffer *buf, int w, int h, int src_stride, int csf_a_stri
14161445 accum_inner_h = 0 ;
14171446 accum_inner_v = 0 ;
14181447 accum_inner_d = 0 ;
1448+
14191449 for (j = start_col ; j < end_col ; ++ j ) {
14201450 xh = src -> band_h [i * src_stride + j ] * i_rfactor [0 ];
14211451 xv = src -> band_v [i * src_stride + j ] * i_rfactor [1 ];
@@ -2429,7 +2459,6 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24292459{
24302460 int w = ref_pic -> w [0 ];
24312461 int h = ref_pic -> h [0 ];
2432-
24332462 const double numden_limit = 1e-10 * (w * h ) / (1920.0 * 1080.0 );
24342463
24352464 size_t curr_ref_stride ;
@@ -2463,9 +2492,9 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24632492 curr_dis_stride , buf_stride );
24642493 }
24652494 else {
2466- adm_dwt2_16 (ref_pic -> data [0 ], & buf -> ref_dwt2 , buf , w , h ,
2495+ s -> dwt2_16 (ref_pic -> data [0 ], & buf -> ref_dwt2 , buf , w , h ,
24672496 curr_ref_stride , buf_stride , ref_pic -> bpc );
2468- adm_dwt2_16 (dis_pic -> data [0 ], & buf -> dis_dwt2 , buf , w , h ,
2497+ s -> dwt2_16 (dis_pic -> data [0 ], & buf -> dis_dwt2 , buf , w , h ,
24692498 curr_dis_stride , buf_stride , dis_pic -> bpc );
24702499 }
24712500
@@ -2474,35 +2503,35 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24742503
24752504 w = (w + 1 ) / 2 ;
24762505 h = (h + 1 ) / 2 ;
2506+ s -> adm_decouple (buf , w , h , buf_stride , adm_enhn_gain_limit , div_lookup );
24772507
2478- adm_decouple (buf , w , h , buf_stride , adm_enhn_gain_limit );
2479-
2480- den_scale = adm_csf_den_scale (& buf -> ref_dwt2 , w , h , buf_stride ,
2508+ den_scale = s -> adm_csf_den_scale (& buf -> ref_dwt2 , w , h , buf_stride ,
24812509 adm_norm_view_dist , adm_ref_display_height );
24822510
2483- adm_csf (buf , w , h , buf_stride , adm_norm_view_dist , adm_ref_display_height );
2511+ s -> adm_csf (buf , w , h , buf_stride , adm_norm_view_dist , adm_ref_display_height );
24842512
2485- num_scale = adm_cm (buf , w , h , buf_stride , buf_stride ,
2513+ num_scale = s -> adm_cm (buf , w , h , buf_stride , buf_stride ,
24862514 adm_norm_view_dist , adm_ref_display_height );
24872515 }
24882516 else {
2489- adm_dwt2_s123_combined (i4_curr_ref_scale , i4_curr_dis_scale , buf , w , h , curr_ref_stride ,
2490- curr_dis_stride , buf_stride , scale );
2517+ s -> adm_dwt2_s123_combined (i4_curr_ref_scale , i4_curr_dis_scale , buf , w , h , curr_ref_stride ,
2518+ curr_dis_stride , buf_stride , scale );
24912519
24922520 w = (w + 1 ) / 2 ;
24932521 h = (h + 1 ) / 2 ;
24942522
2495- adm_decouple_s123 (buf , w , h , buf_stride , adm_enhn_gain_limit );
2523+ s -> adm_decouple_s123 (buf , w , h , buf_stride , adm_enhn_gain_limit , div_lookup );
24962524
2497- den_scale = adm_csf_den_s123 (
2525+ den_scale = s -> adm_csf_den_s123 (
24982526 & buf -> i4_ref_dwt2 , scale , w , h , buf_stride ,
24992527 adm_norm_view_dist , adm_ref_display_height );
2528+
2529+ s -> i4_adm_csf (buf , scale , w , h , buf_stride ,
2530+ adm_norm_view_dist , adm_ref_display_height );
25002531
2501- i4_adm_csf (buf , scale , w , h , buf_stride ,
2502- adm_norm_view_dist , adm_ref_display_height );
2532+ num_scale = s -> i4_adm_cm (buf , w , h , buf_stride , buf_stride , scale ,
2533+ adm_norm_view_dist , adm_ref_display_height );
25032534
2504- num_scale = i4_adm_cm (buf , w , h , buf_stride , buf_stride , scale ,
2505- adm_norm_view_dist , adm_ref_display_height );
25062535 }
25072536
25082537 num += num_scale ;
@@ -2593,12 +2622,47 @@ static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
25932622 }
25942623
25952624 s -> dwt2_8 = adm_dwt2_8 ;
2625+ s -> dwt2_16 = adm_dwt2_16 ;
2626+ s -> adm_csf_den_scale = adm_csf_den_scale ;
2627+ s -> adm_csf = adm_csf ;
2628+ s -> adm_cm = adm_cm ;
2629+ s -> adm_dwt2_s123_combined = adm_dwt2_s123_combined ;
2630+ s -> adm_csf_den_s123 = adm_csf_den_s123 ;
2631+ s -> i4_adm_csf = i4_adm_csf ;
2632+ s -> i4_adm_cm = i4_adm_cm ;
2633+ s -> adm_decouple = adm_decouple ;
2634+ s -> adm_decouple_s123 = adm_decouple_s123 ;
25962635
25972636#if ARCH_X86
25982637 unsigned flags = vmaf_get_cpu_flags ();
25992638 if (flags & VMAF_X86_CPU_FLAG_AVX2 ) {
26002639 if (!(w % 8 )) s -> dwt2_8 = adm_dwt2_8_avx2 ;
2640+ s -> dwt2_16 = adm_dwt2_16_avx2 ;
2641+ s -> adm_csf_den_scale = adm_csf_den_scale_avx2 ;
2642+ s -> adm_csf = adm_csf_avx2 ;
2643+ s -> adm_cm = adm_cm_avx2 ;
2644+ s -> adm_csf_den_s123 = adm_csf_den_s123_avx2 ;
2645+ s -> adm_dwt2_s123_combined = adm_dwt2_s123_combined_avx2 ;
2646+ s -> i4_adm_csf = i4_adm_csf_avx2 ;
2647+ s -> i4_adm_cm = i4_adm_cm_avx2 ;
2648+ s -> adm_decouple = adm_decouple_avx2 ;
2649+ s -> adm_decouple_s123 = adm_decouple_s123_avx2 ;
2650+ }
2651+ #if HAVE_AVX512
2652+ if (flags & VMAF_X86_CPU_FLAG_AVX512 ) {
2653+ s -> dwt2_8 = adm_dwt2_8_avx512 ;
2654+ s -> dwt2_16 = adm_dwt2_16_avx512 ;
2655+ s -> adm_csf_den_scale = adm_csf_den_scale_avx512 ;
2656+ s -> adm_csf = adm_csf_avx512 ;
2657+ s -> adm_cm = adm_cm_avx512 ;
2658+ s -> adm_csf_den_s123 = adm_csf_den_s123_avx512 ;
2659+ s -> adm_dwt2_s123_combined = adm_dwt2_s123_combined_avx512 ;
2660+ s -> i4_adm_csf = i4_adm_csf_avx512 ;
2661+ s -> i4_adm_cm = i4_adm_cm_avx512 ;
2662+ s -> adm_decouple = adm_decouple_avx512 ;
2663+ s -> adm_decouple_s123 = adm_decouple_s123_avx512 ;
26012664 }
2665+ #endif
26022666#elif ARCH_AARCH64
26032667 unsigned flags = vmaf_get_cpu_flags ();
26042668 if (flags & VMAF_ARM_CPU_FLAG_NEON ) {
0 commit comments