Skip to content

Commit 2750902

Browse files
cccclaifacebook-github-bot
authored andcommitted
Prohibit nested backends
Summary: pytorch#15528 initially wanted to subclass a backend.. It was currently already guarded by https://github.com/pytorch/executorch/blob/main/exir/backend/backend_api.py#L111-L112 meaning that subclass will not show up. However it's not super obvious so we want to guard by disallowing subclass at all Differential Revision: D87105211
1 parent a6c5921 commit 2750902

File tree

2 files changed

+32
-0
lines changed

2 files changed

+32
-0
lines changed

exir/backend/backend_details.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,22 @@ class BackendDetails(ABC):
5757
5858
"""
5959

60+
def __init_subclass__(cls, **kwargs):
61+
super().__init_subclass__(**kwargs)
62+
63+
# Allow direct subclasses of BackendDetails
64+
if cls.__bases__ == (BackendDetails,):
65+
return
66+
67+
# Forbid subclasses whose ANY parent is already a child of BackendDetails
68+
for base in cls.__bases__:
69+
if issubclass(base, BackendDetails) and base is not BackendDetails:
70+
raise TypeError(
71+
f"ExecuTorch delegate doesn't support nested backend, '{base.__name__}' "
72+
" should be a final backend implementation and should not be subclassed "
73+
f"(attempted by '{cls.__name__}')."
74+
)
75+
6076
@staticmethod
6177
# all backends need to implement this method
6278
@enforcedmethod

exir/backend/test/test_backends.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
from executorch.exir import to_edge
1414
from executorch.exir.backend.backend_api import LoweredBackendModule, to_backend
15+
from executorch.exir.backend.backend_details import BackendDetails
1516
from executorch.exir.backend.canonical_partitioners.all_node_partitioner import (
1617
AllNodePartitioner,
1718
)
@@ -1444,3 +1445,18 @@ def inputs(self):
14441445
self.assertTrue(
14451446
torch.allclose(model_outputs[0], ref_output, atol=1e-03, rtol=1e-03)
14461447
)
1448+
1449+
def test_prohibited_nested_backends(self):
1450+
class MyBackend(BackendDetails):
1451+
@staticmethod
1452+
def preprocess(edge_program, compile_specs):
1453+
return None
1454+
1455+
with self.assertRaises(TypeError) as ctx:
1456+
class MyOtherBackend(MyBackend):
1457+
pass
1458+
1459+
self.assertIn(
1460+
"'MyBackend' should be a final backend implementation and should not be subclassed (attempted by 'MyOtherBackend')",
1461+
str(ctx.exception)
1462+
)

0 commit comments

Comments
 (0)