Skip to content

Performance tracing and tuning code #941

@wangkuiyi

Description

@wangkuiyi

Hello, team.

We created a performance tracing utility, as shown below, that prints the time it takes for each function in our MLX application and assists us in identifying the component that needs improvement.

image

Do you think it's worth checking the code into the mlx_example repository? If so, I'd be pleased to submit a pull request. Thanks.

The file time_mlx.py:

import time
from typing import Callable, Dict, List, Optional

import mlx.core as mx
import mlx.nn
from tabulate import tabulate


class _Record:
    def __init__(self, msg: str, indentation: int) -> None:
        self.msg = msg
        self.indentation = indentation
        self.timing: List[float] = []
        self.parent: Optional[_Record] = None


class Ledger:
    def __init__(self) -> None:
        self.records: List[_Record] = []
        self.records_dict: Dict[str, _Record] = {}
        self.indentation = -1
        self.key = ""

    def reset(self):
        self.records = []
        self.records_dict: Dict[str, _Record] = {}
        self.indentation = -1
        self.key = ""

    def print_table(self):
        table = [
            [
                "-" * r.indentation + "> " + r.msg,
                sum(r.timing) / len(r.timing),
                sum(r.timing),
                sum(r.timing) / sum(r.parent.timing) * 100 if r.parent else 100,
            ]
            for r in self.records
        ]
        print(
            tabulate(
                table,
                headers=[
                    "function",
                    "latency per run (ms)",
                    "latency in total (ms)",
                    "Latency Ratio (%)",
                ],
                tablefmt="psql",
            )
        )

    def print_summary(self):
        for r in self.records:
            print(f"{r.msg} {sum(r.timing):.3f} (ms)")


ledger = Ledger()


def function(msg: str):
    """This decorator times the exeuction time of a function that calls MLX"""

    def decorator(g: Callable):
        def g_wrapped(*args, **kwargs):
            # Evaluate each of the input parameters to make sure they are ready before starting
            # ticking, and evaluate the return value(s) of g to make sure they are ready before
            # ending ticking.
            def eval_arg(arg):
                if (
                    isinstance(arg, mx.array)
                    or isinstance(arg, list)
                    or isinstance(arg, tuple)
                    or isinstance(arg, dict)
                ):
                    mx.eval(arg)
                elif isinstance(arg, mlx.nn.Module):
                    mx.eval(arg.parameters())
                return arg

            for arg in args:
                eval_arg(arg)
            for k, v in kwargs.items():
                eval_arg(v)

            ledger.indentation += 1
            prev_key = ledger.key

            ledger.key += msg
            if ledger.key not in ledger.records_dict:
                r = _Record(msg, ledger.indentation)
                ledger.records.append(r)
                ledger.records_dict[ledger.key] = r
                r.parent = ledger.records_dict[prev_key] if len(prev_key) > 0 else None

            tic = time.perf_counter()
            result = g(*args, **kwargs)
            eval_arg(result)
            timing = 1e3 * (time.perf_counter() - tic)
            ledger.records_dict[ledger.key].timing.append(timing)

            ledger.indentation -= 1
            ledger.key = prev_key

            return result

        return g_wrapped

    return decorator

The unit test, which serves as an example as well.

import math

import mlx
import mlx.core as mx

from . import time_mlx


@time_mlx.function("two_projs")
def two_projs():
    DIM = 1024

    @time_mlx.function("create_projs")
    def create_projs():
        p1 = mlx.nn.Linear(DIM, DIM, bias=False)
        p2 = mlx.nn.Linear(DIM, DIM, bias=False)
        return p1, p2

    @time_mlx.function("run_projs")
    def run_projs(p1, p2):
        x = mx.ones((1024, 1024))
        return p2(p1(x))

    p1, p2 = create_projs()
    run_projs(p1, p2)


def test_time_mlx_two_projs():
    for _ in range(10):
        ledger = time_mlx.ledger
        ledger.reset()
        two_projs()
        assert (
            ledger.records_dict["two_projscreate_projs"].timing[0]
            + ledger.records_dict["two_projsrun_projs"].timing[0]
            < ledger.records_dict["two_projs"].timing[0]
        )
        assert (
            math.fabs(
                ledger.records_dict["two_projscreate_projs"].timing[0]
                + ledger.records_dict["two_projsrun_projs"].timing[0]
                - ledger.records_dict["two_projs"].timing[0]
            )
            < 1  # ms
        )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions