Skip to content

Commit fc892a6

Browse files
authored
Single pass Forward TX for AVX optimization (AcademySoftwareFoundation#101)
* function pointer updates Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com> * single pass Fwd Tx C-code update Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com> * single pass Fwd Tx AVX-code update Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com> * single pass Fwd Tx ARM_Neon-code update Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com> * remove unused code Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com> --------- Signed-off-by: subhrajitm20 <2003subhrajit@gmail.com>
1 parent 4de2c9b commit fc892a6

File tree

4 files changed

+59
-21
lines changed

4 files changed

+59
-21
lines changed

src/avx/oapv_tq_avx.c

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
_mm256_set_m128i(_mm_loadu_si128(hiaddr), _mm_loadu_si128(loaddr))
4444
#endif // !_mm256_loadu2_m128i
4545

46-
static void oapv_tx_part_avx(s16 *src, s16 *dst, int shift, int line)
46+
static void oapv_tx_avx(s16 *src, int shift1, int shift2, int line)
4747
{
4848
__m256i v0, v1, v2, v3, v4, v5, v6, v7;
49-
__m256i d0, d1, d2, d3;
49+
__m256i d0, d1, d2, d3, d4, d5;
5050
__m256i coeff[8];
5151
coeff[0] = _mm256_set1_epi16(64);
5252
coeff[1] = _mm256_set_epi16(64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64, 64, -64, -64, 64);
@@ -56,7 +56,8 @@ static void oapv_tx_part_avx(s16 *src, s16 *dst, int shift, int line)
5656
coeff[5] = _mm256_set_epi16(-75, 18, 89, 50, -50, -89, -18, 75, -75, 18, 89, 50, -50, -89, -18, 75);
5757
coeff[6] = _mm256_set_epi16(-50, 89, -18, -75, 75, 18, -89, 50, -50, 89, -18, -75, 75, 18, -89, 50);
5858
coeff[7] = _mm256_set_epi16(-18, 50, -75, 89, -89, 75, -50, 18, -18, 50, -75, 89, -89, 75, -50, 18);
59-
__m256i add = _mm256_set1_epi32(1 << (shift - 1));
59+
__m256i add1 = _mm256_set1_epi32(1 << (shift1 - 1));
60+
__m256i add2 = _mm256_set1_epi32(1 << (shift2 - 1));
6061

6162
__m256i s0, s1, s2, s3;
6263

@@ -67,38 +68,63 @@ static void oapv_tx_part_avx(s16 *src, s16 *dst, int shift, int line)
6768

6869
CALCU_2x8(coeff[0], coeff[4], d0, d1);
6970
CALCU_2x8(coeff[2], coeff[5], d2, d3);
70-
CALCU_2x8_ADD_SHIFT(d0, d1, d2, d3, add, shift)
71+
CALCU_2x8_ADD_SHIFT(d0, d1, d2, d3, add1, shift1);
7172

72-
d0 = _mm256_packs_epi32(d0, d1);
73+
d0 = _mm256_packs_epi32(d0, d1);
74+
d1 = _mm256_packs_epi32(d2, d3);
75+
76+
d0 = _mm256_permute4x64_epi64(d0, 0xd8);
77+
d1 = _mm256_permute4x64_epi64(d1, 0xd8);
78+
79+
CALCU_2x8(coeff[1], coeff[6], d2, d3);
80+
CALCU_2x8(coeff[3], coeff[7], d4, d5);
81+
CALCU_2x8_ADD_SHIFT(d2, d3, d4, d5, add1, shift1);
82+
83+
d2 = _mm256_packs_epi32(d2, d3);
84+
d3 = _mm256_packs_epi32(d4, d5);
85+
86+
d2 = _mm256_permute4x64_epi64(d2, 0xd8);
87+
d3 = _mm256_permute4x64_epi64(d3, 0xd8);
88+
89+
s0 = _mm256_setr_m128i(_mm256_castsi256_si128(d0), _mm256_castsi256_si128(d2));
90+
s1 = _mm256_setr_m128i(_mm256_extracti128_si256(d0, 1), _mm256_extracti128_si256(d2, 1));
91+
s2 = _mm256_setr_m128i(_mm256_castsi256_si128(d1), _mm256_castsi256_si128(d3));
92+
s3 = _mm256_setr_m128i(_mm256_extracti128_si256(d1, 1), _mm256_extracti128_si256(d3, 1));
93+
94+
CALCU_2x8(coeff[0], coeff[4], d0, d1);
95+
CALCU_2x8(coeff[2], coeff[5], d2, d3);
96+
CALCU_2x8_ADD_SHIFT(d0, d1, d2, d3, add2, shift2)
97+
98+
d0 = _mm256_packs_epi32(d0, d1);
7399
d1 = _mm256_packs_epi32(d2, d3);
74100

75101
d0 = _mm256_permute4x64_epi64(d0, 0xd8);
76102
d1 = _mm256_permute4x64_epi64(d1, 0xd8);
77103

78-
_mm_store_si128((__m128i *)dst, _mm256_castsi256_si128(d0));
79-
_mm_store_si128((__m128i *)(dst + 1 * line), _mm256_extracti128_si256(d0, 1));
80-
_mm_store_si128((__m128i *)(dst + 2 * line), _mm256_castsi256_si128(d1));
81-
_mm_store_si128((__m128i *)(dst + 3 * line), _mm256_extracti128_si256(d1, 1));
104+
_mm_store_si128((__m128i *)src, _mm256_castsi256_si128(d0));
105+
_mm_store_si128((__m128i *)(src + 1 * line), _mm256_extracti128_si256(d0, 1));
106+
_mm_store_si128((__m128i *)(src + 2 * line), _mm256_castsi256_si128(d1));
107+
_mm_store_si128((__m128i *)(src + 3 * line), _mm256_extracti128_si256(d1, 1));
82108

83109
CALCU_2x8(coeff[1], coeff[6], d0, d1);
84110
CALCU_2x8(coeff[3], coeff[7], d2, d3);
85-
CALCU_2x8_ADD_SHIFT(d0, d1, d2, d3, add, shift);
111+
CALCU_2x8_ADD_SHIFT(d0, d1, d2, d3, add2, shift2);
86112

87113
d0 = _mm256_packs_epi32(d0, d1);
88114
d1 = _mm256_packs_epi32(d2, d3);
89115

90116
d0 = _mm256_permute4x64_epi64(d0, 0xd8);
91117
d1 = _mm256_permute4x64_epi64(d1, 0xd8);
92118

93-
_mm_store_si128((__m128i *)(dst + 4 * line), _mm256_castsi256_si128(d0));
94-
_mm_store_si128((__m128i *)(dst + 5 * line), _mm256_extracti128_si256(d0, 1));
95-
_mm_store_si128((__m128i *)(dst + 6 * line), _mm256_castsi256_si128(d1));
96-
_mm_store_si128((__m128i *)(dst + 7 * line), _mm256_extracti128_si256(d1, 1));
119+
_mm_store_si128((__m128i *)(src + 4 * line), _mm256_castsi256_si128(d0));
120+
_mm_store_si128((__m128i *)(src + 5 * line), _mm256_extracti128_si256(d0, 1));
121+
_mm_store_si128((__m128i *)(src + 6 * line), _mm256_castsi256_si128(d1));
122+
_mm_store_si128((__m128i *)(src + 7 * line), _mm256_extracti128_si256(d1, 1));
97123
}
98124

99125
const oapv_fn_tx_t oapv_tbl_fn_txb_avx[2] =
100126
{
101-
oapv_tx_part_avx,
127+
oapv_tx_avx,
102128
NULL
103129
};
104130

src/neon/oapv_tq_neon.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const s32 oapv_coeff[8][4] =
5252
high = vmulq_s32(part2, coeff); \
5353
res = vcombine_s32(vpadd_s32(vget_low_s32(low), vget_high_s32(low)), vpadd_s32(vget_low_s32(high), vget_high_s32(high))); \
5454

55-
static void oapv_tx_pb8b_neon(s16 *src, s16 *dst, const int shift, int line)
55+
static void oapv_tx_pb8b_part_neon(s16 *src, s16 *dst, const int shift, int line)
5656
{
5757
s16 i;
5858
s16 *tempSrc = src;
@@ -186,6 +186,13 @@ static void oapv_tx_pb8b_neon(s16 *src, s16 *dst, const int shift, int line)
186186
}
187187
}
188188

