Skip to content

Commit e05054f

Browse files
committed
Add fast-path overflow operations with batched profiler counting
1 parent 69ab1ef commit e05054f

File tree

2 files changed

+98
-7
lines changed

2 files changed

+98
-7
lines changed

rpython/jit/codewriter/genextension.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,10 @@ def generate(self):
157157
self.jitcode._genext_source = "\n".join(allcode)
158158
# Import rop for opnum constants used in type-specialized recording
159159
from rpython.jit.metainterp.resoperation import rop
160+
from rpython.rlib.rarithmetic import ovfcheck
160161
d = {"ConstInt": ConstInt, "ConstPtr": ConstPtr, "ConstFloat": ConstFloat, "JitCode": JitCode, "ChangeFrame": ChangeFrame,
161162
"lltype": lltype, "rstr": rstr, 'llmemory': llmemory, 'OBJECTPTR': OBJECTPTR, 'support': support,
162-
'rop': rop}
163+
'rop': rop, 'ovfcheck': ovfcheck}
163164
d.update(self.globals)
164165
source = py.code.Source(self.jitcode._genext_source)
165166
exec source.compile() in d
@@ -2137,6 +2138,62 @@ def emit_unspecialized_goto_if_not_int_ne(self):
21372138
def emit_unspecialized_goto_if_not_int_eq(self):
21382139
return self._emit_goto_if_not_int_comparison_fast("INT_EQ", "==")
21392140

2141+
def _emit_int_ovf_fast(self, rop_name, py_op):
2142+
lines = []
2143+
_, label, arg0, arg1, _, result = self.insn
2144+
2145+
target_pc = self.get_target_pc(label)
2146+
2147+
self._emit_n_ary_if([arg0, arg1], lines)
2148+
specializer = self.work_list.specialize_insn(
2149+
self.insn, self.constant_registers.union({arg0, arg1}), self.orig_pc)
2150+
lines.append(" pc = %d" % (specializer.get_pc(),))
2151+
lines.append(" continue")
2152+
2153+
# Fast-path: compute with overflow check, record directly, skip heapcache
2154+
self._emit_sync_registers(lines)
2155+
box0 = self._get_as_box_after_sync(arg0)
2156+
box1 = self._get_as_box_after_sync(arg1)
2157+
lines.append("_v0 = %s" % self._get_as_unboxed_after_sync(arg0))
2158+
lines.append("_v1 = %s" % self._get_as_unboxed_after_sync(arg1))
2159+
2160+
lines.append("self.metainterp.ovf_flag = False")
2161+
lines.append("try:")
2162+
lines.append(" _res = ovfcheck(_v0 %s _v1)" % py_op)
2163+
lines.append("except OverflowError:")
2164+
lines.append(" self.metainterp.ovf_flag = True")
2165+
lines.append(" _res = 0")
2166+
2167+
lines.append("# fast-path: record overflow op directly, skip heapcache")
2168+
lines.append("_op = self.metainterp.history.record2_int(rop.%s, %s, %s, _res)" % (
2169+
rop_name, box0, box1))
2170+
lines.append("self.registers_i[%d] = _op" % result.index)
2171+
lines.append("i%d = _res" % result.index)
2172+
2173+
lines.append("self.handle_possible_overflow_error(%d, %d, _op)" % (target_pc, self.orig_pc))
2174+
lines.append("pc = self.pc")
2175+
lines.append("if pc == %s:" % (target_pc,))
2176+
specializer = self.work_list.specialize_pc(
2177+
self.constant_registers - {result}, target_pc)
2178+
lines.append(" pc = %s" % (specializer.spec_pc,))
2179+
lines.append("else:")
2180+
next_pc = self.work_list.pc_to_nextpc[self.orig_pc]
2181+
specializer = self.work_list.specialize_pc(
2182+
self.constant_registers - {result}, next_pc)
2183+
lines.append(" assert self.pc == %s" % (specializer.orig_pc,))
2184+
lines.append(" pc = %s" % (specializer.spec_pc,))
2185+
lines.append("continue")
2186+
return lines
2187+
2188+
def emit_unspecialized_int_add_jump_if_ovf(self):
2189+
return self._emit_int_ovf_fast("INT_ADD_OVF", "+")
2190+
2191+
def emit_unspecialized_int_sub_jump_if_ovf(self):
2192+
return self._emit_int_ovf_fast("INT_SUB_OVF", "-")
2193+
2194+
def emit_unspecialized_int_mul_jump_if_ovf(self):
2195+
return self._emit_int_ovf_fast("INT_MUL_OVF", "*")
2196+
21402197
def emit_unspecialized_switch(self):
21412198
lines = []
21422199
arg0, descr = self._get_args()

rpython/jit/metainterp/pyjitpl.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2545,6 +2545,12 @@ def __init__(self, staticdata, jitdriver_sd, force_finish_trace=False):
25452545

25462546
self.box_names_memo = {}
25472547

