15
15
import logging
16
16
import re
17
17
from collections .abc import Generator
18
- from typing import Iterable , Mapping , Optional , Tuple
18
+ from typing import Iterable , List , Mapping , Optional , Tuple
19
19
20
20
import torch
21
21
from compressed_tensors .utils .internal import InternalModule
@@ -57,10 +57,10 @@ def match_named_modules(
57
57
unmatched_targets = set (targets )
58
58
for name , module in model .named_modules ():
59
59
for target in targets :
60
- if is_match (name , module , target , fused ):
60
+ if is_match (name , module , target , fused = fused ):
61
61
unmatched_targets -= {target }
62
62
63
- if not any ( is_match (name , module , ign , fused ) for ign in ignore ):
63
+ if not is_match (name , module , ignore , fused = fused ):
64
64
yield name , module
65
65
66
66
if warn_on_fail :
@@ -155,9 +155,7 @@ def match_modules_set(
155
155
for name , module in model .named_modules ():
156
156
# match until we get a full set
157
157
for target in targets :
158
- if is_match (name , module , target ) and not any (
159
- is_match (name , module , ign ) for ign in ignore
160
- ):
158
+ if is_match (name , module , target , ignore ):
161
159
if matches [target ] is not None :
162
160
raise ValueError (f"Matched a { target } twice before completing set" )
163
161
matches [target ] = module
@@ -176,7 +174,8 @@ def match_modules_set(
176
174
def is_match (
177
175
name : str ,
178
176
module : torch .nn .Module ,
179
- target : str ,
177
+ targets : str | Iterable [str ],
178
+ ignore : str | Iterable [str ] = tuple (),
180
179
fused : Optional [FusedMappping ] = None ,
181
180
) -> bool :
182
181
"""
@@ -198,8 +197,17 @@ def is_match(
198
197
:fused: optional mapping from suffixes of fused modules to the suffixes of their
199
198
corresponding shards
200
199
"""
200
+ targets = [targets ] if isinstance (targets , str ) else targets
201
+ ignore = [ignore ] if isinstance (ignore , str ) else ignore
202
+
201
203
return not isinstance (module , InternalModule ) and (
202
- _match_name (name , target , fused ) or _match_class (module , target )
204
+ any (
205
+ _match_name (name , target , fused ) or _match_class (module , target )
206
+ for target in targets
207
+ )
208
+ and not any (
209
+ _match_name (name , ign , fused ) or _match_class (module , ign ) for ign in ignore
210
+ )
203
211
)
204
212
205
213
0 commit comments