189+
static void oapv_tx_pb8b_neon(s16 *src, const int shift1, const int shift2, int line)
190+
{
191+
ALIGNED_16(s16 dst[OAPV_BLK_D]);
192+
oapv_tx_pb8b_part_neon(src, dst, shift1, line);
193+
oapv_tx_pb8b_part_neon(dst, src, shift2, line);
194+
}
195+
189196
const oapv_fn_tx_t oapv_tbl_fn_txb_neon[2] =
190197
{
191198
oapv_tx_pb8b_neon,

src/oapv_def.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ typedef struct oapve_core oapve_core_t;
174174
*****************************************************************************/
175175
typedef void (*oapv_fn_itx_part_t)(s16 *coef, s16 *t, int shift, int line);
176176
typedef void (*oapv_fn_itx_t)(s16 *coef, int shift1, int shift2, int line);
177-
typedef void (*oapv_fn_tx_t)(s16 *coef, s16 *t, int shift, int line);
177+
typedef void (*oapv_fn_tx_t)(s16 *coef, int shift1, int shift2, int line);
178178
typedef void (*oapv_fn_itx_adj_t)(int *src, int *dst, int itrans_diff_idx, int diff_step, int shift);
179179
typedef int (*oapv_fn_quant_t)(s16 *coef, u8 qp, int q_matrix[OAPV_BLK_D], int log2_w, int log2_h, int bit_depth, int deadzone_offset);
180180
typedef void (*oapv_fn_dquant_t)(s16 *coef, s16 q_matrix[OAPV_BLK_D], int log2_w, int log2_h, s8 shift);

src/oapv_tq.c

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,15 @@ static void oapv_tx_part(s16 *src, s16 *dst, int shift, int line)
7070
}
7171
}
7272

