@@ -1086,13 +1086,15 @@ def replace_class_node(
1086
1086
}
1087
1087
1088
1088
1089
- def find_file_type (class_name : str ) -> str :
1089
+ def find_file_type (class_name : str , model_name : str ) -> str :
1090
1090
"""Based on a class name, find the file type corresponding to the class.
1091
1091
If the class name is `LlamaConfig` it will return `configuration`.
1092
1092
The list of suffixes is in `TYPE_TO_FILE_TYPE`. If there are no match, we match by default to `modeling`
1093
1093
"""
1094
1094
match_pattern = "|" .join (TYPE_TO_FILE_TYPE .keys ())
1095
- match = re .search (rf"({ match_pattern } )$" , class_name )
1095
+ # We remove the model name to avoid ambiguity, e.g. for `Sam2VideoProcessor`,
1096
+ # removing `Sam2Video` ensures we match `Processor` instead of `VideoProcessor`.
1097
+ match = re .search (rf"({ match_pattern } )$" , class_name .replace (get_cased_name (model_name ), "" ))
1096
1098
if match :
1097
1099
file_type = TYPE_TO_FILE_TYPE [match .group (1 )]
1098
1100
else :
@@ -1175,7 +1177,7 @@ def get_needed_imports(body: dict[str, dict], all_imports: list[cst.CSTNode]) ->
1175
1177
return usual_import_nodes + protected_import_nodes
1176
1178
1177
1179
1178
- def split_all_assignment (node : cst .CSTNode ) -> dict [str , cst .CSTNode ]:
1180
+ def split_all_assignment (node : cst .CSTNode , model_name : str ) -> dict [str , cst .CSTNode ]:
1179
1181
"""Split the `__all__` assignment found in the modular between each corresponding files."""
1180
1182
all_all_per_file = {}
1181
1183
assign_node = node .body [0 ]
@@ -1186,7 +1188,7 @@ def split_all_assignment(node: cst.CSTNode) -> dict[str, cst.CSTNode]:
1186
1188
if isinstance (element .value , cst .SimpleString ):
1187
1189
# Remove quotes and add the string to the elements list
1188
1190
class_name = element .value .value
1189
- file = find_file_type (element .value .evaluated_value )
1191
+ file = find_file_type (element .value .evaluated_value , model_name )
1190
1192
all_all_to_add [file ] += [class_name ]
1191
1193
for file , new_alls in all_all_to_add .items ():
1192
1194
new_node = assign_node .with_changes (
@@ -1275,7 +1277,7 @@ def visit_SimpleStatementLine(self, node):
1275
1277
assigned_variable = node .body [0 ].targets [0 ].target .value
1276
1278
# __all__ is treated differently and not added to general assignments
1277
1279
if assigned_variable == "__all__" :
1278
- self .all_all_to_add = split_all_assignment (node )
1280
+ self .all_all_to_add = split_all_assignment (node , self . model_name )
1279
1281
else :
1280
1282
self .current_assignment = assigned_variable
1281
1283
self .assignments [assigned_variable ] = node
@@ -1531,7 +1533,7 @@ class NewNameModel(LlamaModel):
1531
1533
corrected_dependencies = new_dependencies .copy ()
1532
1534
new_imports = {}
1533
1535
for class_name in class_dependencies :
1534
- class_file_type = find_file_type (class_name )
1536
+ class_file_type = find_file_type (class_name , new_name )
1535
1537
# In this case, we need to remove it from the dependencies and create a new import instead
1536
1538
if class_file_type != file_type :
1537
1539
corrected_dependencies .remove (class_name )
@@ -1554,7 +1556,7 @@ class node based on the inherited classes if needed. Also returns any new import
1554
1556
]
1555
1557
super_class = model_specific_bases [0 ] if len (model_specific_bases ) == 1 else None
1556
1558
1557
- file_type = find_file_type (class_name )
1559
+ file_type = find_file_type (class_name , modular_mapper . model_name )
1558
1560
file_to_update = files [file_type ]
1559
1561
model_name = modular_mapper .model_name
1560
1562
0 commit comments