Skip to content

Commit 53c3d23

Browse files
Copilotcrcrpar
andauthored
Add custom_op executor to the default executors to avoid hidden registration through thunder.torch.custom_op._register_custom_op (#2714)
Co-authored-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent 387c791 commit 53c3d23

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

thunder/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@
199199
apex_executor: None | extend.Executor = extend.get_executor("apex")
200200
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
201201
pytorch_executor: None | extend.Executor = extend.get_executor("torch")
202+
custom_op_executor: None | extend.Executor = extend.get_executor("custom_op")
202203

203-
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python]
204+
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> custom_op -> torch -> python]
204205
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
206+
add_default_executor(custom_op_executor)
205207
if nvfuser_executor:
206208
add_default_executor(nvfuser_executor)
207209

thunder/extend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def get_all_executors() -> tuple[Executor, ...]:
533533
apexex,
534534
cudnn_layernormex,
535535
cudnnex,
536+
custom_op_ex,
536537
nvfuserex,
537538
pythonex,
538539
sdpaex,

thunder/tests/test_recipes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515

1616

1717
def get_expected_executors():
18-
return [ex for ex in thunder.get_default_executors() if ex.name not in {"cudnn", "sdpa", "torchcompile_xentropy"}]
18+
return [
19+
ex
20+
for ex in thunder.get_default_executors()
21+
if ex.name not in {"cudnn", "sdpa", "torchcompile_xentropy", "custom_op"}
22+
]
1923

2024

2125
@pytest.mark.skipif(not nvfuser_available(), reason="nvFuser is not available")

thunder/torch/custom_op.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,6 @@ def _register_custom_op(custom_op: CustomOpDef) -> Symbol:
332332
.. note::
333333
This feature is experimental and subject to change.
334334
"""
335-
from thunder.extend import add_executor_lists
336-
from thunder.extend import get_default_executors
337-
from thunder.extend import set_default_executors
338335
from thunder.executors.torchex import _always_executable
339336
from thunder.executors.custom_op_ex import custom_op_ex
340337
from thunder.torch import register_function
@@ -412,12 +409,6 @@ def _register_custom_op(custom_op: CustomOpDef) -> Symbol:
412409
backward_op = custom_op_ex.register_operator(bwd_fn_name, meta=backward_meta, fn=backward_impl)
413410
register_backward(symbol.id)(backward_op)
414411

415-
# NOTE: `thunder.extend.add_default_executor` basically does `lst.insert(ex, 0)`.
416-
if custom_op_ex not in get_default_executors():
417-
default_executors = get_default_executors()
418-
new_default_executors = add_executor_lists(default_executors, [custom_op_ex])
419-
set_default_executors(new_default_executors)
420-
421412
_CUSTOM_OP_TO_TORCHFN_AND_SYMBOL[custom_op] = ((torch_opoverload, torch_opoverload_packet), symbol)
422413

423414
return symbol

0 commit comments

Comments
 (0)