Skip to content

Commit a589968

Browse files
authored
add profile plugin (#2683)
1 parent 530ba7f commit a589968

File tree

7 files changed

+86
-24
lines changed

7 files changed

+86
-24
lines changed

thunder/dev_utils/profile_transform.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import thunder
22
import torch
33
import contextlib
4+
import re
45

56
from thunder.core import prims
67
from thunder.core.symbol import Symbol
@@ -20,15 +21,22 @@ def bind_postprocess(debug_bsym):
2021

2122

2223
class ProfileTransform(thunder.core.transform_common.Transform):
23-
def __init__(self, *, warmup_runs=3, number_runs=1, start_idx=0, end_idx=-1, backward=False):
24-
self.start_idx = start_idx
25-
self.end_idx = end_idx
24+
def __init__(self, *, warmup_runs=3, number_runs=1, start_idx=0, end_idx=None, input_match=None, backward=False):
25+
self.input_match = input_match
26+
if input_match is None:
27+
self.start_idx = start_idx
28+
self.end_idx = end_idx if end_idx is not None else -1
29+
else:
30+
self.match_start_idx = start_idx
31+
self.match_end_idx = end_idx if end_idx is not None else 1
32+
2633
self.enabled = True
2734
self.run_counter = 0
2835
self.warmup_runs = warmup_runs
2936
self.number_runs = number_runs
3037
self.prof = None
3138
self.backward = backward
39+
self.computed_enabled = False
3240

3341
def start_profile(self):
3442
self.run_counter += 1
@@ -94,13 +102,27 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):
94102
if self.backward ^ (TraceTag.BACKWARD in computation_trace.tags):
95103
return computation_trace
96104

105+
if self.input_match is not None:
106+
self.match_list = []
107+
for i, bsym in enumerate(computation_trace.bound_symbols):
108+
for a in bsym.args:
109+
if isinstance(a, thunder.TensorProxy) and re.match(self.input_match, a.name):
110+
self.match_list.append((i, a.name))
111+
start_idx = self.match_list[self.match_start_idx][0]
112+
end_idx = self.match_list[self.match_end_idx][0]
113+
else:
114+
start_idx = self.start_idx
115+
end_idx = self.end_idx
116+
97117
new_bound_symbols = []
98118

99119
new_trace = thunder.core.trace.from_trace(computation_trace)
100-
start_idx = self.start_idx
120+
121+
need_end = False
101122

