1111from torch .testing import assert_close
1212from thunder .recipes import HFTransformers
1313from thunder .executors import nvfuser_available
14- from thunder .executors .cudnnex import cudnn_available
1514from thunder .tests .framework import IS_WINDOWS
1615
1716
18- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
17+ def get_expected_executors ():
18+ return [ex for ex in thunder .get_default_executors () if ex .name not in {"cudnn" , "sdpa" , "torchcompile_xentropy" }]
19+
20+
1921@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
2022@pytest .mark .skipif (IS_WINDOWS , reason = "slow on Windows" )
2123def test_default_recipe_basic_bert ():
@@ -33,7 +35,6 @@ def test_default_recipe_basic_bert():
3335 assert_close (actual , expected )
3436
3537
36- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
3738@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
3839@pytest .mark .skipif (IS_WINDOWS , reason = "slow on Windows" )
3940def test_recipe_basic_bert ():
@@ -65,7 +66,6 @@ def test_recipe_basic_bert():
6566 deregister_executor ("sdpa_mask_transform_ex" )
6667
6768
68- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
6969@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
7070def test_recipe_basic_bert_fx ():
7171 bert = transformers .BertForSequenceClassification (transformers .BertConfig ())
@@ -88,7 +88,6 @@ def test_recipe_basic_bert_fx():
8888 deregister_executor ("sdpa_mask_transform_ex" )
8989
9090
91- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
9291@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
9392@pytest .mark .parametrize (
9493 "model_cls, config_cls" ,
@@ -186,7 +185,6 @@ def __init__(self):
186185 deregister_executor ("sdpa_mask_transform_ex" )
187186
188187
189- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
190188@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
191189def test_plugins_basics ():
192190 model = torch .nn .Sequential (torch .nn .Linear (2048 , 4096 ), torch .nn .ReLU (), torch .nn .Linear (4096 , 64 ))
@@ -198,12 +196,11 @@ def test_plugins_basics():
198196 _ = thunder_model (x )
199197 cd = get_compile_data (thunder_model )
200198 assert cd is not None
201- for ex in thunder . get_default_executors ():
199+ for ex in get_expected_executors ():
202200 assert ex .name in [el .name for el in cd .executors_list ]
203201
204202
205203# test skipped if nvfuser isn't available because providing plugins calls BaseRecipe
206- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
207204@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
208205@pytest .mark .skipif (IS_WINDOWS , reason = "libuv error with PT build on windows" )
209206def test_plugins_composition (monkeypatch ):
@@ -215,21 +212,21 @@ def test_plugins_composition(monkeypatch):
215212 _ = thunder .compile (model , plugins = "fp8" )
216213 call_args = mock_jit .call_args
217214 assert "transformer_engine_v1" in [el .name for el in call_args .kwargs ["executors" ]]
218- for ex in thunder . get_default_executors ():
215+ for ex in get_expected_executors ():
219216 assert ex .name in [el .name for el in call_args .kwargs ["executors" ]]
220217
221218 _ = thunder .compile (model , plugins = ["fp8" ])
222219 call_args = mock_jit .call_args
223220 assert "transformer_engine_v1" in [el .name for el in call_args .kwargs ["executors" ]]
224- for ex in thunder . get_default_executors ():
221+ for ex in get_expected_executors ():
225222 assert ex .name in [el .name for el in call_args .kwargs ["executors" ]]
226223
227224 from thunder .plugins import FP8
228225
229226 _ = thunder .compile (model , plugins = [FP8 ()])
230227 call_args = mock_jit .call_args
231228 assert "transformer_engine_v1" in [el .name for el in call_args .kwargs ["executors" ]]
232- for ex in thunder . get_default_executors ():
229+ for ex in get_expected_executors ():
233230 assert ex .name in [el .name for el in call_args .kwargs ["executors" ]]
234231
235232 if not torch .distributed .is_initialized ():
@@ -259,7 +256,6 @@ def test_plugins_composition(monkeypatch):
259256 assert "transformer_engine_v1" in [el .name for el in call_args .kwargs ["executors" ]]
260257
261258
262- @pytest .mark .skipif (not cudnn_available (), reason = "cuDNN is not available" )
263259@pytest .mark .skipif (not nvfuser_available (), reason = "nvFuser is not available" )
264260@pytest .mark .skipif (IS_WINDOWS , reason = "libuv error with PT build on windows" )
265261def test_plugins_hybrid_ddpfsdp (monkeypatch ):
0 commit comments