Skip to content

Commit c58ec95

Browse files
hlkyArthurZucker
authored andcommitted
Fixes for Modular Converter on Windows (huggingface#34266)
* Separator in regex * Standardize separator for relative path in auto generated message * open() encoding * Replace `\` on `os.path.abspath` --------- Co-authored-by: Arthur <[email protected]>
1 parent 40f93a0 commit c58ec95

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

utils/modular_model_converter.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def get_module_source_from_name(module_name: str) -> str:
5656
if spec is None or spec.origin is None:
5757
return f"Module {module_name} not found"
5858

59-
with open(spec.origin, "r") as file:
59+
with open(spec.origin, "r", encoding="utf-8") as file:
6060
source_code = file.read()
6161
return source_code
6262

@@ -1132,7 +1132,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
11321132
if pattern is not None:
11331133
model_name = pattern.groups()[0]
11341134
# Parse the Python file
1135-
with open(modular_file, "r") as file:
1135+
with open(modular_file, "r", encoding="utf-8") as file:
11361136
code = file.read()
11371137
module = cst.parse_module(code)
11381138
wrapper = MetadataWrapper(module)
@@ -1143,7 +1143,7 @@ def convert_modular_file(modular_file, old_model_name=None, new_model_name=None,
11431143
if node != {}:
11441144
# Get relative path starting from src/transformers/
11451145
relative_path = re.search(
1146-
rf"(src{os.sep}transformers{os.sep}.*|examples{os.sep}.*)", os.path.abspath(modular_file)
1146+
r"(src/transformers/.*|examples/.*)", os.path.abspath(modular_file).replace("\\", "/")
11471147
).group(1)
11481148

11491149
header = AUTO_GENERATED_MESSAGE.format(
@@ -1164,15 +1164,15 @@ def save_modeling_file(modular_file, converted_file):
11641164
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
11651165
)
11661166
if len(converted_file[file_type][0].strip()) > 0 and non_comment_lines > 0:
1167-
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
1167+
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
11681168
f.write(converted_file[file_type][0])
11691169
else:
11701170
non_comment_lines = len(
11711171
[line for line in converted_file[file_type][0].strip().split("\n") if not line.strip().startswith("#")]
11721172
)
11731173
if len(converted_file[file_type][1].strip()) > 0 and non_comment_lines > 0:
11741174
logger.warning("The modeling code contains errors, it's written without formatting")
1175-
with open(modular_file.replace("modular_", f"{file_type}_"), "w") as f:
1175+
with open(modular_file.replace("modular_", f"{file_type}_"), "w", encoding="utf-8") as f:
11761176
f.write(converted_file[file_type][1])
11771177

11781178

0 commit comments

Comments
 (0)