102123
for i, bsym in enumerate(computation_trace.bound_symbols[:]):
103-
if i == self.end_idx or bsym.sym == prims.python_return:
124+
if i == end_idx or (bsym.sym == prims.python_return and need_end):
125+
need_end = False
104126
new_bound_symbols.append(create_boundsymbol("end_profiling", None, self.end_profile))
105127
if bsym.sym in {
106128
prims.unpack_trivial,
@@ -115,6 +137,7 @@ def transform_trace_post_optimization(self, computation_trace, **kwargs):
115137
start_idx += 1
116138
continue
117139
if i == start_idx:
140+
need_end = True
118141
new_bound_symbols.append(create_boundsymbol("start_profiling", None, self.start_profile))
119142
new_bound_symbols.append(
120143
create_boundsymbol(

thunder/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from thunder.plugins.distributed import DDP, FSDP
22
from thunder.plugins.quantization import QuantizeInt4
33
from thunder.plugins.fp8 import FP8
4+
from thunder.plugins.profile import Profile as Profile
45
from thunder.plugins.reduce_overhead import ReduceOverhead
56

67
names_to_plugins = {

thunder/plugins/profile.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from thunder.core.recipe import Plugin, PluginPolicy
2+
from thunder.dev_utils.profile_transform import ProfileTransform
3+
4+
5+
class Profile(Plugin):
6+
policy = PluginPolicy.POST
7+
8+
def __init__(self, input_match, from_match_idx=0, to_match_idx=1):
9+
self.profile_transform = ProfileTransform(
10+
input_match=input_match, start_idx=from_match_idx, end_idx=to_match_idx
11+
)
12+
13+
def setup_transforms(self):
14+
return [self.profile_transform]

thunder/recipes/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,13 +78,11 @@ def __init__(
7878
self.fuser = fuser
7979
self.executor_names = []
8080

81-
if torch.cuda.is_available():
82-
self.executor_names = ["cudnn", "sdpa"]
83-
if self.fuser == "nvfuser":
84-
self.executor_names.append("torchcompile_xentropy")
85-
else:
86-
print("GPU not found, nvFuser not available. Setting fusing executor to torch.compile")
87-
self.fuser = "torch.compile"
81+
if fuser is None:
82+
if torch.cuda.is_available():
83+
self.fuser = "nvfuser"
84+
else:
85+
self.fuser = "torch.compile"
8886

8987
self.setup_fuser()
9088
self.show_progress = show_progress

thunder/recipes/hf_transformers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010
from thunder import Recipe
1111

1212

13+
# for materializing models, we need reset_parameters, which is part of the unwritten
14+
# spec for idiomatic PyTorch, but not implemented everywhere
15+
def RotaryEmbedding_reset_parameters(self):
16+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device)
17+
with torch.no_grad():
18+
self.inv_freq.copy_(inv_freq)
19+
20+
21+
def RMSNorm_reset_parameters(self):
22+
with torch.no_grad():
23+
self.weight.fill_(1)
24+
25+
1326
class InplaceIndexCopyTransform(thunder.Transform):
1427
def __init__(self):
1528
super().__init__()
@@ -308,6 +321,16 @@ def apply(self, model):
308321
transformers.PreTrainedModel: Thunder-compiled model ready
309322
for inference.
310323
"""
324+
325+
# We need reset_parameters for initialization of buffers in materialization.
326+
# This seems to work for transformers 4.5x with Llama, Llama4 and Qwen2 at least
327+
for submodule in model.modules():
328+
cls = submodule.__class__
329+
if cls.__name__.endswith("RotaryEmbedding") and not hasattr(cls, "reset_parameters"):
330+
cls.reset_parameters = RotaryEmbedding_reset_parameters
331+
elif cls.__name__.endswith("RMSNorm") and not hasattr(cls, "reset_parameters"):
332+
cls.reset_parameters = RMSNorm_reset_parameters
333+
311334
thunder_model = super().apply(model)
312335

313336
if getattr(thunder_model, "generate", None):

thunder/tests/test_recipes.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from torch.testing import assert_close
1212
from thunder.recipes import HFTransformers
1313
from thunder.executors import nvfuser_available
14-
from thunder.executors.cudnnex import cudnn_available
1514
from thunder.tests.framework import IS_WINDOWS
1615

1716

18-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
17+
def get_expected_executors():
18+
return [ex for ex in thunder.get_default_executors() if ex.name not in {"cudnn", "sdpa", "torchcompile_xentropy"}]
19+
20+
1921
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
2022
@pytest.mark.skipif(IS_WINDOWS, reason="slow on Windows")
2123
def test_default_recipe_basic_bert():
@@ -33,7 +35,6 @@ def test_default_recipe_basic_bert():
3335
assert_close(actual, expected)
3436

3537

36-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
3738
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
3839
@pytest.mark.skipif(IS_WINDOWS, reason="slow on Windows")
3940
def test_recipe_basic_bert():
@@ -65,7 +66,6 @@ def test_recipe_basic_bert():
6566
deregister_executor("sdpa_mask_transform_ex")
6667

6768

68-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
6969
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
7070
def test_recipe_basic_bert_fx():
7171
bert = transformers.BertForSequenceClassification(transformers.BertConfig())
@@ -88,7 +88,6 @@ def test_recipe_basic_bert_fx():
8888
deregister_executor("sdpa_mask_transform_ex")
8989

9090

91-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
9291
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
9392
@pytest.mark.parametrize(
9493
"model_cls, config_cls",
@@ -186,7 +185,6 @@ def __init__(self):
186185
deregister_executor("sdpa_mask_transform_ex")
187186

188187

189-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
190188
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
191189
def test_plugins_basics():
192190
model = torch.nn.Sequential(torch.nn.Linear(2048, 4096), torch.nn.ReLU(), torch.nn.Linear(4096, 64))
@@ -198,12 +196,11 @@ def test_plugins_basics():
198196
_ = thunder_model(x)
199197
cd = get_compile_data(thunder_model)
200198
assert cd is not None
201-
for ex in thunder.get_default_executors():
199+
for ex in get_expected_executors():
202200
assert ex.name in [el.name for el in cd.executors_list]
203201

204202

205203
# test skipped if nvfuser isn't available because providing plugins calls BaseRecipe
206-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
207204
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
208205
@pytest.mark.skipif(IS_WINDOWS, reason="libuv error with PT build on windows")
209206
def test_plugins_composition(monkeypatch):
@@ -215,21 +212,21 @@ def test_plugins_composition(monkeypatch):
215212
_ = thunder.compile(model, plugins="fp8")
216213
call_args = mock_jit.call_args
217214
assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]]
218-
for ex in thunder.get_default_executors():
215+
for ex in get_expected_executors():
219216
assert ex.name in [el.name for el in call_args.kwargs["executors"]]
220217

221218
_ = thunder.compile(model, plugins=["fp8"])
222219
call_args = mock_jit.call_args
223220
assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]]
224-
for ex in thunder.get_default_executors():
221+
for ex in get_expected_executors():
225222
assert ex.name in [el.name for el in call_args.kwargs["executors"]]
226223

227224
from thunder.plugins import FP8
228225

229226
_ = thunder.compile(model, plugins=[FP8()])
230227
call_args = mock_jit.call_args
231228
assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]]
232-
for ex in thunder.get_default_executors():
229+
for ex in get_expected_executors():
233230
assert ex.name in [el.name for el in call_args.kwargs["executors"]]
234231

235232
if not torch.distributed.is_initialized():
@@ -259,7 +256,6 @@ def test_plugins_composition(monkeypatch):
259256
assert "transformer_engine_v1" in [el.name for el in call_args.kwargs["executors"]]
260257

261258

262-
@pytest.mark.skipif(not cudnn_available(), reason="cuDNN is not available")
263259
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")
264260
@pytest.mark.skipif(IS_WINDOWS, reason="libuv error with PT build on windows")
265261
def test_plugins_hybrid_ddpfsdp(monkeypatch):

thunder/tests/test_transformer_engine_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,13 @@ def thunder_model(x):
718718

719719

720720
@requiresCUDA
721+
@pytest.mark.skipif(
722+
LooseVersion(transformer_engine.__version__) < LooseVersion("2.9"),
723+
reason="need TE >= 2.9 for quantizer location",
724+
)
721725
def test_te_inference_8bit():
726+
from thunder.transforms.te_inference import TEInference8BitTransform
727+
722728
with torch.device("cuda"):
723729
m = torch.nn.Sequential(
724730
torch.nn.Linear(1024, 2048),
@@ -733,6 +739,7 @@ def test_te_inference_8bit():
733739
a = torch.randn(16, 1024, device="cuda")
734740

735741
quant_transform = TEInference8BitTransform()
742+
te_inference_executor = quant_transform.get_executor()
736743
quant_transform2 = TEInference8BitTransform()
737744
jm = thunder.jit(
738745
m, transforms=[quant_transform], executors=(te_inference_executor, *thunder.get_default_executors())

0 commit comments

Comments
 (0)