|
| 1 | +import functools |
| 2 | +import time |
| 3 | + |
| 4 | +# TODO (tharitt): later move to env file or something |
| 5 | +ENABLE_BENCHMARK = True |
| 6 | + |
| 7 | +# This function is to be instrumented throughout the targeted function |
| 8 | +def mark(label): |
| 9 | + if _current_mark_func is not None: |
| 10 | + _current_mark_func(label) |
| 11 | + |
| 12 | +# Global hook - this will be re-assigned (points to) |
| 13 | +# the function defined in benchmark wrapper |
| 14 | +_current_mark_func = None |
| 15 | + |
| 16 | +def benchmark(func): |
| 17 | + """A wrapper for code injection for time measurement. |
| 18 | +
|
| 19 | + This wrapper allows users to put a call to mark() |
| 20 | + anywhere inside the wrapped function. The function mark() |
| 21 | + is defined in the global scope to be a placeholder for the targeted |
| 22 | + function to import. This wrapper will make it points to local_mark() defined |
| 23 | + in this function. Therefore, the wrapped function will be able call |
| 24 | + local_mark(). All the context for local_mark() like mark list can be |
| 25 | + hidden from users and thus provide clean interface. |
| 26 | +
|
| 27 | + Parameters |
| 28 | + ---------- |
| 29 | + func : :obj:`callable`, optional |
| 30 | + Function to be decorated. |
| 31 | + """ |
| 32 | + |
| 33 | + # Zero-overhead |
| 34 | + if not ENABLE_BENCHMARK: |
| 35 | + return func |
| 36 | + |
| 37 | + @functools.wraps(func) |
| 38 | + def wrapper(*args, **kwargs): |
| 39 | + marks = [] |
| 40 | + |
| 41 | + # currently this simply record the user-define label and record time |
| 42 | + def local_mark(label): |
| 43 | + marks.append((label, time.perf_counter())) |
| 44 | + |
| 45 | + global _current_mark_func |
| 46 | + _current_mark_func = local_mark |
| 47 | + |
| 48 | + # the mark() called in wrapped function will now call local_mark |
| 49 | + result = func(*args, **kwargs) |
| 50 | + # clean up to original state |
| 51 | + _current_mark_func = None |
| 52 | + |
| 53 | + # TODO (tharitt): maybe changing to saving results to file instead |
| 54 | + if marks: |
| 55 | + prev_label, prev_t = marks[0] |
| 56 | + print(f"[BENCH] {prev_label}: 0.000000s") |
| 57 | + for label, t in marks[1:]: |
| 58 | + print(f"[BENCH] {label}: {t - prev_t:.6f}s since '{prev_label}'") |
| 59 | + prev_label, prev_t = label, t |
| 60 | + return result |
| 61 | + |
| 62 | + return wrapper |
0 commit comments