@@ -55,9 +55,22 @@ def build_ArmComputeLibrary() -> None:
5555 shutil .copytree (f"{ acl_checkout_dir } /{ d } " , f"{ acl_install_dir } /{ d } " )
5656
5757
58- def update_wheel (wheel_path , desired_cuda ) -> None :
58+ def replace_tag (filename ) -> None :
59+ with open (filename ) as f :
60+ lines = f .readlines ()
61+ for i , line in enumerate (lines ):
62+ if line .startswith ("Tag:" ):
63+ lines [i ] = line .replace ("-linux_" , "-manylinux_2_28_" )
64+ print (f"Updated tag from { line } to { lines [i ]} " )
65+ break
66+
67+ with open (filename , "w" ) as f :
68+ f .writelines (lines )
69+
70+
71+ def package_cuda_wheel (wheel_path , desired_cuda ) -> None :
5972 """
60- Update the cuda wheel libraries
73+ Package the cuda wheel libraries
6174 """
6275 folder = os .path .dirname (wheel_path )
6376 wheelname = os .path .basename (wheel_path )
@@ -88,30 +101,19 @@ def update_wheel(wheel_path, desired_cuda) -> None:
88101 "/usr/lib64/libgfortran.so.5" ,
89102 "/acl/build/libarm_compute.so" ,
90103 "/acl/build/libarm_compute_graph.so" ,
104+ "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0" ,
105+ "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ,
106+ "/usr/local/lib/libnvpl_lapack_core.so.0" ,
107+ "/usr/local/lib/libnvpl_blas_core.so.0" ,
91108 ]
92- if enable_cuda :
93- libs_to_copy += [
94- "/usr/local/lib/libnvpl_lapack_lp64_gomp.so.0" ,
95- "/usr/local/lib/libnvpl_blas_lp64_gomp.so.0" ,
96- "/usr/local/lib/libnvpl_lapack_core.so.0" ,
97- "/usr/local/lib/libnvpl_blas_core.so.0" ,
98- ]
99- if "126" in desired_cuda :
100- libs_to_copy += [
101- "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.6" ,
102- "/usr/local/cuda/lib64/libcufile.so.0" ,
103- "/usr/local/cuda/lib64/libcufile_rdma.so.1" ,
104- ]
105- elif "128" in desired_cuda :
106- libs_to_copy += [
107- "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8" ,
108- "/usr/local/cuda/lib64/libcufile.so.0" ,
109- "/usr/local/cuda/lib64/libcufile_rdma.so.1" ,
110- ]
111- else :
109+
110+ if "128" in desired_cuda :
112111 libs_to_copy += [
113- "/opt/OpenBLAS/lib/libopenblas.so.0" ,
112+ "/usr/local/cuda/lib64/libnvrtc-builtins.so.12.8" ,
113+ "/usr/local/cuda/lib64/libcufile.so.0" ,
114+ "/usr/local/cuda/lib64/libcufile_rdma.so.1" ,
114115 ]
116+
115117 # Copy libraries to unzipped_folder/a/lib
116118 for lib_path in libs_to_copy :
117119 lib_name = os .path .basename (lib_path )
@@ -120,6 +122,13 @@ def update_wheel(wheel_path, desired_cuda) -> None:
120122 f"cd { folder } /tmp/torch/lib/; "
121123 f"patchelf --set-rpath '$ORIGIN' --force-rpath { folder } /tmp/torch/lib/{ lib_name } "
122124 )
125+
126+ # Make sure the wheel is tagged with manylinux_2_28
127+ for f in os .scandir (f"{ folder } /tmp/" ):
128+ if f .is_dir () and f .name .endswith (".dist-info" ):
129+ replace_tag (f"{ f .path } /WHEEL" )
130+ break
131+
123132 os .mkdir (f"{ folder } /cuda_wheel" )
124133 os .system (f"cd { folder } /tmp/; zip -r { folder } /cuda_wheel/{ wheelname } *" )
125134 shutil .move (
@@ -242,6 +251,6 @@ def parse_arguments():
242251 print ("Updating Cuda Dependency" )
243252 filename = os .listdir ("/pytorch/dist/" )
244253 wheel_path = f"/pytorch/dist/{ filename [0 ]} "
245- update_wheel (wheel_path , desired_cuda )
254+ package_cuda_wheel (wheel_path , desired_cuda )
246255 pytorch_wheel_name = complete_wheel ("/pytorch/" )
247256 print (f"Build Complete. Created { pytorch_wheel_name } .." )
0 commit comments