Skip to content

Commit 6bd1ab5

Browse files
Whatcookieelad335
authored andcommitted
RSX: Use AVX-512-ICL code in get_vertecx_program_ucode_hash and in vertex_program_compare
- Code is about 4 times as fast on my zen4 machine - Should be twice as fast on zen5 machines with full width AVX-512
1 parent 8e6272b commit 6bd1ab5

File tree

1 file changed

+186
-19
lines changed

1 file changed

+186
-19
lines changed

rpcs3/Emu/RSX/Program/ProgramStateCache.cpp

Lines changed: 186 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "stdafx.h"
22
#include "ProgramStateCache.h"
33
#include "Emu/system_config.h"
4+
#include "util/sysinfo.hpp"
45

56
#include <stack>
67

@@ -21,31 +22,119 @@
2122
#endif
2223
#endif
2324

25+
#ifdef ARCH_ARM64
26+
#define AVX512_ICL_FUNC
27+
#endif
28+
29+
#ifdef _MSC_VER
30+
#define AVX512_ICL_FUNC
31+
#else
32+
#define AVX512_ICL_FUNC __attribute__((__target__("avx512f,avx512bw,avx512dq,avx512cd,avx512vl,avx512bitalg,avx512ifma,avx512vbmi,avx512vbmi2,avx512vnni,avx512vpopcntdq")))
33+
#endif
34+
35+
2436
using namespace program_hash_util;
2537

