|  | 
| 1 | 1 | import functools | 
| 2 | 2 | import time | 
|  | 3 | +from typing import Callable, Optional | 
|  | 4 | +from mpi4py import MPI | 
|  | 5 | + | 
| 3 | 6 | 
 | 
| 4 | 7 | # TODO (tharitt): later move to env file or something | 
| 5 | 8 | ENABLE_BENCHMARK = True | 
| 6 | 9 | 
 | 
| 7 | 10 | 
 | 
| 8 |  | -# This function is to be instrumented throughout the targeted function | 
|  | 11 | +# This function allows users to measure time arbitary lines of the function | 
| 9 | 12 | def mark(label): | 
| 10 |  | -    if _current_mark_func is not None: | 
| 11 |  | -        _current_mark_func(label) | 
|  | 13 | +    if not _mark_func_stack: | 
|  | 14 | +        raise RuntimeError("mark() called outside of a benchmarked region") | 
|  | 15 | +    _mark_func_stack[-1](label) | 
| 12 | 16 | 
 | 
| 13 | 17 | 
 | 
| 14 |  | -# Global hook - this will be re-assigned (points to) | 
| 15 |  | -# the function defined in benchmark wrapper | 
| 16 |  | -_current_mark_func = None | 
|  | 18 | +# Stack of active mark functions for nested support | 
|  | 19 | +_mark_func_stack = [] | 
| 17 | 20 | 
 | 
| 18 | 21 | 
 | 
| 19 |  | -def benchmark(func): | 
|  | 22 | +def benchmark(func: Optional[Callable] = None, | 
|  | 23 | +              description="", | 
|  | 24 | +              save_file=False, | 
|  | 25 | +              file_path='benchmark.log' | 
|  | 26 | +              ): | 
| 20 | 27 |     """A wrapper for code injection for time measurement. | 
| 21 | 28 | 
 | 
| 22 |  | -    This wrapper allows users to put a call to mark() | 
| 23 |  | -    anywhere inside the wrapped function. The function mark() | 
| 24 |  | -    is defined in the global scope to be a placeholder for the targeted | 
| 25 |  | -    function to import. This wrapper will make it points to local_mark() defined | 
| 26 |  | -    in this function. Therefore, the wrapped function will be able call | 
| 27 |  | -    local_mark(). All the context for local_mark() like mark list can be | 
| 28 |  | -    hidden from users and thus provide clean interface. | 
|  | 29 | +    This wrapper measure the start-to-end time of the wrapped function when | 
|  | 30 | +    decorated without any argument. | 
|  | 31 | +    It also allows users to put a call to mark() anywhere inside the wrapped function  | 
|  | 32 | +    for fine-grain time benchmark. This wrapper defines the local_mark() and push it | 
|  | 33 | +    the the _mark_func_stack for isolation in case of nested call.  | 
|  | 34 | +    The user-facing mark() will always call the function at the top of the _mark_func_stack. | 
| 29 | 35 | 
 | 
| 30 | 36 |     Parameters | 
| 31 | 37 |     ---------- | 
| 32 | 38 |     func : :obj:`callable`, optional | 
| 33 |  | -        Function to be decorated. | 
|  | 39 | +        Function to be decorated. Defaults to ``None``. | 
|  | 40 | +    description : :obj:`str`, optional | 
|  | 41 | +        Description for the output text. Defaults to ``''``. | 
|  | 42 | +    save_file : :obj:`bool`, optional | 
|  | 43 | +        Flag for saving file to a disk. Otherwise, the result will output to stdout. Defaults to ``False`` | 
|  | 44 | +    file_path : :obj:`str`, optional | 
|  | 45 | +        File path for saving the output. Defaults to ``benchmark.log`` | 
|  | 46 | +
 | 
| 34 | 47 |     """ | 
| 35 | 48 | 
 | 
| 36 | 49 |     # Zero-overhead | 
| 37 | 50 |     if not ENABLE_BENCHMARK: | 
| 38 | 51 |         return func | 
| 39 | 52 | 
 | 
