Skip to content

Commit fde929c

Browse files
xuhancnpytorchmergebot
authored andcommitted
[AOTI] Fix model_package_loader get_cpp_compile_command (pytorch#163561)
It should fix AOTI UTs of `test_aot_inductor_package.py`, these cases are failed at `compile_so`. reproducer: ```cmd pytest test\inductor\test_aot_inductor_package.py -v -k test_multiple_methods ``` <img width="1262" height="95" alt="image" src="https://github.com/user-attachments/assets/49458536-1cfe-498e-a12a-2bfd8da67a9e" /> Major fix at `get_cpp_compile_command`. The code is aligned to cpp_builder frontend code: https://github.com/pytorch/pytorch/blob/3ef1bef36c73b4def0e1b71847e27fde1556c0fb/torch/_inductor/cpp_builder.py#L1780-L1790 https://github.com/pytorch/pytorch/blob/3ef1bef36c73b4def0e1b71847e27fde1556c0fb/torch/_inductor/cpp_builder.py#L1959-L1976 Fixed on Windows: <img width="1261" height="89" alt="Image" src="https://github.com/user-attachments/assets/9bf43b11-aac1-4161-a625-e602e313a299" /> Also validated on Linux: <img width="1039" height="81" alt="Image" src="https://github.com/user-attachments/assets/46063e16-6cf1-4a28-8466-0496871b8619" /> Pull Request resolved: pytorch#163561 Approved by: https://github.com/jansel
1 parent 134dfbe commit fde929c

File tree

1 file changed

+75
-26
lines changed

1 file changed

+75
-26
lines changed

torch/csrc/inductor/aoti_package/model_package_loader.cpp

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@ namespace {
4040

4141
const std::string k_separator = "/";
4242

43+
std::string remove_duplicate_separator_of_path(const std::string& path) {
44+
/*
45+
On Windows, temp file path maybe has duplicate separator.
46+
Need to remove the duplication:
47+
Origin: C:/Users/Xuhan/AppData/Local/Temp//tmpl10jfwef/filename
48+
Processed: C:/Users/Xuhan/AppData/Local/Temp/tmpl10jfwef/filename
49+
*/
50+
std::string result = path;
51+
size_t pos = 0;
52+
53+
while ((pos = result.find("//", pos)) != std::string::npos) {
54+
result.replace(pos, 2, "/");
55+
}
56+
57+
return result;
58+
}
59+
4360
std::string normalize_path_separator(const std::string& orig_path) {
4461
/*
4562
On Windows and Linux have different separator:
@@ -57,6 +74,7 @@ std::string normalize_path_separator(const std::string& orig_path) {
5774
#ifdef _WIN32
5875
std::replace(normalized_path.begin(), normalized_path.end(), '\\', '/');
5976
#endif
77+
normalized_path = remove_duplicate_separator_of_path(normalized_path);
6078
return normalized_path;
6179
}
6280

@@ -108,6 +126,22 @@ const char* extension_file_ext() {
108126
#endif
109127
}
110128

129+
const char* get_output_flags(bool compile_only) {
130+
if (compile_only) {
131+
#ifdef _WIN32
132+
return "/c /Fo"; // codespell:ignore
133+
#else
134+
return "-c -o";
135+
#endif
136+
}
137+
138+
#ifdef _WIN32
139+
return "/Fe";
140+
#else
141+
return "-o";
142+
#endif
143+
}
144+
111145
bool _is_windows_os() {
112146
#ifdef _WIN32
113147
return true;
@@ -146,7 +180,7 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
146180

147181
std::string source_args;
148182
for (const std::string& source : sources) {
149-
source_args += source + " ";
183+
source_args += normalize_path_separator(source) + " ";
150184
}
151185

152186
std::string file_ext =
@@ -160,24 +194,28 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
160194

161195
std::string cflags_args;
162196
for (auto& arg : compile_options["cflags"]) {
163-
cflags_args += _is_windows_os() ? "/" : "-" + arg.get<std::string>() + " ";
197+
// [Windows compiler need it] convert first char arg to std::string, for
198+
// following plus(+) strings.
199+
cflags_args += std::string(_is_windows_os() ? "/" : "-") +
200+
arg.get<std::string>() + " ";
164201
}
165202

166203
std::string definitions_args;
167204
for (auto& arg : compile_options["definitions"]) {
168-
definitions_args +=
169-
_is_windows_os() ? "/D" : "-D " + arg.get<std::string>() + " ";
205+
definitions_args += std::string(_is_windows_os() ? "/D" : "-D ") +
206+
arg.get<std::string>() + " ";
170207
}
171208

172209
std::string include_dirs_args;
173210
for (auto& arg : compile_options["include_dirs"]) {
174-
include_dirs_args +=
175-
_is_windows_os() ? "/I" : "-I" + arg.get<std::string>() + " ";
211+
include_dirs_args += std::string(_is_windows_os() ? "/I" : "-I") +
212+
arg.get<std::string>() + " ";
176213
}
177214

178215
std::string ldflags_args;
179216
for (auto& arg : compile_options["ldflags"]) {
180-
ldflags_args += _is_windows_os() ? "/" : "-" + arg.get<std::string>() + " ";
217+
ldflags_args += std::string(_is_windows_os() ? "/" : "-") +
218+
arg.get<std::string>() + " ";
181219
}
182220

183221
std::string libraries_dirs_args;
@@ -209,38 +247,48 @@ std::tuple<std::string, std::string> get_cpp_compile_command(
209247
passthrough_parameters_args += arg_str + " ";
210248
}
211249

212-
std::string compile_only_arg =
213-
compile_only ? (_is_windows_os() ? "/c" : "-c") : "";
250+
std::string output_flags = get_output_flags(compile_only);
214251

215252
std::string cmd;
253+
/*
254+
Format command as python frontend cpp_builder:
255+
https://github.com/pytorch/pytorch/blob/3ef1bef36c73b4def0e1b71847e27fde1556c0fb/torch/_inductor/cpp_builder.py#L1780-L1790
256+
https://github.com/pytorch/pytorch/blob/3ef1bef36c73b4def0e1b71847e27fde1556c0fb/torch/_inductor/cpp_builder.py#L1959-L1976
257+
*/
216258
if (_is_windows_os()) {
217-
cmd = normalize_path_separator(fmt::format(
218-
"{} {} {} {} {} {} /LD /Fe{} {} /link {} {} {}",
259+
cmd = fmt::format(
260+
"{} {} {} {} {} {} {}{}",
219261
compiler,
220262
include_dirs_args,
221263
definitions_args,
222264
cflags_args,
223265
source_args,
224266
passthrough_parameters_args,
225-
target_file,
226-
compile_only_arg,
227-
libraries_dirs_args,
228-
libraries_args,
229-
ldflags_args));
267+
output_flags,
268+
target_file);
269+
if (compile_only == false) {
270+
cmd += fmt::format(
271+
" /LD /link {} {} {}",
272+
libraries_dirs_args,
273+
libraries_args,
274+
ldflags_args);
275+
}
276+
cmd = normalize_path_separator(cmd);
230277
} else {
231-
cmd = normalize_path_separator(fmt::format(
232-
"{} {} {} {} {} {} {} {} {} {} -o {}",
278+
cmd = fmt::format(
279+
"{} {} {} {} {} {} {} {}",
233280
compiler,
234281
source_args,
235282
definitions_args,
236283
cflags_args,
237284
include_dirs_args,
238285
passthrough_parameters_args,
239-
ldflags_args,
240-
libraries_args,
241-
libraries_dirs_args,
242-
compile_only_arg,
243-
target_file));
286+
output_flags,
287+
target_file);
288+
if (compile_only == false) {
289+
cmd += fmt::format(
290+
" {} {} {}", ldflags_args, libraries_args, libraries_dirs_args);
291+
}
244292
}
245293

246294
return std::make_tuple(cmd, target_file);
@@ -350,14 +398,15 @@ std::string compile_so(
350398
size_t lastindex = cpp_filename.find_last_of('.');
351399
std::string filename = cpp_filename.substr(0, lastindex);
352400

353-
std::string compile_flags_path = filename + "_compile_flags.json";
401+
std::string compile_flags_path =
402+
normalize_path_separator(filename + "_compile_flags.json");
354403
const nlohmann::json compile_flags = load_json_file(compile_flags_path);
355404

356405
auto [compile_cmd, output_o] =
357406
get_cpp_compile_command(filename, {cpp_filename}, compile_flags);
358407

359-
std::string linker_flags_path =
360-
cpp_filename.substr(0, lastindex) + "_linker_flags.json";
408+
std::string linker_flags_path = normalize_path_separator(
409+
cpp_filename.substr(0, lastindex) + "_linker_flags.json");
361410
const nlohmann::json linker_flags = load_json_file(linker_flags_path);
362411

363412
obj_filenames.push_back(output_o);

0 commit comments

Comments
 (0)