| 
10 | 10 | from contextlib import contextmanager, nullcontext  | 
11 | 11 | from dataclasses import dataclass  | 
12 | 12 | from functools import singledispatch  | 
13 |  | -from typing import Dict, Generator, List, Mapping  | 
 | 13 | +from typing import Dict, Generator, List, Mapping, Set  | 
14 | 14 | 
 
  | 
15 | 15 | import torch  | 
16 | 16 | 
 
  | 
@@ -581,9 +581,21 @@ def lower_all_submodules_to_backend(  | 
581 | 581 |         for method_name, call_submodule_nodes in method_to_submodules_nodes.items()  | 
582 | 582 |     }  | 
583 | 583 | 
 
  | 
 | 584 | +    def _get_all_final_backend_details_subclasses(cls) -> Set[type]:  | 
 | 585 | +        subclasses = set()  | 
 | 586 | +        for subclass in cls.__subclasses__():  | 
 | 587 | +            # Check if subclass is a final class (marked as @final)  | 
 | 588 | +            if getattr(subclass, "__final__", False):  | 
 | 589 | +                subclasses.add(subclass)  | 
 | 590 | +            # Recursively check subclasses  | 
 | 591 | +            subclasses.update(_get_all_final_backend_details_subclasses(subclass))  | 
 | 592 | +        return subclasses  | 
 | 593 | + | 
584 | 594 |     backend_name_to_subclass = {  | 
585 |  | -        subclass.__name__: subclass for subclass in BackendDetails.__subclasses__()  | 
 | 595 | +        subclass.__name__: subclass  | 
 | 596 | +        for subclass in _get_all_final_backend_details_subclasses(BackendDetails)  | 
586 | 597 |     }  | 
 | 598 | + | 
587 | 599 |     if backend_id not in backend_name_to_subclass:  | 
588 | 600 |         raise NotImplementedError(f"Backend {backend_id} was not found.")  | 
589 | 601 | 
 
  | 
 | 
0 commit comments