Skip to content

Commit 769d5da

Browse files
pytorchbotatalman
andauthored
[binary builds] Linux aarch64 CUDA builds. Make sure tag is set correctly (pytorch#154136)
[binary builds] Linux aarch64 CUDA builds. Make sure tag is set correctly (pytorch#154045) 1. This should set the Manylinux 2.28 tag correctly for CUDA Aarch builds. I believe we used to have something similar in the old script: https://github.com/pytorch/pytorch/blob/main/.ci/aarch64_linux/build_aarch64_wheel.py#L811 ``Tag: cp311-cp311-linux_aarch64 ``-> ``Tag: cp311-cp311-manylinux_2_28_aarch64`` 2. Remove section for CUDA 12.6, since we no longer building CUDA 12.6 aarch64 builds Pull Request resolved: pytorch#154045 Approved by: https://github.com/Camyll, https://github.com/malfet (cherry picked from commit 4277907) Co-authored-by: atalman <[email protected]>
1 parent 306ba12 commit 769d5da

File tree

1 file changed

+33
-24
lines changed

1 file changed

+33
-24
lines changed

.ci/aarch64_linux/aarch64_wheel_ci_build.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,22 @@ def build_ArmComputeLibrary() -> None:
5555
shutil.copytree(f"{acl_checkout_dir}/{d}", f"{acl_install_dir}/{d}")
5656

5757

58-
def update_wheel(wheel_path, desired_cuda) -> None:
58+
def replace_tag(filename) -> None:
59+
with open(filename) as f:
60+
lines = f.readlines()
61+
for i, line in enumerate(lines):
62+
if line.startswith("Tag:"):
63+
lines[i] = line.replace("-linux_", "-manylinux_2_28_")
64+
print(f"Updated tag from {line} to {lines[i]}")
65+
break
66+
67+
with open(filename, "w") as f:
68+
f.writelines(lines)
69+
70+
71+
def package_cuda_wheel(wheel_path, desired_cuda) -> None:
5972
"""
60-
Update the cuda wheel libraries
73+
Package the cuda wheel libraries
6174
"""
6275
folder = os.path.dirname(wheel_path)
6376
wheelname = os.path.basename(wheel_path)
@@ -88,30 +101,19 @@ def update_wheel(wheel_path, desired_cuda) -> None:
88101
"/usr/lib64/libgfortran.so.5",
89102
"/acl/build/libarm_compute.so",
90103
"/acl/build/libarm_compute_graph.so",
104+
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
105+
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
106+
"/usr/local/lib/libnvpl_lapack_core.so.0",
107+
"/usr/local/lib/libnvpl_blas_core.so.0",
91108
]
92-
if enable_cuda:
93-
libs_to_copy += [
94-
"/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0",
95-
"/usr/local/lib/libnvpl_blas_lp64_gomp.so.0",
96-
"/usr/local/lib/libnvpl_lapack_core.so.0",
97-
"/usr/local/lib/libnvpl_blas_core.so.0",
98-
]
99-
if "126" in desired_cuda:
100-
libs_to_copy += [
101-
"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.6",
102-
"/usr/local/cuda/lib64/libcufile.so.0",
103-
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
104-
]
105-
elif "128" in desired_cuda:
106-
libs_to_copy += [
107-
"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8",
108-
"/usr/local/cuda/lib64/libcufile.so.0",
109-
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
110-
]
111-
else:
109+
110+
if "128" in desired_cuda:
112111
libs_to_copy += [
113-
"/opt/OpenBLAS/lib/libopenblas.so.0",
112+
"/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8",
113+
"/usr/local/cuda/lib64/libcufile.so.0",
114+
"/usr/local/cuda/lib64/libcufile_rdma.so.1",
114115
]
116+
115117
# Copy libraries to unzipped_folder/a/lib
116118
for lib_path in libs_to_copy:
117119
lib_name = os.path.basename(lib_path)
@@ -120,6 +122,13 @@ def update_wheel(wheel_path, desired_cuda) -> None:
120122
f"cd {folder}/tmp/torch/lib/; "
121123
f"patchelf --set-rpath '$ORIGIN' --force-rpath {folder}/tmp/torch/lib/{lib_name}"
122124
)
125+
126+
# Make sure the wheel is tagged with manylinux_2_28
127+
for f in os.scandir(f"{folder}/tmp/"):
128+
if f.is_dir() and f.name.endswith(".dist-info"):
129+
replace_tag(f"{f.path}/WHEEL")
130+
break
131+
123132
os.mkdir(f"{folder}/cuda_wheel")
124133
os.system(f"cd {folder}/tmp/; zip -r {folder}/cuda_wheel/{wheelname} *")
125134
shutil.move(
@@ -242,6 +251,6 @@ def parse_arguments():
242251
print("Updating Cuda Dependency")
243252
filename = os.listdir("/pytorch/dist/")
244253
wheel_path = f"/pytorch/dist/{filename[0]}"
245-
update_wheel(wheel_path, desired_cuda)
254+
package_cuda_wheel(wheel_path, desired_cuda)
246255
pytorch_wheel_name = complete_wheel("/pytorch/")
247256
print(f"Build Complete. Created {pytorch_wheel_name}..")

0 commit comments

Comments
 (0)