73+
static void oapv_tx(s16 *src, int shift1, int shift2, int line)
74+
{
75+
ALIGNED_16(s16 dst[OAPV_BLK_D]);
76+
oapv_tx_part(src, dst, shift1, line);
77+
oapv_tx_part(dst, src, shift2, line);
78+
}
79+
7380
const oapv_fn_tx_t oapv_tbl_fn_tx[2] = {
74-
oapv_tx_part,
81+
oapv_tx,
7582
NULL
7683
};
7784

@@ -90,9 +97,7 @@ void oapv_trans(oapve_ctx_t *ctx, s16 *coef, int log2_w, int log2_h, int bit_dep
9097
int shift1 = get_transform_shift(log2_w, 0, bit_depth);
9198
int shift2 = get_transform_shift(log2_h, 1, bit_depth);
9299

93-
ALIGNED_16(s16 tb[OAPV_BLK_D]);
94-
(ctx->fn_txb)[0](coef, tb, shift1, 1 << log2_h);
95-
(ctx->fn_txb)[0](tb, coef, shift2, 1 << log2_w);
100+
(ctx->fn_txb)[0](coef, shift1, shift2, 1 << log2_h);
96101
}
97102

98103
static int oapv_quant(s16 *coef, u8 qp, int q_matrix[OAPV_BLK_D], int log2_w, int log2_h, int bit_depth, int deadzone_offset)

0 commit comments

Comments
 (0)