2548+
# Batched profiler counting for fast-path operations
2549+
# Instead of calling profiler.count_ops() for every operation,
2550+
# we batch counts and flush them at strategic points (guards, end of trace)
2551+
self._batched_ops_count = 0
2552+
self._batched_recorded_ops_count = 0
2553+
25482554
self.aborted_tracing_jitdriver = None
25492555
self.aborted_tracing_greenkey = None
25502556

@@ -2554,6 +2560,30 @@ def __init__(self, staticdata, jitdriver_sd, force_finish_trace=False):
25542560
self.force_finish_trace = force_finish_trace
25552561
self.trace_length_at_last_tco = -1
25562562

2563+
@always_inline
2564+
def batch_op_count(self):
2565+
"""Batch an operation count instead of calling profiler directly.
2566+
2567+
This is used by fast-path recording methods to avoid the overhead
2568+
of calling profiler.count_ops() for every operation. Counts are
2569+
flushed at strategic points (guards, end of trace).
2570+
"""
2571+
self._batched_ops_count += 1
2572+
self._batched_recorded_ops_count += 1
2573+
2574+
def flush_batched_counts(self):
2575+
"""Flush batched operation counts to the profiler.
2576+
2577+
Called before guards and at end of trace to ensure accurate counting.
2578+
"""
2579+
profiler = self.staticdata.profiler
2580+
if self._batched_ops_count > 0:
2581+
profiler.count(Counters.OPS, self._batched_ops_count)
2582+
self._batched_ops_count = 0
2583+
if self._batched_recorded_ops_count > 0:
2584+
profiler.count(Counters.RECORDED_OPS, self._batched_recorded_ops_count)
2585+
self._batched_recorded_ops_count = 0
2586+
25572587
def retrace_needed(self, trace, exported_state):
25582588
self.partial_trace = trace
25592589
self.retracing_from = self.potential_retrace_position
@@ -2699,6 +2729,8 @@ def check_recursion_invariant(self):
26992729
raise AssertionError
27002730

27012731
def generate_guard(self, opnum, box=None, extraarg=None, resumepc=-1):
2732+
# Flush batched counts before guard to ensure accurate profiling
2733+
self.flush_batched_counts()
27022734
if isinstance(box, Const): # no need for a guard
27032735
return
27042736
if opnum == rop.GUARD_EXCEPTION:
@@ -2848,8 +2880,8 @@ def _record_helper(self, opnum, resvalue, descr, *argboxes):
28482880
def _record_int_binop(self, opnum, resvalue, b1, b2):
28492881
if not we_are_translated():
28502882
PyjitplCounters._record_int_binop_calls += 1
2851-
profiler = self.staticdata.profiler
2852-
profiler.count_ops(opnum, Counters.RECORDED_OPS)
2883+
# Use batched counting instead of calling profiler directly
2884+
self.batch_op_count()
28532885
if self.framestack:
28542886
self.framestack[-1].jitcode.traced_operations += 1
28552887
op = self.history.record2_int(opnum, b1, b2, resvalue)
@@ -2859,8 +2891,8 @@ def _record_int_binop(self, opnum, resvalue, b1, b2):
28592891
def _record_int_unop(self, opnum, resvalue, b1):
28602892
if not we_are_translated():
28612893
PyjitplCounters._record_int_binop_calls += 1
2862-
profiler = self.staticdata.profiler
2863-
profiler.count_ops(opnum, Counters.RECORDED_OPS)
2894+
# Use batched counting instead of calling profiler directly
2895+
self.batch_op_count()
28642896
if self.framestack:
28652897
self.framestack[-1].jitcode.traced_operations += 1
28662898
op = self.history.record1_int(opnum, b1, resvalue)
@@ -2870,8 +2902,8 @@ def _record_int_unop(self, opnum, resvalue, b1):
28702902
def _record_float_binop(self, opnum, resvalue, b1, b2):
28712903
if not we_are_translated():
28722904
PyjitplCounters._record_int_binop_calls += 1
2873-
profiler = self.staticdata.profiler
2874-
profiler.count_ops(opnum, Counters.RECORDED_OPS)
2905+
# Use batched counting instead of calling profiler directly
2906+
self.batch_op_count()
28752907
if self.framestack:
28762908
self.framestack[-1].jitcode.traced_operations += 1
28772909
op = self.history.record2_float(opnum, b1, b2, resvalue)
@@ -3060,6 +3092,7 @@ def compile_and_run_once(self, jitdriver_sd, *args):
30603092
original_boxes = self.initialize_original_boxes(jitdriver_sd, *args)
30613093
return self._compile_and_run_once(original_boxes)
30623094
finally:
3095+
self.flush_batched_counts() # Flush before end_tracing
30633096
self.staticdata.profiler.end_tracing()
30643097
debug_stop('jit-tracing')
30653098

@@ -3101,6 +3134,7 @@ def handle_guard_failure(self, resumedescr, deadframe):
31013134
self.run_blackhole_interp_to_cancel_tracing(stb)
31023135
finally:
31033136
self.resumekey_original_loop_token = None
3137+
self.flush_batched_counts() # Flush before end_tracing
31043138
self.staticdata.profiler.end_tracing()
31053139
debug_stop('jit-tracing')
31063140

0 commit comments

Comments
 (0)