Skip to content
9 changes: 8 additions & 1 deletion thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class InferenceBenchmarkConfig:
mode: str
disable_moe_replacement: bool
profile: bool
enable_thunder_cudagraph: bool


@dataclass
Expand Down Expand Up @@ -281,7 +282,7 @@ def __init__(self, config: InferenceBenchmarkConfig):
def _thunder_jit_options(self) -> dict[str, Any]:
# `nv_enable_linear=True` might fail with distributed run
# ref: https://github.com/NVIDIA/Fuser/issues/4507
res = {}
res = {"transforms": []}
if self.config.enable_nv_linear:
res = {"nv_enable_linear": True, "nv_enable_matmul": True}
if self.config.mode == "thunderjit":
Expand All @@ -291,6 +292,10 @@ def _thunder_jit_options(self) -> dict[str, Any]:
self._mask_transform = SDPAMaskTransform()
res["transforms"] = [self._mask_transform]
res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()]
if self.config.enable_thunder_cudagraph:
from thunder.transforms.cudagraph import CUDAGraphTransform

res["transforms"].append(CUDAGraphTransform())
return res

def _compile_model(self, model):
Expand Down Expand Up @@ -676,6 +681,7 @@ def parse_args() -> argparse.Namespace:

parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
parser.add_argument("--enable-thunder-cudagraph", action="store_true", help="Pass CUDAGraphTransform to Thunder")

args = parser.parse_args()
return args
Expand Down Expand Up @@ -708,6 +714,7 @@ def main():
enable_nv_linear=args.enable_nv_linear,
disable_moe_replacement=args.disable_moe_replacement,
profile=args.profile,
enable_thunder_cudagraph=args.enable_thunder_cudagraph,
)
benchmark = InferenceBenchmark(config)

Expand Down
Loading