| 40 | 53 |     @functools.wraps(func) | 
| 41 |  | -    def wrapper(*args, **kwargs): | 
| 42 |  | -        marks = [] | 
| 43 |  | - | 
| 44 |  | -        # currently this simply record the user-define label and record time | 
| 45 |  | -        def local_mark(label): | 
| 46 |  | -            marks.append((label, time.perf_counter())) | 
| 47 |  | - | 
| 48 |  | -        global _current_mark_func | 
| 49 |  | -        _current_mark_func = local_mark | 
| 50 |  | - | 
| 51 |  | -        # the mark() called in wrapped function will now call local_mark | 
| 52 |  | -        result = func(*args, **kwargs) | 
| 53 |  | -        # clean up to original state | 
| 54 |  | -        _current_mark_func = None | 
| 55 |  | - | 
| 56 |  | -        # TODO (tharitt): maybe changing to saving results to file instead | 
| 57 |  | -        if marks: | 
| 58 |  | -            prev_label, prev_t = marks[0] | 
| 59 |  | -            print(f"[BENCH] {prev_label}: 0.000000s") | 
| 60 |  | -            for label, t in marks[1:]: | 
| 61 |  | -                print(f"[BENCH] {label}: {t - prev_t:.6f}s since '{prev_label}'") | 
| 62 |  | -                prev_label, prev_t = label, t | 
| 63 |  | -        return result | 
| 64 |  | - | 
| 65 |  | -    return wrapper | 
|  | 54 | +    def decorator(func): | 
|  | 55 | +        def wrapper(*args, **kwargs): | 
|  | 56 | +            rank = MPI.COMM_WORLD.Get_rank() | 
|  | 57 | + | 
|  | 58 | +            # Here we rely on the closure property of Python. | 
|  | 59 | +            # This marks will isolate from (shadow) the marks previously | 
|  | 60 | +            # defined in the function currently on top of the _mark_func_stack. | 
|  | 61 | +            marks = [] | 
|  | 62 | + | 
|  | 63 | +            def local_mark(label): | 
|  | 64 | +                marks.append((label, time.perf_counter())) | 
|  | 65 | +            _mark_func_stack.append(local_mark) | 
|  | 66 | + | 
|  | 67 | +            start_time = time.perf_counter() | 
|  | 68 | +            # the mark() called in wrapped function will now call local_mark | 
|  | 69 | +            result = func(*args, **kwargs) | 
|  | 70 | +            end_time = time.perf_counter() | 
|  | 71 | + | 
|  | 72 | +            _mark_func_stack.pop() | 
|  | 73 | + | 
|  | 74 | +            output = [] | 
|  | 75 | +            # start-to-end time | 
|  | 76 | +            elapsed = end_time - start_time | 
|  | 77 | + | 
|  | 78 | +            # TODO (tharitt): Both MPI + NCCL collective calls have implicit synchronization inside the stream | 
|  | 79 | +            # So, output only from rank=0 should suffice. We can add per-rank output later on if makes sense. | 
|  | 80 | +            if rank == 0: | 
|  | 81 | +                level = len(_mark_func_stack) | 
|  | 82 | +                output.append(f"{'---' * level}{description or func.__name__} {elapsed:6f} seconds (rank = {rank})\n") | 
|  | 83 | +                if marks: | 
|  | 84 | +                    prev_label, prev_t = marks[0] | 
|  | 85 | +                    for label, t in marks[1:]: | 
|  | 86 | +                        output.append(f"{'---' * level}[{prev_label} ---> {label}]: {t - prev_t:.6f}s \n") | 
|  | 87 | +                        prev_label, prev_t = label, t | 
|  | 88 | + | 
|  | 89 | +                if save_file: | 
|  | 90 | +                    with open(file_path, "a") as f: | 
|  | 91 | +                        f.write("".join(output)) | 
|  | 92 | +                else: | 
|  | 93 | +                    print("".join(output)) | 
|  | 94 | +            return result | 
|  | 95 | +        return wrapper | 
|  | 96 | +    if func is not None: | 
|  | 97 | +        return decorator(func) | 
|  | 98 | + | 
|  | 99 | +    return decorator | 
0 commit comments