Skip to content

Commit 513f768

Browse files
authored
Modular fix: remove the model name in find_file_type (#39897)
* remove the model name in the class name * add comment
1 parent 743bb5f commit 513f768

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

utils/modular_model_converter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,13 +1086,15 @@ def replace_class_node(
10861086
}
10871087

10881088

1089-
def find_file_type(class_name: str) -> str:
1089+
def find_file_type(class_name: str, model_name: str) -> str:
10901090
"""Based on a class name, find the file type corresponding to the class.
10911091
If the class name is `LlamaConfig` it will return `configuration`.
10921092
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling`
10931093
"""
10941094
match_pattern = "|".join(TYPE_TO_FILE_TYPE.keys())
1095-
match = re.search(rf"({match_pattern})$", class_name)
1095+
# We remove the model name to avoid ambiguity, e.g. for `Sam2VideoProcessor`,
1096+
# removing `Sam2Video` ensures we match `Processor` instead of `VideoProcessor`.
1097+
match = re.search(rf"({match_pattern})$", class_name.replace(get_cased_name(model_name), ""))
10961098
if match:
10971099
file_type = TYPE_TO_FILE_TYPE[match.group(1)]
10981100
else:
@@ -1175,7 +1177,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
11751177
return usual_import_nodes + protected_import_nodes
11761178

11771179

1178-
def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]:
1180+
def split_all_assignment(node: cst.CSTNode, model_name: str) -> dict[str, cst.CSTNode]:
11791181
"""Split the `__all__` assignment found in the modular between each corresponding files."""
11801182
all_all_per_file = {}
11811183
assign_node = node.body[0]
@@ -1186,7 +1188,7 @@ def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]:
11861188
if isinstance(element.value, cst.SimpleString):
11871189
# Remove quotes and add the string to the elements list
11881190
class_name = element.value.value
1189-
file = find_file_type(element.value.evaluated_value)
1191+
file = find_file_type(element.value.evaluated_value, model_name)
11901192
all_all_to_add[file] += [class_name]
11911193
for file, new_alls in all_all_to_add.items():
11921194
new_node = assign_node.with_changes(
@@ -1275,7 +1277,7 @@ def visit_SimpleStatementLine(self, node):
12751277
assigned_variable = node.body[0].targets[0].target.value
12761278
# __all__ is treated differently and not added to general assignments
12771279
if assigned_variable == "__all__":
1278-
self.all_all_to_add = split_all_assignment(node)
1280+
self.all_all_to_add = split_all_assignment(node, self.model_name)
12791281
else:
12801282
self.current_assignment = assigned_variable
12811283
self.assignments[assigned_variable] = node
@@ -1531,7 +1533,7 @@ class NewNameModel(LlamaModel):
15311533
corrected_dependencies = new_dependencies.copy()
15321534
new_imports = {}
15331535
for class_name in class_dependencies:
1534-
class_file_type = find_file_type(class_name)
1536+
class_file_type = find_file_type(class_name, new_name)
15351537
# In this case, we need to remove it from the dependencies and create a new import instead
15361538
if class_file_type != file_type:
15371539
corrected_dependencies.remove(class_name)
@@ -1554,7 +1556,7 @@ class node based on the inherited classes if needed. Also returns any new import
15541556
]
15551557
super_class = model_specific_bases[0] if len(model_specific_bases) == 1 else None
15561558

1557-
file_type = find_file_type(class_name)
1559+
file_type = find_file_type(class_name, modular_mapper.model_name)
15581560
file_to_update = files[file_type]
15591561
model_name = modular_mapper.model_name
15601562

0 commit comments

Comments
 (0)