26-
usz vertex_program_utils::get_vertex_program_ucode_hash(const RSXVertexProgram &program)
38+
AVX512_ICL_FUNC usz vertex_program_utils::get_vertex_program_ucode_hash(const RSXVertexProgram &program)
2739
{
28-
// Checksum as hash with rotated data
29-
const void* instbuffer = program.data.data();
30-
u32 instIndex = 0;
31-
usz acc0 = 0;
32-
usz acc1 = 0;
33-
34-
do
40+
#ifdef ARCH_X64
41+
if (utils::has_avx512_icl())
3542
{
36-
if (program.instruction_mask[instIndex])
43+
// Load all elements of the instruction_mask bitset
44+
const __m512i* instMask512 = reinterpret_cast<const __m512i*>(&program.instruction_mask);
45+
const __m128i* instMask128 = reinterpret_cast<const __m128i*>(&program.instruction_mask);
46+
47+
const __m512i lowerMask = _mm512_loadu_si512(instMask512);
48+
const __m128i upper128 = _mm_loadu_si128(instMask128 + 4);
49+
const __m512i upperMask = _mm512_zextsi128_si512(upper128);
50+
51+
__m512i maskIndex = _mm512_setzero_si512();
52+
const __m512i negativeOnes = _mm512_set1_epi64(-1);
53+
54+
// Special masks to test against bitset
55+
const __m512i testMask0 = _mm512_set_epi64(
56+
0x0808080808080808,
57+
0x0808080808080808,
58+
0x0404040404040404,
59+
0x0404040404040404,
60+
0x0202020202020202,
61+
0x0202020202020202,
62+
0x0101010101010101,
63+
0x0101010101010101);
64+
65+
const __m512i testMask1 = _mm512_set_epi64(
66+
0x8080808080808080,
67+
0x8080808080808080,
68+
0x4040404040404040,
69+
0x4040404040404040,
70+
0x2020202020202020,
71+
0x2020202020202020,
72+
0x1010101010101010,
73+
0x1010101010101010);
74+
75+
const __m512i* instBuffer = reinterpret_cast<const __m512i*>(program.data.data());
76+
__m512i acc0 = _mm512_setzero_si512();
77+
__m512i acc1 = _mm512_setzero_si512();
78+
79+
__m512i rotMask0 = _mm512_set_epi64(7, 6, 5, 4, 3, 2, 1, 0);
80+
__m512i rotMask1 = _mm512_set_epi64(15, 14, 13, 12, 11, 10, 9, 8);
81+
__m512i rotMaskAdd = _mm512_set_epi64(16, 16, 16, 16, 16, 16, 16, 16);
82+
83+
u32 instIndex = 0;
84+
85+
// If there is remainder, add an extra (masked) iteration
86+
u32 extraIteration = (program.data.size() % 32 != 0) ? 1 : 0;
87+
u32 length = (program.data.size() / 32) + extraIteration;
88+
89+
// The instruction mask will prevent us from reading out of bounds, we do not need a seperate masked loop
90+
// for the remainder, or a scalar loop.
91+
while (instIndex < (length))
3792
{
38-
const auto inst = v128::loadu(instbuffer, instIndex);
39-
usz tmp0 = std::rotr(inst._u64[0], instIndex * 2);
40-
acc0 += tmp0;
41-
usz tmp1 = std::rotr(inst._u64[1], (instIndex * 2) + 1);
42-
acc1 += tmp1;
93+
const __m512i masks = _mm512_permutex2var_epi8(lowerMask, maskIndex, upperMask);
94+
const __mmask8 result0 = _mm512_test_epi64_mask(masks, testMask0);
95+
const __mmask8 result1 = _mm512_test_epi64_mask(masks, testMask1);
96+
const __m512i load0 = _mm512_maskz_loadu_epi64(result0, (instBuffer + instIndex * 2));
97+
const __m512i load1 = _mm512_maskz_loadu_epi64(result1, (instBuffer + (instIndex * 2)+ 1));
98+
99+
const __m512i rotated0 = _mm512_rorv_epi64(load0, rotMask0);
100+
const __m512i rotated1 = _mm512_rorv_epi64(load1, rotMask1);
101+
102+
acc0 = _mm512_add_epi64(acc0, rotated0);
103+
acc1 = _mm512_add_epi64(acc1, rotated1);
104+
105+
rotMask0 = _mm512_add_epi64(rotMask0, rotMaskAdd);
106+
rotMask1 = _mm512_add_epi64(rotMask1, rotMaskAdd);
107+
maskIndex = _mm512_sub_epi8(maskIndex, negativeOnes);
108+
109+
instIndex++;
43110
}
44111

45-
instIndex++;
46-
} while (instIndex < (program.data.size() / 4));
112+
const __m512i result = _mm512_add_epi64(acc0, acc1);
113+
return _mm512_reduce_add_epi64(result);
114+
}
115+
#endif
116+
117+
// Checksum as hash with rotated data
118+
const void* instbuffer = program.data.data();
119+
u32 instIndex = 0;
120+
usz acc0 = 0;
121+
usz acc1 = 0;
122+
123+
do
124+
{
125+
if (program.instruction_mask[instIndex])
126+
{
127+
const auto inst = v128::loadu(instbuffer, instIndex);
128+
usz tmp0 = std::rotr(inst._u64[0], instIndex * 2);
129+
acc0 += tmp0;
130+
usz tmp1 = std::rotr(inst._u64[1], (instIndex * 2) + 1);
131+
acc1 += tmp1;
132+
}
133+
134+
instIndex++;
135+
} while (instIndex < (program.data.size() / 4));
47136
return acc0 + acc1;
48-
}
137+
}
49138

50139
vertex_program_utils::vertex_program_metadata vertex_program_utils::analyse_vertex_program(const u32* data, u32 entry, RSXVertexProgram& dst_prog)
51140
{
@@ -350,7 +439,7 @@ usz vertex_program_storage_hash::operator()(const RSXVertexProgram &program) con
350439
return rpcs3::hash64(ucode_hash, metadata_hash);
351440
}
352441

