-
Notifications
You must be signed in to change notification settings - Fork 110
Description
In inference scenario, when adding thunderfx/thunder.jit to a submodule of a Model, thunder has high overhead leading to poorer perf even compared to eager.
In the following benchmark script, we apply thunder, thunderfx (also benchmark the generated nvFuser FusionDefinition) and torch.compile to RMSNorm and compare their perf against eager.
NOTE: In this case, thunder creates a single FusionDefinition for the entire program which nvFuser executes.
import torch
from torch.utils.benchmark import Timer
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssRMSNorm
from thunder.dynamo import thunderfx
import thunder.examine
model_name = "openai/gpt-oss-20b"
measurements = []
# set_num_threads to 1 to avoid recompilations due to global state changes
# W1127 03:17:30.034000 16176 torch/_dynamo/convert_frame.py:1551] [0/8] last reason: 0/5: GLOBAL_STATE changed: num_threads
# W1127 03:17:30.034000 16176 torch/_dynamo/convert_frame.py:1551] [0/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
torch.set_num_threads(1)
torch._dynamo.config.recompile_limit = 64
for shape in (1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048):
with torch.device("cuda"):
rms_norm = GptOssRMSNorm(2880, eps=1e-05)
rms_norm.requires_grad_(False)
x = torch.randn(1, shape, 2880, device="cuda", dtype=torch.bfloat16)
print(f"Shape: {x.shape}")
tf_rms_norm = thunderfx(rms_norm)
tc_rms_norm = torch.compile(rms_norm)
tf_rms_norm(x)
exec_trc = tf_rms_norm.last_traces[0]
fusions = thunder.examine.get_fusions(exec_trc)
assert len(fusions) == 1
name, fd_wrapper = fusions[0]
fd = fd_wrapper.last_used
tj_rms_norm = thunder.jit(rms_norm)
args = [x, rms_norm.weight]
eager_time = Timer(stmt="rms_norm(x)", globals=globals(), label="RMSNorm", description=f"Eager", sub_label=f"{x.shape}").blocked_autorange()
thunder_time = Timer(stmt="tf_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"ThunderFX", sub_label=f"{x.shape}").blocked_autorange()
thunder_jit_time = Timer(stmt="tj_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"ThunderJIT", sub_label=f"{x.shape}").blocked_autorange()
torch_compile_time = Timer(stmt="tc_rms_norm(x)", globals=globals(), label="RMSNorm", description=f"TorchCompile", sub_label=f"{x.shape}").blocked_autorange()
fd_wrapper_time = Timer(stmt="fd_wrapper(*args)", globals=globals(), label="RMSNorm", description=f"FDWrapper", sub_label=f"{x.shape}").blocked_autorange()
fd_time = Timer(stmt="fd.execute(args)", globals=globals(), label="RMSNorm", description=f"FD", sub_label=f"{x.shape}").blocked_autorange()
measurements.extend([eager_time, thunder_time, thunder_jit_time, torch_compile_time, fd_wrapper_time, fd_time])
print(torch.utils.benchmark.Compare(measurements))Results:
NOTE:
- FDWrapper corresponds to this entrypoint
- FD corresponds to calling
FusionDefinition.executedirectly.
[-------------------------------------------------- RMSNorm --------------------------------------------------]
| Eager | ThunderFX | ThunderJIT | TorchCompile | FDWrapper | FD
1 threads: ----------------------------------------------------------------------------------------------------
torch.Size([1, 1, 2880]) | 34.0 | 59.3 | 50.1 | 16.3 | 7.8 | 5.0
torch.Size([1, 2, 2880]) | 31.1 | 66.2 | 50.2 | 19.6 | 8.1 | 5.0
torch.Size([1, 4, 2880]) | 31.2 | 67.0 | 50.6 | 16.8 | 8.0 | 5.0
torch.Size([1, 8, 2880]) | 31.6 | 65.9 | 49.9 | 16.9 | 8.0 | 5.1
torch.Size([1, 16, 2880]) | 31.1 | 66.1 | 49.7 | 16.8 | 8.1 | 5.0
torch.Size([1, 32, 2880]) | 31.2 | 65.3 | 50.2 | 16.9 | 8.1 | 5.0
torch.Size([1, 64, 2880]) | 30.8 | 65.1 | 50.5 | 16.8 | 8.1 | 5.0
torch.Size([1, 128, 2880]) | 31.6 | 65.2 | 49.6 | 17.1 | 7.9 | 5.0
torch.Size([1, 256, 2880]) | 30.6 | 65.1 | 49.2 | 16.8 | 8.0 | 5.0
torch.Size([1, 512, 2880]) | 30.9 | 65.6 | 49.8 | 16.7 | 8.1 | 5.1
torch.Size([1, 1024, 2880]) | 41.4 | 65.3 | 50.0 | 16.9 | 7.9 | 5.0
torch.Size([1, 2048, 2880]) | 75.4 | 64.9 | 50.1 | 16.9 | 8.1 | 7.7
Times are in microseconds (us).
Generated thunder trace:
Details
# Constructed by Unwrap the actual return value
import torch
from thunder.executors.torchex import no_autocast
@torch.no_grad()
@no_autocast
def computation(hidden_states, t_weight):
# hidden_states: "cuda:0 bf16[1, 2048, 2880]"
# t_weight: "cuda:0 f32[2880]"
[t16] = nvFusion0(hidden_states, t_weight)
# t0 = prims.convert_element_type(hidden_states, dtypes.float32) # t0: "cuda:0 f32[1, 2048, 2880]"
# t1 = prims.pow(t0, 2.0) # t1: "cuda:0 f32[1, 2048, 2880]"
# t2 = prims.shallow_copy(t1) # t2: "cuda:0 f32[1, 2048, 2880]"
# t3 = prims.sum(t2, (2,)) # t3: "cuda:0 f32[1, 2048]"
# t4 = prims.broadcast_in_dim(t3, [1, 2048, 1], [0, 1]) # t4: "cuda:0 f32[1, 2048, 1]"
# t5 = prims.div(t4, 2880.0) # t5: "cuda:0 f32[1, 2048, 1]"
# variance = prims.shallow_copy(t5) # variance: "cuda:0 f32[1, 2048, 1]"
# t8 = prims.add(variance, 1e-05) # t8: "cuda:0 f32[1, 2048, 1]"
# t9 = prims.rsqrt(t8) # t9: "cuda:0 f32[1, 2048, 1]"
# t10 = prims.broadcast_in_dim(t9, (1, 2048, 2880), (0, 1, 2)) # t10: "cuda:0 f32[1, 2048, 2880]"
# t11 = prims.mul(t0, t10) # t11: "cuda:0 f32[1, 2048, 2880]"
# t14 = prims.broadcast_in_dim(t_weight, (1, 2048, 2880), (2,)) # t14: "cuda:0 f32[1, 2048, 2880]"
# t15 = prims.mul(t14, t11) # t15: "cuda:0 f32[1, 2048, 2880]"
# t16 = prims.convert_element_type(t15, dtypes.bfloat16) # t16: "cuda:0 bf16[1, 2048, 2880]"
return (t16,)From the results, it can be observed that the overhead from thunder is much higher at small shapes when with minimum overhead nvFuser could have delivered some performance benefit.
For reference, shapes observed when running openai/gpt-oss-20b with SGLang and querying the server with python3 -m sglang.bench_serving --backend sglang --num-prompt 10 which uses sharegpt dataset for prompt.
Shapes (at prefill stage) - For decode, SGLang applies CUDAGraph
[1, 15, 2880]
[1, 1369, 2880]
[1, 568, 2880]
Related: #2556