diff --git a/Python/optimizer.c b/Python/optimizer.c index 6fc5eabdf8b44e..801abc25e23eca 100644 --- a/Python/optimizer.c +++ b/Python/optimizer.c @@ -17,6 +17,10 @@ #include #include +#ifdef USE_SIMD // HACL_CAN_COMPILE_VEC256 +#include +#endif + #define NEED_OPCODE_METADATA #include "pycore_uop_metadata.h" // Uop tables #undef NEED_OPCODE_METADATA @@ -307,7 +311,7 @@ static int executor_traverse(PyObject *o, visitproc visit, void *arg) { _PyExecutorObject *executor = (_PyExecutorObject *)o; - for (uint32_t i = 0; i < executor->exit_count; i++) { + for (uint32_t i = 0; i < executor->; i++) { Py_VISIT(executor->exits[i].executor); } return 0; @@ -923,17 +927,53 @@ translate_bytecode_to_trace( #define BIT_IS_SET(array, bit) (array[(bit)>>5] & (1<<((bit)&31))) /* Count the number of unused uops and exits +* An optimized version of SIMD is used. */ static int count_exits(_PyUOpInstruction *buffer, int length) { int exit_count = 0; +#if defined(USE_SIMD) + // Use SIMD instructions for optimization of counting + // Assume that _PyUOpInstruction contains only opcode + // and its size is a multiple of the SIMD register size + + // For AVX2 (256-bit registers) + const __m256i exit_code = _mm256_set1_epi32(_EXIT_TRACE); + const __m256i zero = _mm256_setzero_si256(); + + int i; + + // Process data in blocks of 8 elements (256 bits / 32 bits) + for (i = 0; i < length - 7; i += 8) { + __m256i vec = _mm256_load_si256((const __m256i*)(buffer + i)); + __m256i cmp = _mm256_cmpeq_epi32(vec, exit_code); + __m256i sum = _mm256_add_epi32(cmp, zero); + + // Sum horizontal values + __m128i lo = _mm256_extracti128_si256(sum, 0); + __m128i hi = _mm256_extracti128_si256(sum, 1); + __m128i total = _mm_add_epi32(lo, hi); + + int result; + _mm_store_si128((__m128i*)&result, total); + exit_count += _mm_extract_epi32(total, 0) + _mm_extract_epi32(total, 1); + } + + // Process remaining elements + for (; i < length; i++) { + if (buffer[i].opcode == _EXIT_TRACE) { + exit_count++; + } + } +#else for (int i = 0; i < length; i++) { int opcode = buffer[i].opcode; if (opcode == _EXIT_TRACE) { exit_count++; } } +#endif return exit_count; }