Skip to content

Commit 548acbb

Browse files
authored
Update get_block_names func (#1047)
1 parent 247a9a7 commit 548acbb

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

auto_round/utils/model.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -631,18 +631,13 @@ def _get_vlm_block_names(model, quant_vision=False):
631631
block_names = []
632632
target_modules = []
633633
vision_blocks_tuple = ("vision", "visual", "image", "img")
634-
last_block_name = ""
635-
for n, m in model.named_modules():
636-
if hasattr(type(m), "__name__") and "ModuleList" in type(m).__name__:
637-
if quant_vision or all(key not in n.lower() for key in (vision_blocks_tuple)):
638-
if last_block_name and last_block_name in n:
639-
continue
640-
target_modules.append((n, m))
641-
last_block_name = n
634+
target_modules = _search_block("", model)
635+
642636
for i, target_m in enumerate(target_modules):
643-
block_names.append([])
644-
for n, m in target_m[1].named_children():
645-
block_names[i].append(target_m[0] + "." + n)
637+
if quant_vision or all(key not in target_m[0].lower() for key in (vision_blocks_tuple)):
638+
block_names.append([])
639+
for n, m in target_m[1].named_children():
640+
block_names[-1].append(target_m[0] + "." + n)
646641
return block_names
647642

648643
if quant_vision or not is_pure_text_model(model):

test/test_cuda/test_get_block_name.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def tearDownClass(self):
3030
shutil.rmtree("runs", ignore_errors=True)
3131

3232
def check_block_names(self, block_names, prefixs=[], n_layers=[]):
33+
assert len(block_names) == len(prefixs) == len(n_layers)
3334
for i, block_name in enumerate(block_names):
3435
prefix = prefixs[i]
3536
n_layer = n_layers[i]
@@ -196,6 +197,9 @@ def test_flux(self):
196197
self.check_block_names(block_names, ["transformer_blocks", "single_transformer_blocks"], [19, 38])
197198
self.assertTrue(any(["context_embedder" not in n for n in block_names]))
198199

200+
block_names = get_block_names(model, quant_vision=True)
201+
self.check_block_names(block_names, ["transformer_blocks", "single_transformer_blocks"], [19, 38])
202+
199203

200204
if __name__ == "__main__":
201205
unittest.main()

0 commit comments

Comments
 (0)