Skip to content

Commit 9c10f0a

Browse files
committed
add custom_op executor to the default executor list
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
1 parent 70741cc commit 9c10f0a

File tree

4 files changed

+4
-71
lines changed

4 files changed

+4
-71
lines changed

thunder/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,11 @@
199199
apex_executor: None | extend.Executor = extend.get_executor("apex")
200200
nvfuser_executor: None | extend.Executor = extend.get_executor("nvfuser")
201201
pytorch_executor: None | extend.Executor = extend.get_executor("torch")
202+
custom_op_executor: None | extend.Executor = extend.get_executor("custom_op")
202203

203-
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> torch -> python]
204+
# Default executor list is [cudnn -> sdpa -> torchcompile_cat -> torchcompile_xentropy -> nvfuser -> custom_op -> torch -> python]
204205
# Note that add_default_executor inserts executor at start of list, hence the reverse order below.
206+
add_default_executor(custom_op_executor)
205207
if nvfuser_executor:
206208
add_default_executor(nvfuser_executor)
207209

thunder/extend/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def get_all_executors() -> tuple[Executor, ...]:
533533
apexex,
534534
cudnn_layernormex,
535535
cudnnex,
536+
custom_op_ex,
536537
nvfuserex,
537538
pythonex,
538539
sdpaex,

thunder/tests/test_torch_library_custom_op.py

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -263,58 +263,3 @@ def mul_translator(a, b, c=None, *, fd, lc_to_nv_map):
263263
if bsym.sym.name == f"{_symbol.name}_backward" and bsym.sym.executor is custom_op_ex:
264264
bsym_custom_ex_bsym_found = True
265265
assert bsym_custom_ex_bsym_found
266-
267-
268-
def test_custom_op_executor_cleanup():
269-
"""Test that custom_op executor is properly removed from default executors after deregistration.
270-
271-
This is a regression test for the issue where the custom_op executor would remain in the
272-
default executors list after all custom ops were deregistered, causing failures in tests
273-
that check the expected executor list.
274-
275-
The original issue manifested when running test_recipes.py after test_torch_library_custom_op.py.
276-
The test_recipes tests use get_expected_executors() which filters thunder.get_default_executors(),
277-
and they expect that only executors actually used by the model are present. When custom_op
278-
executor wasn't properly cleaned up, it would remain in the default executors list even though
279-
no custom ops were registered, causing assertions like:
280-
assert ex.name in [el.name for el in cd.executors_list]
281-
to fail because 'custom_op' was in get_expected_executors() but not in cd.executors_list.
282-
"""
283-
import thunder
284-
285-
# Define a test custom op
286-
@torch.library.custom_op("test_cleanup::mul", mutates_args=())
287-
def cleanup_mul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
288-
return a * b
289-
290-
@torch.library.register_kernel("test_cleanup::mul", "cpu")
291-
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
292-
return a * b
293-
294-
@torch.library.register_fake("test_cleanup::mul")
295-
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
296-
return torch.empty_like(a)
297-
298-
# Get initial state (custom_op should not be in default executors)
299-
initial_executors = [ex.name for ex in thunder.get_default_executors()]
300-
assert "custom_op" not in initial_executors, "custom_op should not be in default executors initially"
301-
302-
# Simulate what happens in test_torch_library_custom_op tests
303-
# Register the custom op (this adds custom_op_ex to default executors)
304-
symbol = _register_custom_op(cleanup_mul)
305-
executors_after_register = [ex.name for ex in thunder.get_default_executors()]
306-
assert "custom_op" in executors_after_register, "custom_op should be added after registration"
307-
308-
# Simulate the cleanup that happens in the autouse fixture
309-
# Without the fix, this would NOT remove custom_op_ex from default executors
310-
_deregister_custom_op(cleanup_mul)
311-
executors_after_deregister = [ex.name for ex in thunder.get_default_executors()]
312-
313-
# This is the critical assertion that would fail with the bug:
314-
# After deregistration, custom_op should be removed from default executors
315-
assert "custom_op" not in executors_after_deregister, \
316-
"custom_op should be removed from default executors when no custom ops remain"
317-
318-
# Verify we're back to the initial state, which is what test_recipes.py expects
319-
assert executors_after_deregister == initial_executors, \
320-
"Should return to initial executor state after deregistration"

thunder/torch/custom_op.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,6 @@ def _register_custom_op(custom_op: CustomOpDef) -> Symbol:
332332
.. note::
333333
This feature is experimental and subject to change.
334334
"""
335-
from thunder.extend import add_executor_lists
336-
from thunder.extend import get_default_executors
337-
from thunder.extend import set_default_executors
338335
from thunder.executors.torchex import _always_executable
339336
from thunder.executors.custom_op_ex import custom_op_ex
340337
from thunder.torch import register_function
@@ -412,12 +409,6 @@ def _register_custom_op(custom_op: CustomOpDef) -> Symbol:
412409
backward_op = custom_op_ex.register_operator(bwd_fn_name, meta=backward_meta, fn=backward_impl)
413410
register_backward(symbol.id)(backward_op)
414411

415-
# NOTE: `thunder.extend.add_default_executor` basically does `lst.insert(ex, 0)`.
416-
if custom_op_ex not in get_default_executors():
417-
default_executors = get_default_executors()
418-
new_default_executors = add_executor_lists(default_executors, [custom_op_ex])
419-
set_default_executors(new_default_executors)
420-
421412
_CUSTOM_OP_TO_TORCHFN_AND_SYMBOL[custom_op] = ((torch_opoverload, torch_opoverload_packet), symbol)
422413

423414
return symbol
@@ -450,12 +441,6 @@ def _deregister_custom_op(custom_op: CustomOpDef) -> None:
450441

451442
del _CUSTOM_OP_TO_TORCHFN_AND_SYMBOL[custom_op]
452443

453-
# Remove custom_op_ex from default executors if no custom ops remain
454-
if not _CUSTOM_OP_TO_TORCHFN_AND_SYMBOL:
455-
from thunder.extend import remove_default_executor
456-
457-
remove_default_executor(custom_op_ex)
458-
459444

460445
def _register_nvfuser_translator(
461446
symbol: Symbol,

0 commit comments

Comments
 (0)