Skip to content

Commit 44a9254

Browse files
vaibhavk2fhannebicabirdme
authored
Additional AVX optimizations with build and vmaf accuracy fixes (#1452)
* Intel additional AVX optimizations * use functions pointers * move include pragmas * Picture.h: Fix build issues with AVX patches Resolve the path for picture.h file while compiling with AVX optimization patches. Signed-off-by: Vaibhav Shankar <vaibhav.shankar@intel.com> * applying fixes to adm_decouple_s123 avx2 and avx512 versions to address identified accuracy discrepancies Signed-off-by: Christopher Bird <christopher.a.bird@intel.com> * Fix accuracy issues in adm_decouple_512 (AVX-512) This change fixes accuracy mismatches in the AVX-512 implementation of adm_decouple_512. The issue was reproducible only with narrow width video files from Netflix’s Test: - Verified AVX-512 output matches scalar implementation Signed-off-by: Vaibhav Shankar <vaibhav.shankar@intel.com> * Fix accuracy issues in adm_dwt2_8_avx2 (AVX2) Resolves accuracy mismatches in the AVX2 implementation of adm_dwt2_8_avx2. The issue was reproducible only with specific narrow width video files. Test: - Verified AVX2 output(accuracy) matches the scalar implementation Signed-off-by: Vaibhav Shankar <vaibhav.shankar@intel.com> * Fix Windows build issues in AVX optimizations. Update AVX-related code to use data types compatible with both Linux and Windows. Tested on Windows (MSYS2) and Linux: - Build completes successfully - All tests pass Signed-off-by: Vaibhav Shankar <vaibhav.shankar@intel.com> --------- Signed-off-by: Vaibhav Shankar <vaibhav.shankar@intel.com> Signed-off-by: Christopher Bird <christopher.a.bird@intel.com> Co-authored-by: fhannebi <francois.hannebicq@intel.com> Co-authored-by: Christopher Bird <christopher.a.bird@intel.com>
1 parent 7e16db0 commit 44a9254

File tree

15 files changed

+8936
-104
lines changed

15 files changed

+8936
-104
lines changed

libvmaf/src/feature/integer_adm.c

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626

2727
#if ARCH_X86
2828
#include "x86/adm_avx2.h"
29+
#if HAVE_AVX512
30+
#include "x86/adm_avx512.h"
31+
#endif
2932
#elif ARCH_AARCH64
3033
#include "arm64/adm_neon.h"
3134
#include <arm_neon.h>
@@ -41,6 +44,31 @@ typedef struct AdmState {
4144
void (*dwt2_8)(const uint8_t *src, const adm_dwt_band_t *dst,
4245
AdmBuffer *buf, int w, int h, int src_stride,
4346
int dst_stride);
47+
void (*dwt2_16)(const uint16_t *src, const adm_dwt_band_t *dst,
48+
AdmBuffer *buf, int w, int h, int src_stride,
49+
int dst_stride, int inp_size_bits);
50+
void (*adm_decouple)(AdmBuffer *buf, int w, int h, int stride,
51+
double adm_enhn_gain_limit, int32_t* adm_div_lookup);
52+
void (*adm_decouple_s123)(AdmBuffer *buf, int w, int h, int stride,
53+
double adm_enhn_gain_limit, int32_t* adm_div_lookup);
54+
float (*adm_csf_den_scale)(const adm_dwt_band_t *src, int w, int h,
55+
int src_stride, double adm_norm_view_dist,
56+
int adm_ref_display_height);
57+
void (*adm_csf)(AdmBuffer *buf, int w, int h, int stride,
58+
double adm_norm_view_dist, int adm_ref_display_height);
59+
float (*adm_cm)(AdmBuffer *buf, int w, int h, int src_stride, int csf_a_stride,
60+
double adm_norm_view_dist, int adm_ref_display_height);
61+
void (*adm_dwt2_s123_combined)(const int32_t *i4_ref_scale, const int32_t *i4_curr_dis,
62+
AdmBuffer *buf, int w, int h, int ref_stride,
63+
int dis_stride, int dst_stride, int scale);
64+
float (*adm_csf_den_s123)(const i4_adm_dwt_band_t *src, int scale, int w, int h,
65+
int src_stride, double adm_norm_view_dist,
66+
int adm_ref_display_height);
67+
void (*i4_adm_csf)(AdmBuffer *buf, int scale, int w, int h, int stride,
68+
double adm_norm_view_dist, int adm_ref_display_height);
69+
float (*i4_adm_cm)(AdmBuffer *buf, int w, int h, int src_stride,
70+
int csf_a_stride, int scale, double adm_norm_view_dist,
71+
int adm_ref_display_height);
4472
VmafDictionary *feature_name_dict;
4573
} AdmState;
4674

@@ -657,8 +685,8 @@ static void dwt2_src_indices_filt(int **src_ind_y, int **src_ind_x, int w, int h
657685
#define MIN(x, y) (((x) < (y)) ? (x) : (y))
658686
#define MAX(x, y) (((x) > (y)) ? (x) : (y))
659687

660-
static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
661-
double adm_enhn_gain_limit)
688+
static inline void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
689+
double adm_enhn_gain_limit, int32_t* adm_div_lookup)
662690
{
663691
const float cos_1deg_sq = cos(1.0 * M_PI / 180.0) * cos(1.0 * M_PI / 180.0);
664692

@@ -684,7 +712,6 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
684712
if (bottom > h) {
685713
bottom = h;
686714
}
687-
688715
int64_t ot_dp, o_mag_sq, t_mag_sq;
689716

690717
for (int i = top; i < bottom; ++i) {
@@ -722,6 +749,7 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
722749
o_mag_sq = (int64_t)oh * oh + (int64_t)ov * ov;
723750
t_mag_sq = (int64_t)th * th + (int64_t)tv * tv;
724751

752+
725753
/**
726754
* angle_flag is calculated in floating-point by converting fixed-point variables back to
727755
* floating-point
@@ -735,16 +763,17 @@ static void adm_decouple(AdmBuffer *buf, int w, int h, int stride,
735763
*/
736764

737765
int32_t tmp_kh = (oh == 0) ?
738-
32768 : (((int64_t)div_lookup[oh + 32768] * th) + 16384) >> 15;
766+
32768 : (((int64_t)adm_div_lookup[oh + 32768] * th) + 16384) >> 15;
739767
int32_t tmp_kv = (ov == 0) ?
740-
32768 : (((int64_t)div_lookup[ov + 32768] * tv) + 16384) >> 15;
768+
32768 : (((int64_t)adm_div_lookup[ov + 32768] * tv) + 16384) >> 15;
741769
int32_t tmp_kd = (od == 0) ?
742-
32768 : (((int64_t)div_lookup[od + 32768] * td) + 16384) >> 15;
770+
32768 : (((int64_t)adm_div_lookup[od + 32768] * td) + 16384) >> 15;
743771

744772
int32_t kh = tmp_kh < 0 ? 0 : (tmp_kh > 32768 ? 32768 : tmp_kh);
745773
int32_t kv = tmp_kv < 0 ? 0 : (tmp_kv > 32768 ? 32768 : tmp_kv);
746774
int32_t kd = tmp_kd < 0 ? 0 : (tmp_kd > 32768 ? 32768 : tmp_kd);
747775

776+
748777
/**
749778
* kh,kv,kd are in Q15 type and oh,ov,od are in Q16 type hence shifted by
750779
* 15 to make result Q16
@@ -787,7 +816,7 @@ static inline uint16_t get_best15_from32(uint32_t temp, int *x)
787816
}
788817

789818
static void adm_decouple_s123(AdmBuffer *buf, int w, int h, int stride,
790-
double adm_enhn_gain_limit)
819+
double adm_enhn_gain_limit, int32_t* adm_div_lookup)
791820
{
792821
const float cos_1deg_sq = cos(1.0 * M_PI / 180.0) * cos(1.0 * M_PI / 180.0);
793822

@@ -890,11 +919,11 @@ static void adm_decouple_s123(AdmBuffer *buf, int w, int h, int stride,
890919
uint16_t kv_msb = (abs_ov < (32768) ? abs_ov : get_best15_from32(abs_ov, &kv_shift));
891920
uint16_t kd_msb = (abs_od < (32768) ? abs_od : get_best15_from32(abs_od, &kd_shift));
892921

893-
int64_t tmp_kh = (oh == 0) ? 32768 : (((int64_t)div_lookup[kh_msb + 32768] * th)*(kh_sign) +
922+
int64_t tmp_kh = (oh == 0) ? 32768 : (((int64_t)adm_div_lookup[kh_msb + 32768] * th)*(kh_sign) +
894923
(1 << (14 + kh_shift))) >> (15 + kh_shift);
895-
int64_t tmp_kv = (ov == 0) ? 32768 : (((int64_t)div_lookup[kv_msb + 32768] * tv)*(kv_sign) +
924+
int64_t tmp_kv = (ov == 0) ? 32768 : (((int64_t)adm_div_lookup[kv_msb + 32768] * tv)*(kv_sign) +
896925
(1 << (14 + kv_shift))) >> (15 + kv_shift);
897-
int64_t tmp_kd = (od == 0) ? 32768 : (((int64_t)div_lookup[kd_msb + 32768] * td)*(kd_sign) +
926+
int64_t tmp_kd = (od == 0) ? 32768 : (((int64_t)adm_div_lookup[kd_msb + 32768] * td)*(kd_sign) +
898927
(1 << (14 + kd_shift))) >> (15 + kd_shift);
899928

900929
int64_t kh = tmp_kh < 0 ? 0 : (tmp_kh > 32768 ? 32768 : tmp_kh);
@@ -903,12 +932,12 @@ static void adm_decouple_s123(AdmBuffer *buf, int w, int h, int stride,
903932

904933
rst_h = ((kh * oh) + 16384) >> 15;
905934
rst_v = ((kv * ov) + 16384) >> 15;
906-
rst_d = ((kd * od) + 16384) >> 15;
935+
rst_d = ((kd * od) + 16384) >> 15;
907936

908937
const float rst_h_f = ((float)kh / 32768) * ((float)oh / 64);
909938
const float rst_v_f = ((float)kv / 32768) * ((float)ov / 64);
910939
const float rst_d_f = ((float)kd / 32768) * ((float)od / 64);
911-
940+
912941
if (angle_flag && (rst_h_f > 0.)) rst_h = MIN((rst_h * adm_enhn_gain_limit), th);
913942
if (angle_flag && (rst_h_f < 0.)) rst_h = MAX((rst_h * adm_enhn_gain_limit), th);
914943

@@ -1416,6 +1445,7 @@ static float adm_cm(AdmBuffer *buf, int w, int h, int src_stride, int csf_a_stri
14161445
accum_inner_h = 0;
14171446
accum_inner_v = 0;
14181447
accum_inner_d = 0;
1448+
14191449
for (j = start_col; j < end_col; ++j) {
14201450
xh = src->band_h[i * src_stride + j] * i_rfactor[0];
14211451
xv = src->band_v[i * src_stride + j] * i_rfactor[1];
@@ -2429,7 +2459,6 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24292459
{
24302460
int w = ref_pic->w[0];
24312461
int h = ref_pic->h[0];
2432-
24332462
const double numden_limit = 1e-10 * (w * h) / (1920.0 * 1080.0);
24342463

24352464
size_t curr_ref_stride;
@@ -2463,9 +2492,9 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24632492
curr_dis_stride, buf_stride);
24642493
}
24652494
else {
2466-
adm_dwt2_16(ref_pic->data[0], &buf->ref_dwt2, buf, w, h,
2495+
s->dwt2_16(ref_pic->data[0], &buf->ref_dwt2, buf, w, h,
24672496
curr_ref_stride, buf_stride, ref_pic->bpc);
2468-
adm_dwt2_16(dis_pic->data[0], &buf->dis_dwt2, buf, w, h,
2497+
s->dwt2_16(dis_pic->data[0], &buf->dis_dwt2, buf, w, h,
24692498
curr_dis_stride, buf_stride, dis_pic->bpc);
24702499
}
24712500

@@ -2474,35 +2503,35 @@ void integer_compute_adm(AdmState *s, VmafPicture *ref_pic, VmafPicture *dis_pic
24742503

24752504
w = (w + 1) / 2;
24762505
h = (h + 1) / 2;
2506+
s->adm_decouple(buf, w, h, buf_stride, adm_enhn_gain_limit, div_lookup);
24772507

2478-
adm_decouple(buf, w, h, buf_stride, adm_enhn_gain_limit);
2479-
2480-
den_scale = adm_csf_den_scale(&buf->ref_dwt2, w, h, buf_stride,
2508+
den_scale = s->adm_csf_den_scale(&buf->ref_dwt2, w, h, buf_stride,
24812509
adm_norm_view_dist, adm_ref_display_height);
24822510

2483-
adm_csf(buf, w, h, buf_stride, adm_norm_view_dist, adm_ref_display_height);
2511+
s->adm_csf(buf, w, h, buf_stride, adm_norm_view_dist, adm_ref_display_height);
24842512

2485-
num_scale = adm_cm(buf, w, h, buf_stride, buf_stride,
2513+
num_scale = s->adm_cm(buf, w, h, buf_stride, buf_stride,
24862514
adm_norm_view_dist, adm_ref_display_height);
24872515
}
24882516
else {
2489-
adm_dwt2_s123_combined(i4_curr_ref_scale, i4_curr_dis_scale, buf, w, h, curr_ref_stride,
2490-
curr_dis_stride, buf_stride, scale);
2517+
s->adm_dwt2_s123_combined(i4_curr_ref_scale, i4_curr_dis_scale, buf, w, h, curr_ref_stride,
2518+
curr_dis_stride, buf_stride, scale);
24912519

24922520
w = (w + 1) / 2;
24932521
h = (h + 1) / 2;
24942522

2495-
adm_decouple_s123(buf, w, h, buf_stride, adm_enhn_gain_limit);
2523+
s->adm_decouple_s123(buf, w, h, buf_stride, adm_enhn_gain_limit, div_lookup);
24962524

2497-
den_scale = adm_csf_den_s123(
2525+
den_scale = s->adm_csf_den_s123(
24982526
&buf->i4_ref_dwt2, scale, w, h, buf_stride,
24992527
adm_norm_view_dist, adm_ref_display_height);
2528+
2529+
s->i4_adm_csf(buf, scale, w, h, buf_stride,
2530+
adm_norm_view_dist, adm_ref_display_height);
25002531

2501-
i4_adm_csf(buf, scale, w, h, buf_stride,
2502-
adm_norm_view_dist, adm_ref_display_height);
2532+
num_scale = s->i4_adm_cm(buf, w, h, buf_stride, buf_stride, scale,
2533+
adm_norm_view_dist, adm_ref_display_height);
25032534

2504-
num_scale = i4_adm_cm(buf, w, h, buf_stride, buf_stride, scale,
2505-
adm_norm_view_dist, adm_ref_display_height);
25062535
}
25072536

25082537
num += num_scale;
@@ -2593,12 +2622,47 @@ static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
25932622
}
25942623

25952624
s->dwt2_8 = adm_dwt2_8;
2625+
s->dwt2_16 = adm_dwt2_16;
2626+
s->adm_csf_den_scale = adm_csf_den_scale;
2627+
s->adm_csf = adm_csf;
2628+
s->adm_cm = adm_cm;
2629+
s->adm_dwt2_s123_combined = adm_dwt2_s123_combined;
2630+
s->adm_csf_den_s123 = adm_csf_den_s123;
2631+
s->i4_adm_csf = i4_adm_csf;
2632+
s->i4_adm_cm = i4_adm_cm;
2633+
s->adm_decouple = adm_decouple;
2634+
s->adm_decouple_s123 = adm_decouple_s123;
25962635

25972636
#if ARCH_X86
25982637
unsigned flags = vmaf_get_cpu_flags();
25992638
if (flags & VMAF_X86_CPU_FLAG_AVX2) {
26002639
if (!(w % 8)) s->dwt2_8 = adm_dwt2_8_avx2;
2640+
s->dwt2_16 = adm_dwt2_16_avx2;
2641+
s->adm_csf_den_scale = adm_csf_den_scale_avx2;
2642+
s->adm_csf = adm_csf_avx2;
2643+
s->adm_cm = adm_cm_avx2;
2644+
s->adm_csf_den_s123 = adm_csf_den_s123_avx2;
2645+
s->adm_dwt2_s123_combined = adm_dwt2_s123_combined_avx2;
2646+
s->i4_adm_csf = i4_adm_csf_avx2;
2647+
s->i4_adm_cm = i4_adm_cm_avx2;
2648+
s->adm_decouple = adm_decouple_avx2;
2649+
s->adm_decouple_s123 = adm_decouple_s123_avx2;
2650+
}
2651+
#if HAVE_AVX512
2652+
if (flags & VMAF_X86_CPU_FLAG_AVX512) {
2653+
s->dwt2_8 = adm_dwt2_8_avx512;
2654+
s->dwt2_16 = adm_dwt2_16_avx512;
2655+
s->adm_csf_den_scale = adm_csf_den_scale_avx512;
2656+
s->adm_csf = adm_csf_avx512;
2657+
s->adm_cm = adm_cm_avx512;
2658+
s->adm_csf_den_s123 = adm_csf_den_s123_avx512;
2659+
s->adm_dwt2_s123_combined = adm_dwt2_s123_combined_avx512;
2660+
s->i4_adm_csf = i4_adm_csf_avx512;
2661+
s->i4_adm_cm = i4_adm_cm_avx512;
2662+
s->adm_decouple = adm_decouple_avx512;
2663+
s->adm_decouple_s123 = adm_decouple_s123_avx512;
26012664
}
2665+
#endif
26022666
#elif ARCH_AARCH64
26032667
unsigned flags = vmaf_get_cpu_flags();
26042668
if (flags & VMAF_ARM_CPU_FLAG_NEON) {

libvmaf/src/feature/integer_motion.c

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
283283

284284
MotionState *s = fex->priv;
285285
int err = 0;
286+
unsigned flags = vmaf_get_cpu_flags();
286287

287288
s->feature_name_dict =
288289
vmaf_feature_name_dict_from_provided_features(fex->provided_features,
@@ -303,10 +304,17 @@ static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
303304
if (err) goto fail;
304305

305306
s->y_convolution = bpc == 8 ? y_convolution_8 : y_convolution_16;
306-
s->x_convolution = x_convolution_16;
307+
#if ARCH_X86
308+
if (flags & VMAF_X86_CPU_FLAG_AVX2)
309+
s->y_convolution = bpc == 8 ? y_convolution_8_avx2 : y_convolution_16_avx2;
310+
#if HAVE_AVX512
311+
if (flags & VMAF_X86_CPU_FLAG_AVX512)
312+
s->y_convolution = bpc == 8 ? y_convolution_8_avx512 : y_convolution_16_avx512;
313+
#endif
314+
#endif
307315

316+
s->x_convolution = x_convolution_16;
308317
#if ARCH_X86
309-
unsigned flags = vmaf_get_cpu_flags();
310318
if (flags & VMAF_X86_CPU_FLAG_AVX2)
311319
s->x_convolution = x_convolution_16_avx2;
312320
#if HAVE_AVX512
@@ -316,6 +324,14 @@ static int init(VmafFeatureExtractor *fex, enum VmafPixelFormat pix_fmt,
316324
#endif
317325

318326
s->sad = sad_c;
327+
#if ARCH_X86
328+
if (flags & VMAF_X86_CPU_FLAG_AVX2)
329+
s->sad = sad_avx2;
330+
#if HAVE_AVX512
331+
if (flags & VMAF_X86_CPU_FLAG_AVX512)
332+
s->sad = sad_avx512;
333+
#endif
334+
#endif
319335
s->score = 0.;
320336

321337
return 0;

libvmaf/src/feature/integer_motion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
#include <stdbool.h>
2323
#include <stdint.h>
24+
#include "cpu.h"
2425

2526
static const uint16_t filter[5] = { 3571, 16004, 26386, 16004, 3571 };
2627
static const int filter_width = sizeof(filter) / sizeof(filter[0]);

libvmaf/src/feature/integer_vif.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <stdint.h>
2323
#include <stdbool.h>
2424
#include <assert.h>
25+
#include "cpu.h"
2526

2627
/* Enhancement gain imposed on vif, must be >= 1.0, where 1.0 means the gain is completely disabled */
2728
#ifndef DEFAULT_VIF_ENHN_GAIN_LIMIT

0 commit comments

Comments
 (0)