Skip to content

[inference] High thunder.jit/thunderfx overhead for small moduleΒ #2775

@kshitij12345

Description

@kshitij12345

In inference scenario, when adding thunderfx/thunder.jit to a submodule of a Model, thunder has high overhead leading to poorer perf even compared to eager.

In the following benchmark script, we apply thunder, thunderfx (also benchmark the generated nvFuser FusionDefinition) and torch.compile to RMSNorm and compare their perf against eager.

NOTE: In this case, thunder creates a single FusionDefinition for the entire program which nvFuser executes.

import torch
from torch.utils.benchmark import Timer
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
from thunder.dynamo import thunderfx
import thunder.examine

model_name = "openai/gpt-oss-20b"

measurements = []

# set_num_threads to 1 to avoid recompilations due to global state changes
# W1127 03:17:30.034000 16176 torch/_dynamo/convert_frame.py:1551] [0/8]    last reason: 0/5: GLOBAL_STATE changed: num_threads 
# W1127 03:17:30.034000 16176 torch/_dynamo/convert_frame.py:1551] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
torch.set_num_threads(1)

torch._dynamo.config.recompile_limit = 64

for shape in (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048):
    with torch.device("cuda"):
        rms_norm = GptOssRMSNorm(2880, eps=1e-05)

    rms_norm.requires_grad_(False)

    x = torch.randn(1, shape, 2880, device="cuda", dtype=torch.bfloat16)
    print(f"Shape: {x.shape}")
    tf_rms_norm = thunderfx(rms_norm)
    tc_rms_norm = torch.compile(rms_norm)
    tf_rms_norm(x)
    exec_trc = tf_rms_norm.last_traces[0]
    fusions = thunder.examine.get_fusions(exec_trc)
    assert len(fusions) == 1
    name, fd_wrapper = fusions[0]
    fd = fd_wrapper.last_used

    tj_rms_norm = thunder.jit(rms_norm)

    args = [x, rms_norm.weight]
    eager_time = Timer(stmt="rms_norm(x)", globals=globals(), label="RMSNorm", description=f"Eager", sub_label=f"{x.shape}").blocked_autorange()
    thunder_time = Timer(stmt="tf_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"ThunderFX", sub_label=f"{x.shape}").blocked_autorange()
    thunder_jit_time = Timer(stmt="tj_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"ThunderJIT", sub_label=f"{x.shape}").blocked_autorange()
    torch_compile_time = Timer(stmt="tc_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"TorchCompile", sub_label=f"{x.shape}").blocked_autorange()
    fd_wrapper_time = Timer(stmt="fd_wrapper(*args)", globals=globals(), label="RMSNorm", description=f"FDWrapper", sub_label=f"{x.shape}").blocked_autorange()
    fd_time = Timer(stmt="fd.execute(args)", globals=globals(), label="RMSNorm", description=f"FD", sub_label=f"{x.shape}").blocked_autorange()
    measurements.extend([eager_time, thunder_time, thunder_jit_time, torch_compile_time, fd_wrapper_time, fd_time])
    
print(torch.utils.benchmark.Compare(measurements))

Results:

NOTE:

  1. FDWrapper corresponds to this entrypoint
  2. FD corresponds to calling FusionDefinition.execute directly.
[-------------------------------------------------- RMSNorm --------------------------------------------------]
                                   |  Eager  |  ThunderFX  |  ThunderJIT  |  TorchCompile  |  FDWrapper  |   FD
1 threads: ----------------------------------------------------------------------------------------------------
      torch.Size([1, 1, 2880])     |   34.0  |     59.3    |     50.1     |      16.3      |     7.8     |  5.0
      torch.Size([1, 2, 2880])     |   31.1  |     66.2    |     50.2     |      19.6      |     8.1     |  5.0
      torch.Size([1, 4, 2880])     |   31.2  |     67.0    |     50.6     |      16.8      |     8.0     |  5.0
      torch.Size([1, 8, 2880])     |   31.6  |     65.9    |     49.9     |      16.9      |     8.0     |  5.1
      torch.Size([1, 16, 2880])    |   31.1  |     66.1    |     49.7     |      16.8      |     8.1     |  5.0
      torch.Size([1, 32, 2880])    |   31.2  |     65.3    |     50.2     |      16.9      |     8.1     |  5.0
      torch.Size([1, 64, 2880])    |   30.8  |     65.1    |     50.5     |      16.8      |     8.1     |  5.0
      torch.Size([1, 128, 2880])   |   31.6  |     65.2    |     49.6     |      17.1      |     7.9     |  5.0
      torch.Size([1, 256, 2880])   |   30.6  |     65.1    |     49.2     |      16.8      |     8.0     |  5.0
      torch.Size([1, 512, 2880])   |   30.9  |     65.6    |     49.8     |      16.7      |     8.1     |  5.1
      torch.Size([1, 1024, 2880])  |   41.4  |     65.3    |     50.0     |      16.9      |     7.9     |  5.0
      torch.Size([1, 2048, 2880])  |   75.4  |     64.9    |     50.1     |      16.9      |     8.1     |  7.7

Times are in microseconds (us).

Generated thunder trace:

Details
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast

@torch.no_grad()
@no_autocast
def computation(hidden_states, t_weight):
  # hidden_states: "cuda:0 bf16[1, 2048, 2880]"
  # t_weight: "cuda:0 f32[2880]"
  [t16] = nvFusion0(hidden_states, t_weight)
    # t0 = prims.convert_element_type(hidden_states, dtypes.float32)  # t0: "cuda:0 f32[1, 2048, 2880]"
    # t1 = prims.pow(t0, 2.0)  # t1: "cuda:0 f32[1, 2048, 2880]"
    # t2 = prims.shallow_copy(t1)  # t2: "cuda:0 f32[1, 2048, 2880]"
    # t3 = prims.sum(t2, (2,))  # t3: "cuda:0 f32[1, 2048]"
    # t4 = prims.broadcast_in_dim(t3, [1, 2048, 1], [0, 1])  # t4: "cuda:0 f32[1, 2048, 1]"
    # t5 = prims.div(t4, 2880.0)  # t5: "cuda:0 f32[1, 2048, 1]"
    # variance = prims.shallow_copy(t5)  # variance: "cuda:0 f32[1, 2048, 1]"
    # t8 = prims.add(variance, 1e-05)  # t8: "cuda:0 f32[1, 2048, 1]"
    # t9 = prims.rsqrt(t8)  # t9: "cuda:0 f32[1, 2048, 1]"
    # t10 = prims.broadcast_in_dim(t9, (1, 2048, 2880), (0, 1, 2))  # t10: "cuda:0 f32[1, 2048, 2880]"
    # t11 = prims.mul(t0, t10)  # t11: "cuda:0 f32[1, 2048, 2880]"
    # t14 = prims.broadcast_in_dim(t_weight, (1, 2048, 2880), (2,))  # t14: "cuda:0 f32[1, 2048, 2880]"
    # t15 = prims.mul(t14, t11)  # t15: "cuda:0 f32[1, 2048, 2880]"
    # t16 = prims.convert_element_type(t15, dtypes.bfloat16)  # t16: "cuda:0 bf16[1, 2048, 2880]"
  return (t16,)

From the results, it can be observed that the overhead from thunder is much higher at small shapes when with minimum overhead nvFuser could have delivered some performance benefit.

For reference, shapes observed when running openai/gpt-oss-20b with SGLang and querying the server with python3 -m sglang.bench_serving --backend sglang --num-prompt 10 which uses sharegpt dataset for prompt.

Shapes (at prefill stage) - For decode, SGLang applies CUDAGraph

[1, 15, 2880]
[1, 1369, 2880]
[1, 568, 2880]

Related: #2556

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions