47
47
48
48
49
49
__all__ = [
50
- "match_targets" ,
51
- "get_terminal_layers" ,
52
- "get_prunable_layers" ,
53
- "get_quantizable_layers" ,
54
50
"qat_active" ,
55
51
"get_matching_layer" ,
56
52
"get_no_split_params" ,
61
57
ALL_QUANTIZABLE_TARGET = "__ALL_QUANTIZABLE__"
62
58
63
59
64
- def match_targets (name : str , targets : Union [str , List [str ]]) -> Tuple [bool , int ]:
65
- if isinstance (targets , str ):
66
- targets = [targets ]
67
-
68
- for index , target in enumerate (targets ):
69
- if target [:3 ] == "re:" :
70
- pattern = target [3 :]
71
- if re .match (pattern , name ):
72
- return True , index
73
- elif name == target :
74
- return True , index
75
-
76
- return False , - 1
77
-
78
-
79
60
def match_class (layer : Module , targets : Union [str , List [str ]]) -> Tuple [bool , int ]:
80
61
if isinstance (targets , str ):
81
62
targets = [targets ]
@@ -87,51 +68,6 @@ def match_class(layer: Module, targets: Union[str, List[str]]) -> Tuple[bool, in
87
68
return False , - 1
88
69
89
70
90
- def get_terminal_layers (module : Module ) -> Dict [str , Module ]:
91
- terminal = {}
92
-
93
- for name , layer in module .named_modules ():
94
- if len (list (layer .named_modules ())) > 1 :
95
- continue
96
-
97
- terminal [name ] = layer
98
-
99
- return terminal
100
-
101
-
102
- def get_prunable_layers (module : Module ) -> Dict [str , Module ]:
103
- prunable = {}
104
-
105
- for name , layer in module .named_modules ():
106
- if (
107
- isinstance (layer , Linear )
108
- or isinstance (layer , _ConvNd )
109
- or (QATLinear and isinstance (layer , QATLinear ))
110
- or (QATConv2d and isinstance (layer , QATConv2d ))
111
- or (QATConv3d and isinstance (layer , QATConv3d ))
112
- or (TransformerConv1D and isinstance (layer , TransformerConv1D ))
113
- ):
114
- prunable [name ] = layer
115
-
116
- return prunable
117
-
118
-
119
- def get_quantizable_layers (module : Module ) -> Dict [str , Module ]:
120
- if QATLinear is None :
121
- raise ImportError (
122
- "PyTorch version is not setup for Quantization. "
123
- "Please install a QAT compatible version of PyTorch"
124
- )
125
-
126
- quantizable = {}
127
-
128
- for name , layer in module .named_modules ():
129
- if isinstance (layer , Linear ) or isinstance (layer , _ConvNd ):
130
- quantizable [name ] = layer
131
-
132
- return quantizable
133
-
134
-
135
71
def qat_active (module : Module ) -> bool :
136
72
"""
137
73
Determines if any layers in the model have quantization enabled by checking for
0 commit comments