@@ -570,10 +570,9 @@ _PyJIT_translate_single_bytecode_to_trace(
570570{
571571
572572 int is_first_instr = tstate -> interp -> jit_state .jit_tracer_initial_instr == this_instr ;
573- bool progress_needed = (tstate -> interp -> jit_state .jit_tracer_initial_chain_depth % MAX_CHAIN_DEPTH ) == 0 && is_first_instr ;;
573+ bool progress_needed = (tstate -> interp -> jit_state .jit_tracer_initial_chain_depth % MAX_CHAIN_DEPTH ) == 0 ;;
574574 _PyBloomFilter * dependencies = & tstate -> interp -> jit_state .jit_tracer_dependencies ;
575575 _Py_BloomFilter_Add (dependencies , old_code );
576- _Py_CODEUNIT * target_instr = this_instr ;
577576 int trace_length = tstate -> interp -> jit_state .jit_tracer_code_curr_size ;
578577 _PyUOpInstruction * trace = tstate -> interp -> jit_state .jit_tracer_code_buffer ;
579578 int max_length = tstate -> interp -> jit_state .jit_tracer_code_max_size ;
@@ -585,11 +584,26 @@ _PyJIT_translate_single_bytecode_to_trace(
585584 lltrace = * python_lltrace - '0' ; // TODO: Parse an int and all that
586585 }
587586#endif
588-
587+ _Py_CODEUNIT * target_instr = this_instr ;
589588 uint32_t target = 0 ;
590589
591590 target = INSTR_IP (target_instr , old_code );
592591
592+ // Rewind EXTENDED_ARG so that we see the whole thing.
593+ // We must point to the first EXTENDED_ARG when deopting.
594+ int rewind_oparg = oparg ;
595+ while (rewind_oparg > 255 ) {
596+ rewind_oparg >>= 8 ;
597+ target -- ;
598+ }
599+ #ifdef Py_DEBUG
600+ if (oparg > 255 ) {
601+ assert (_Py_GetBaseCodeUnit (old_code , target ).op .code == EXTENDED_ARG );
602+ }
603+ #endif
604+
605+ DPRINTF (2 , "%p %d: %s(%d) %d\n" , old_code , target , _PyOpcode_OpName [opcode ], oparg , progress_needed );
606+
593607 bool needs_guard_ip = _PyOpcode_NeedsGuardIp [opcode ] &&
594608 !(opcode == FOR_ITER_RANGE || opcode == FOR_ITER_LIST || opcode == FOR_ITER_TUPLE ) &&
595609 !(opcode == JUMP_BACKWARD_NO_INTERRUPT || opcode == JUMP_BACKWARD || opcode == JUMP_BACKWARD_JIT ) &&
@@ -600,8 +614,7 @@ _PyJIT_translate_single_bytecode_to_trace(
600614 // This happens when a recursive call happens that we can't trace. Such as Python -> C -> Python calls
601615 // If we haven't guarded the IP, then it's untraceable.
602616 (frame != tstate -> interp -> jit_state .jit_tracer_current_frame && !needs_guard_ip ) ||
603- // TODO handle extended args.
604- oparg > 255 || opcode == EXTENDED_ARG ||
617+ (oparg > 0xFFFF ) ||
605618 // TODO handle BINARY_OP_INPLACE_ADD_UNICODE
606619 opcode == BINARY_OP_INPLACE_ADD_UNICODE ||
607620 // TODO (gh-140277): The constituent uops are invalid.
@@ -633,8 +646,6 @@ _PyJIT_translate_single_bytecode_to_trace(
633646
634647 tstate -> interp -> jit_state .jit_tracer_current_frame = frame ;
635648
636- DPRINTF (2 , "%p %d: %s(%d)\n" , old_code , target , _PyOpcode_OpName [opcode ], oparg );
637-
638649 if (opcode == NOP ) {
639650 return 1 ;
640651 }
@@ -643,6 +654,10 @@ _PyJIT_translate_single_bytecode_to_trace(
643654 return 1 ;
644655 }
645656
657+ if (opcode == EXTENDED_ARG ) {
658+ return 1 ;
659+ }
660+
646661 // One for possible _DEOPT, one because _CHECK_VALIDITY itself might _DEOPT
647662 max_length -= 2 ;
648663
@@ -663,7 +678,7 @@ _PyJIT_translate_single_bytecode_to_trace(
663678
664679 /* Special case the first instruction,
665680 * so that we can guarantee forward progress */
666- if (progress_needed && is_first_instr ) {
681+ if (progress_needed && tstate -> interp -> jit_state . jit_tracer_code_curr_size <= 2 ) {
667682 if (OPCODE_HAS_EXIT (opcode ) || OPCODE_HAS_DEOPT (opcode )) {
668683 opcode = _PyOpcode_Deopt [opcode ];
669684 }
@@ -695,12 +710,13 @@ _PyJIT_translate_single_bytecode_to_trace(
695710 case POP_JUMP_IF_FALSE :
696711 case POP_JUMP_IF_TRUE :
697712 {
698- RESERVE (1 );
699- _Py_CODEUNIT * computed_next_instr = target_instr + 1 + _PyOpcode_Caches [_PyOpcode_Deopt [opcode ]];
700- _Py_CODEUNIT * computed_jump_instr = computed_next_instr + oparg ;
701- int jump_likely = computed_jump_instr == next_instr ;
702- uint32_t uopcode = BRANCH_TO_GUARD [opcode - POP_JUMP_IF_FALSE ][jump_likely ];
703- ADD_TO_TRACE (uopcode , 0 , 0 , INSTR_IP (jump_likely ? computed_next_instr : computed_jump_instr , old_code ));
713+ _Py_CODEUNIT * computed_next_instr_without_modifiers = target_instr + 1 + _PyOpcode_Caches [_PyOpcode_Deopt [opcode ]];
714+ _Py_CODEUNIT * computed_next_instr = computed_next_instr_without_modifiers + (computed_next_instr_without_modifiers -> op .code == NOT_TAKEN );
715+ _Py_CODEUNIT * computed_jump_instr = computed_next_instr_without_modifiers + oparg ;
716+ assert (next_instr == computed_next_instr || next_instr == computed_jump_instr );
717+ int jump_happened = computed_jump_instr == next_instr ;
718+ uint32_t uopcode = BRANCH_TO_GUARD [opcode - POP_JUMP_IF_FALSE ][jump_happened ];
719+ ADD_TO_TRACE (uopcode , 0 , 0 , INSTR_IP (jump_happened ? computed_next_instr : computed_jump_instr , old_code ));
704720 break ;
705721 }
706722 case JUMP_BACKWARD_JIT :
@@ -731,8 +747,10 @@ _PyJIT_translate_single_bytecode_to_trace(
731747 assert (nuops > 0 );
732748 RESERVE (nuops + 1 ); /* One extra for exit */
733749 uint32_t orig_oparg = oparg ; // For OPARG_TOP/BOTTOM
750+ uint32_t orig_target = target ;
734751 for (int i = 0 ; i < nuops ; i ++ ) {
735752 oparg = orig_oparg ;
753+ target = orig_target ;
736754 uint32_t uop = expansion -> uops [i ].uop ;
737755 uint64_t operand = 0 ;
738756 // Add one to account for the actual opcode/oparg pair:
@@ -751,9 +769,11 @@ _PyJIT_translate_single_bytecode_to_trace(
751769 operand = read_u64 (& this_instr [offset ].cache );
752770 break ;
753771 case OPARG_TOP : // First half of super-instr
772+ assert (orig_oparg <= 255 );
754773 oparg = orig_oparg >> 4 ;
755774 break ;
756775 case OPARG_BOTTOM : // Second half of super-instr
776+ assert (orig_oparg <= 255 );
757777 oparg = orig_oparg & 0xF ;
758778 break ;
759779 case OPARG_SAVE_RETURN_OFFSET : // op=_SAVE_RETURN_OFFSET; oparg=return_offset
@@ -768,13 +788,15 @@ _PyJIT_translate_single_bytecode_to_trace(
768788 if (uop == _TIER2_RESUME_CHECK ) {
769789 target = next_inst ;
770790 }
771- #ifdef Py_DEBUG
772791 else if (uop != _FOR_ITER_TIER_TWO ) {
773- uint32_t jump_target = next_inst + oparg ;
792+ int extended_arg = orig_oparg > 255 ;
793+ uint32_t jump_target = next_inst + orig_oparg + extended_arg ;
774794 assert (_Py_GetBaseCodeUnit (old_code , jump_target ).op .code == END_FOR );
775795 assert (_Py_GetBaseCodeUnit (old_code , jump_target + 1 ).op .code == POP_ITER );
796+ if (is_for_iter_test [uop ]) {
797+ target = jump_target + 1 ;
798+ }
776799 }
777- #endif
778800 break ;
779801 case OPERAND1_1 :
780802 assert (trace [trace_length - 1 ].opcode == uop );
@@ -859,11 +881,12 @@ _PyJIT_InitializeTracing(PyThreadState *tstate, _PyInterpreterFrame *frame, _Py_
859881 lltrace = * python_lltrace - '0' ; // TODO: Parse an int and all that
860882 }
861883 DPRINTF (2 ,
862- "Tracing %s (%s:%d) at byte offset %d\n" ,
884+ "Tracing %s (%s:%d) at byte offset %d at chain depth %d \n" ,
863885 PyUnicode_AsUTF8 (code -> co_qualname ),
864886 PyUnicode_AsUTF8 (code -> co_filename ),
865887 code -> co_firstlineno ,
866- 2 * INSTR_IP (next_instr , code ));
888+ 2 * INSTR_IP (next_instr , code ),
889+ chain_depth );
867890#endif
868891 add_to_trace (tstate -> interp -> jit_state .jit_tracer_code_buffer , 0 , _START_EXECUTOR , 0 , (uintptr_t )next_instr , INSTR_IP (next_instr , code ));
869892 add_to_trace (tstate -> interp -> jit_state .jit_tracer_code_buffer , 1 , _MAKE_WARM , 0 , 0 , 0 );
@@ -976,12 +999,6 @@ prepare_for_execution(_PyUOpInstruction *buffer, int length)
976999 exit_op = _DYNAMIC_EXIT ;
9771000 unique_target = true;
9781001 }
979- if (is_for_iter_test [opcode ]) {
980- /* Target the POP_TOP immediately after the END_FOR,
981- * leaving only the iterator on the stack. */
982- int32_t next_inst = target + 1 + INLINE_CACHE_ENTRIES_FOR_ITER ;
983- jump_target = next_inst + inst -> oparg + 1 ;
984- }
9851002 if (unique_target || jump_target != current_jump_target || current_exit_op != exit_op ) {
9861003 make_exit (& buffer [next_spare ], exit_op , jump_target );
9871004 current_exit_op = exit_op ;
0 commit comments