forked from vllm-project/llm-compressor
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfactory.py
More file actions
152 lines (127 loc) · 5.52 KB
/
factory.py
File metadata and controls
152 lines (127 loc) · 5.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import importlib
import pkgutil
from llmcompressor.modifiers.modifier import Modifier
__all__ = ["ModifierFactory"]
class ModifierFactory:
"""
A factory for loading and registering modifiers
"""
_MAIN_PACKAGE_PATH = "llmcompressor.modifiers"
_EXPERIMENTAL_PACKAGE_PATH = "llmcompressor.modifiers.experimental"
_loaded: bool = False
_main_registry: dict[str, type[Modifier]] = {}
_experimental_registry: dict[str, type[Modifier]] = {}
_registered_registry: dict[str, type[Modifier]] = {}
_errors: dict[str, Exception] = {}
@staticmethod
def refresh():
"""
A method to refresh the factory by reloading the modifiers
Note: this will clear any previously registered modifiers
"""
ModifierFactory._main_registry = ModifierFactory.load_from_package(
ModifierFactory._MAIN_PACKAGE_PATH
)
ModifierFactory._experimental_registry = ModifierFactory.load_from_package(
ModifierFactory._EXPERIMENTAL_PACKAGE_PATH
)
ModifierFactory._loaded = True
@staticmethod
def load_from_package(package_path: str) -> dict[str, type[Modifier]]:
"""
:param package_path: The path to the package to load modifiers from
:return: The loaded modifiers, as a mapping of name to class
"""
loaded = {}
main_package = importlib.import_module(package_path)
# exclude deprecated packages from registry so
# their new location is used instead
deprecated_packages = [
"llmcompressor.modifiers.obcq",
"llmcompressor.modifiers.obcq.sgpt_base",
"llmcompressor.modifiers.quantization.gptq",
"llmcompressor.modifiers.quantization.gptq.base",
"llmcompressor.modifiers.quantization.gptq.gptq_quantize",
]
for _importer, modname, _is_pkg in pkgutil.walk_packages(
main_package.__path__, package_path + "."
):
if modname in deprecated_packages:
continue
try:
module = importlib.import_module(modname)
for attribute_name in dir(module):
if not attribute_name.endswith("Modifier"):
continue
try:
if attribute_name in loaded:
continue
attr = getattr(module, attribute_name)
if not isinstance(attr, type):
raise ValueError(
f"Attribute {attribute_name} is not a type"
)
if not issubclass(attr, Modifier):
raise ValueError(
f"Attribute {attribute_name} is not a Modifier"
)
loaded[attribute_name] = attr
except Exception as err:
# TODO: log import error
ModifierFactory._errors[attribute_name] = err
except Exception as module_err:
# TODO: log import error
print(module_err)
return loaded
@staticmethod
def create(
type_: str,
allow_registered: bool,
allow_experimental: bool,
**kwargs,
) -> Modifier:
"""
Instantiate a modifier of the given type from registered modifiers.
:raises ValueError: If no modifier of the given type is found
:param type_: The type of modifier to create
:param framework: The framework the modifier is for
:param allow_registered: Whether or not to allow registered modifiers
:param allow_experimental: Whether or not to allow experimental modifiers
:param kwargs: Additional keyword arguments to pass to the modifier
during instantiation
:return: The instantiated modifier
"""
if type_ in ModifierFactory._errors:
raise ModifierFactory._errors[type_]
if type_ in ModifierFactory._registered_registry:
if allow_registered:
return ModifierFactory._registered_registry[type_](**kwargs)
else:
# TODO: log warning that modifier was skipped
pass
if type_ in ModifierFactory._experimental_registry:
if allow_experimental:
return ModifierFactory._experimental_registry[type_](**kwargs)
else:
# TODO: log warning that modifier was skipped
pass
if type_ in ModifierFactory._main_registry:
return ModifierFactory._main_registry[type_](**kwargs)
raise ValueError(f"No modifier of type '{type_}' found.")
@staticmethod
def register(type_: str, modifier_class: type[Modifier]):
"""
Register a modifier class to be used by the factory.
:raises ValueError: If the provided class does not subclass the Modifier
base class or is not a type
:param type_: The type of modifier to register
:param modifier_class: The class of the modifier to register, must subclass
the Modifier base class
"""
if not issubclass(modifier_class, Modifier):
raise ValueError(
"The provided class does not subclass the Modifier base class."
)
if not isinstance(modifier_class, type):
raise ValueError("The provided class is not a type.")
ModifierFactory._registered_registry[type_] = modifier_class