353-
bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const RSXVertexProgram &binary2) const
442+
AVX512_ICL_FUNC bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const RSXVertexProgram &binary2) const
354443
{
355444
if (binary1.output_mask != binary2.output_mask)
356445
return false;
@@ -363,10 +452,88 @@ bool vertex_program_compare::operator()(const RSXVertexProgram &binary1, const R
363452
if (binary1.jump_table != binary2.jump_table)
364453
return false;
365454

455+
#ifdef ARCH_X64
456+
if (utils::has_avx512_icl())
457+
{
458+
// Load all elements of the instruction_mask bitset
459+
const __m512i* instMask512 = reinterpret_cast<const __m512i*>(&binary1.instruction_mask);
460+
const __m128i* instMask128 = reinterpret_cast<const __m128i*>(&binary1.instruction_mask);
461+
462+
const __m512i lowerMask = _mm512_loadu_si512(instMask512);
463+
const __m128i upper128 = _mm_loadu_si128(instMask128 + 4);
464+
const __m512i upperMask = _mm512_zextsi128_si512(upper128);
465+
466+
__m512i maskIndex = _mm512_setzero_si512();
467+
const __m512i negativeOnes = _mm512_set1_epi64(-1);
468+
469+
// Special masks to test against bitset
470+
const __m512i testMask0 = _mm512_set_epi64(
471+
0x0808080808080808,
472+
0x0808080808080808,
473+
0x0404040404040404,
474+
0x0404040404040404,
475+
0x0202020202020202,
476+
0x0202020202020202,
477+
0x0101010101010101,
478+
0x0101010101010101);
479+
480+
const __m512i testMask1 = _mm512_set_epi64(
481+
0x8080808080808080,
482+
0x8080808080808080,
483+
0x4040404040404040,
484+
0x4040404040404040,
485+
0x2020202020202020,
486+
0x2020202020202020,
487+
0x1010101010101010,
488+
0x1010101010101010);
489+
490+
const __m512i* instBuffer1 = reinterpret_cast<const __m512i*>(binary1.data.data());
491+
const __m512i* instBuffer2 = reinterpret_cast<const __m512i*>(binary2.data.data());
492+
493+
// If there is remainder, add an extra (masked) iteration
494+
u32 extraIteration = (binary1.data.size() % 32 != 0) ? 1 : 0;
495+
u32 length = (binary1.data.size() / 32) + extraIteration;
496+
497+
u32 instIndex = 0;
498+
499+
// The instruction mask will prevent us from reading out of bounds, we do not need a seperate masked loop
500+
// for the remainder, or a scalar loop.
501+
while (instIndex < (length))
502+
{
503+
const __m512i masks = _mm512_permutex2var_epi8(lowerMask, maskIndex, upperMask);
504+
505+
const __mmask8 result0 = _mm512_test_epi64_mask(masks, testMask0);
506+
const __mmask8 result1 = _mm512_test_epi64_mask(masks, testMask1);
507+
508+
const __m512i load0 = _mm512_maskz_loadu_epi64(result0, (instBuffer1 + (instIndex * 2)));
509+
const __m512i load1 = _mm512_maskz_loadu_epi64(result0, (instBuffer2 + (instIndex * 2)));
510+
const __m512i load2 = _mm512_maskz_loadu_epi64(result1, (instBuffer1 + (instIndex * 2) + 1));
511+
const __m512i load3 = _mm512_maskz_loadu_epi64(result1, (instBuffer2 + (instIndex * 2)+ 1));
512+
513+
const __mmask8 res0 = _mm512_cmpneq_epi64_mask(load0, load1);
514+
const __mmask8 res1 = _mm512_cmpneq_epi64_mask(load2, load3);
515+
516+
const u8 result = _kortestz_mask8_u8(res0, res1);
517+
518+
//kortestz will set result to 1 if all bits are zero, so invert the check for result
519+
if (!result)
520+
{
521+
return false;
522+
}
523+
524+
maskIndex = _mm512_sub_epi8(maskIndex, negativeOnes);
525+
526+
instIndex++;
527+
}
528+
529+
return true;
530+
}
531+
#endif
532+
366533
const void* instBuffer1 = binary1.data.data();
367534
const void* instBuffer2 = binary2.data.data();
368535
usz instIndex = 0;
369-
for (unsigned i = 0; i < binary1.data.size() / 4; i++)
536+
while (instIndex < (binary1.data.size() / 4))
370537
{
371538
if (binary1.instruction_mask[instIndex])
372539
{

0 commit comments

Comments
 (0)