Skip to content

Commit edb3814

Browse files
committed
new rocm patch
1 parent 0503b61 commit edb3814

File tree

1 file changed

+260
-49
lines changed

1 file changed

+260
-49
lines changed

patches/xla.patch

Lines changed: 260 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,269 @@
1+
diff --git a/third_party/rocm_device_libs/build_defs.bzl b/third_party/rocm_device_libs/build_defs.bzl
2+
index 845618bd9c..352882efb5 100644
3+
--- a/third_party/rocm_device_libs/build_defs.bzl
4+
+++ b/third_party/rocm_device_libs/build_defs.bzl
5+
@@ -2,110 +2,144 @@
6+
7+
load("@bazel_skylib//lib:paths.bzl", "paths")
8+
9+
-def bitcode_library(
10+
- name,
11+
- srcs = [],
12+
- hdrs = [],
13+
- file_specific_flags = {}):
14+
- """Builds a bitcode library
15+
+BitcodeLibraryInfo = provider(fields = ["bc_file"])
16+
17+
- Args:
18+
- name: Unique name of the build rule.
19+
- srcs: List of source files (*.cl, *.ll).
20+
- hdrs: List of header files (*.h).
21+
- file_specific_flags: Per-file dict of flags to be passed to clang.
22+
- """
23+
- # Takes the CL sources and compiles them into bitcode files.
24+
- # Merges those bitcode files together with any given .ll files into a single bitcode file.
25+
- # Strips unnecessary metadata and forces linkonce local visibility for symbols.
26+
- # Adapted from:
27+
- # https://github.com/ROCm/llvm-project/blob/22ee53fa53edc3a5f25feb08dc840f5b0fc362da/amd/device-libs/cmake/OCL.cmake#L73
28+
+def _bitcode_library_impl(ctx):
29+
+ """Implements a bitcode library rule."""
30+
+ srcs = ctx.files.srcs
31+
+ hdrs = ctx.files.hdrs
32+
33+
- clang_tool = "@llvm-project//clang:clang"
34+
- llvm_link_tool = "@llvm-project//llvm:llvm-link"
35+
- opt_tool = "@llvm-project//llvm:opt"
36+
- prepare_builtins_tool = ":prepare_builtins"
37+
- clang_includes = "@llvm-project//clang:builtin_headers_gen"
38+
+ bc_outputs = []
39+
40+
- # Just for calculating the include path.
41+
- clang_header = "@llvm-project//clang:staging/include/opencl-c.h"
42+
-
43+
- include_paths = dict([(paths.dirname(h), None) for h in hdrs]).keys()
44+
-
45+
- #TODO(rocm): Maybe compute this in cmd not to pass dirs as srcs
46+
- includes = " ".join(["-I$(location {})".format(inc) for inc in include_paths])
47+
- flags = ("-fcolor-diagnostics -Werror -Wno-error=atomic-alignment -x cl -Xclang " +
48+
- "-cl-std=CL2.0 --target=amdgcn-amd-amdhsa -fvisibility=hidden -fomit-frame-pointer " +
49+
- "-Xclang -finclude-default-header -Xclang -fexperimental-strict-floating-point " +
50+
- "-Xclang -fdenormal-fp-math=dynamic -Xclang -Qn " +
51+
- "-nogpulib -cl-no-stdinc -Xclang -mcode-object-version=none")
52+
-
53+
- link_inputs = []
54+
+ include_dirs = dict([(paths.dirname(h.path), None) for h in ctx.files.hdrs]).keys()
55+
56+
+ # Compile .cl files to .bc
57+
for src in srcs:
58+
- filename = paths.basename(src)
59+
- (basename, _, ext) = filename.partition(".")
60+
-
61+
- if (ext == "ll"):
62+
- link_inputs.append(src)
63+
- continue
64+
-
65+
- out = basename + ".bc"
66+
- link_inputs.append(out)
67+
- extra_flags = " ".join(file_specific_flags.get(filename, []))
68+
- native.genrule(
69+
- name = "compile_" + basename,
70+
- srcs = [src] + hdrs + include_paths + [clang_includes, clang_header],
71+
- outs = [out],
72+
- cmd = "$(location {}) -I$$(dirname $(location {})) {} {} {} -emit-llvm -c $(location {}) -o $@".format(
73+
- clang_tool,
74+
- clang_header,
75+
- includes,
76+
- flags,
77+
- extra_flags,
78+
- src,
79+
- ),
80+
- tools = [clang_tool],
81+
- message = "Compiling {} ...".format(filename),
82+
- )
83+
-
84+
- link_message = "Linking {}.bc ...".format(name)
85+
-
86+
- prelink_out = name + ".link0.lib.bc"
87+
- native.genrule(
88+
- name = "prelink_" + name,
89+
- srcs = link_inputs,
90+
- outs = [prelink_out],
91+
- cmd = "$(location {}) $(SRCS) -o $@".format(llvm_link_tool),
92+
- tools = [llvm_link_tool],
93+
- message = link_message,
94+
+ if src.path.endswith(".cl"):
95+
+ out = ctx.actions.declare_file(src.basename + ".bc")
96+
+ bc_outputs.append(out)
97+
+
98+
+ extra_flags = ctx.attr.file_specific_flags.get(src.basename, "")
99+
+ include_flags = ["-I{}".format(dir) for dir in include_dirs]
100+
+ include_flags += ["-I{}".format(ctx.files._clang_header[0].dirname)]
101+
+ include_flags += ["-I{}".format(ctx.files._clang_includes[0].dirname)]
102+
+ args = [
103+
+ "-x",
104+
+ "cl",
105+
+ "--target=amdgcn-amd-amdhsa",
106+
+ "-emit-llvm",
107+
+ "-fcolor-diagnostics",
108+
+ "-Werror",
109+
+ "-Wno-error=atomic-alignment",
110+
+ "-Xclang",
111+
+ "-cl-std=CL2.0",
112+
+ "-fvisibility=hidden",
113+
+ "-fomit-frame-pointer",
114+
+ "-Xclang",
115+
+ "-finclude-default-header",
116+
+ "-Xclang",
117+
+ "-fexperimental-strict-floating-point",
118+
+ "-Xclang",
119+
+ "-fdenormal-fp-math=dynamic",
120+
+ "-Xclang",
121+
+ "-Qn",
122+
+ "-nogpulib",
123+
+ "-cl-no-stdinc",
124+
+ "-Xclang",
125+
+ "-mcode-object-version=none",
126+
+ "-c",
127+
+ ] + include_flags + [src.path, "-o", out.path] + extra_flags.split(" ")
128+
+
129+
+ ctx.actions.run(
130+
+ executable = ctx.executable._clang,
131+
+ inputs = [src] + hdrs + ctx.files._clang_includes + ctx.files._clang_header,
132+
+ outputs = [out],
133+
+ arguments = args,
134+
+ progress_message = "Compiling {} → bitcode".format(src.basename),
135+
+ )
136+
+
137+
+ elif src.path.endswith(".ll"):
138+
+ # Directly include .ll files in linking
139+
+ bc_outputs.append(src)
140+
+
141+
+ # Link all .bc files into one prelinked .bc
142+
+ prelink_out = ctx.actions.declare_file(ctx.label.name + ".link0.lib.bc")
143+
+ ctx.actions.run(
144+
+ executable = ctx.executable._llvm_link,
145+
+ inputs = bc_outputs,
146+
+ outputs = [prelink_out],
147+
+ arguments = [f.path for f in bc_outputs] + ["-o", prelink_out.path],
148+
+ progress_message = "Linking {} bitcode files".format(ctx.label.name),
149+
)
150+
151+
- internalize_out = name + ".lib.bc"
152+
- native.genrule(
153+
- name = "internalize_" + name,
154+
- srcs = [prelink_out],
155+
- outs = [internalize_out],
156+
- cmd = "$(location {}) -internalize -only-needed $< -o $@".format(llvm_link_tool),
157+
- tools = [llvm_link_tool],
158+
- message = link_message,
159+
+ # Internalize symbols (llvm-link + -internalize)
160+
+ internalize_out = ctx.actions.declare_file(ctx.label.name + ".lib.bc")
161+
+ ctx.actions.run(
162+
+ executable = ctx.executable._llvm_link,
163+
+ inputs = [prelink_out],
164+
+ outputs = [internalize_out],
165+
+ arguments = ["-internalize", "-only-needed", prelink_out.path, "-o", internalize_out.path],
166+
+ progress_message = "Internalizing symbols for {}".format(ctx.label.name),
167+
)
168+
169+
- strip_out = name + ".strip.bc"
170+
- native.genrule(
171+
- name = "strip_" + name,
172+
- srcs = [internalize_out],
173+
- outs = [strip_out],
174+
- cmd = "$(location {}) -passes=strip -o $@ $<".format(opt_tool),
175+
- tools = [opt_tool],
176+
- message = link_message,
177+
+ # Strip unnecessary metadata
178+
+ strip_out = ctx.actions.declare_file(ctx.label.name + ".strip.bc")
179+
+ ctx.actions.run(
180+
+ executable = ctx.executable._opt,
181+
+ inputs = [internalize_out],
182+
+ outputs = [strip_out],
183+
+ arguments = ["-passes=strip", "-o", strip_out.path, internalize_out.path],
184+
+ progress_message = "Stripping {}".format(ctx.label.name),
185+
)
186+
187+
- native.genrule(
188+
- name = name,
189+
- srcs = [strip_out],
190+
- outs = [name + ".bc"],
191+
- cmd = "$(location {}) -o $@ $<".format(prepare_builtins_tool),
192+
- tools = [prepare_builtins_tool],
193+
- message = link_message,
194+
+ # Final preparation of bitcode (custom prepare_builtins tool)
195+
+ final_bc = ctx.actions.declare_file(ctx.label.name + ".bc")
196+
+ ctx.actions.run(
197+
+ executable = ctx.executable._prepare_builtins,
198+
+ inputs = [strip_out],
199+
+ outputs = [final_bc],
200+
+ arguments = [strip_out.path, "-o", final_bc.path],
201+
+ progress_message = "Preparing final bitcode for {}".format(ctx.label.name),
202+
)
203+
+
204+
+ return [
205+
+ DefaultInfo(files = depset([final_bc])),
206+
+ BitcodeLibraryInfo(bc_file = final_bc),
207+
+ ]
208+
+
209+
+bitcode_library = rule(
210+
+ implementation = _bitcode_library_impl,
211+
+ attrs = {
212+
+ "srcs": attr.label_list(allow_files = [".cl", ".ll"]),
213+
+ "hdrs": attr.label_list(allow_files = [".h"]),
214+
+ "file_specific_flags": attr.string_dict(),
215+
+ "_clang": attr.label(
216+
+ default = Label("@llvm-project//clang:clang"),
217+
+ executable = True,
218+
+ cfg = "exec",
219+
+ ),
220+
+ "_llvm_link": attr.label(
221+
+ default = Label("@llvm-project//llvm:llvm-link"),
222+
+ executable = True,
223+
+ cfg = "exec",
224+
+ ),
225+
+ "_opt": attr.label(
226+
+ default = Label("@llvm-project//llvm:opt"),
227+
+ executable = True,
228+
+ cfg = "exec",
229+
+ ),
230+
+ "_prepare_builtins": attr.label(
231+
+ default = Label(":prepare_builtins"),
232+
+ executable = True,
233+
+ cfg = "exec",
234+
+ ),
235+
+ "_clang_includes": attr.label(
236+
+ default = Label("@llvm-project//clang:builtin_headers_gen"),
237+
+ allow_files = True,
238+
+ ),
239+
+ "_clang_header": attr.label(
240+
+ default = Label("@llvm-project//clang:staging/include/opencl-c.h"),
241+
+ allow_files = True,
242+
+ ),
243+
+ },
244+
+)
1245
diff --git a/third_party/rocm_device_libs/rocm_device_libs.BUILD b/third_party/rocm_device_libs/rocm_device_libs.BUILD
2-
index 11795b3537..c6e953d577 100644
246+
index 11795b3537..966c3605ce 100644
3247
--- a/third_party/rocm_device_libs/rocm_device_libs.BUILD
4248
+++ b/third_party/rocm_device_libs/rocm_device_libs.BUILD
5-
@@ -28,35 +28,24 @@ cc_binary(
249+
@@ -39,9 +39,9 @@ bitcode_library(
250+
"oclc/inc/*.h",
251+
]),
252+
file_specific_flags = {
253+
- "native_logF.cl": ["-fapprox-func"],
254+
- "native_expF.cl": ["-fapprox-func"],
255+
- "sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
256+
+ "native_logF.cl": "-fapprox-func",
257+
+ "native_expF.cl": "-fapprox-func",
258+
+ "sqrtF.cl": "-cl-fp32-correctly-rounded-divide-sqrt",
259+
},
6260
)
7261

8-
bitcode_library(
9-
- name = "ocml",
10-
+ name = "ockl",
11-
srcs = glob([
12-
"ocml/src/*.cl",
13-
+ "ockl/src/*.cl",
14-
+ "ockl/src/*.ll",
15-
]),
16-
hdrs = glob([
17-
"ocml/src/*.h",
18-
"ocml/inc/*.h",
19-
"irif/inc/*.h",
20-
- "oclc/inc/*.h",
21-
+ "oclc/inc/*.h",
22-
+ "ockl/inc/*.h",
262+
@@ -57,6 +57,6 @@ bitcode_library(
263+
"oclc/inc/*.h",
23264
]),
24265
file_specific_flags = {
25-
"native_logF.cl": ["-fapprox-func"],
26-
"native_expF.cl": ["-fapprox-func"],
27-
"sqrtF.cl": ["-cl-fp32-correctly-rounded-divide-sqrt"],
28-
- },
29-
-)
30-
-
31-
-bitcode_library(
32-
- name = "ockl",
33-
- srcs = glob([
34-
- "ockl/src/*.cl",
35-
- "ockl/src/*.ll",
36-
- ]),
37-
- hdrs = glob([
38-
- "ockl/inc/*.h",
39-
- "irif/inc/*.h",
40-
- "oclc/inc/*.h",
41-
- ]),
42-
- file_specific_flags = {
43-
"gaaf.cl": ["-munsafe-fp-atomics"],
266+
- "gaaf.cl": ["-munsafe-fp-atomics"],
267+
+ "gaaf.cl": "-munsafe-fp-atomics",
44268
},
45269
)
46-
+
47-
diff --git a/xla/service/gpu/llvm_gpu_backend/BUILD b/xla/service/gpu/llvm_gpu_backend/BUILD
48-
index 101b0580e1..1df2b6accb 100644
49-
--- a/xla/service/gpu/llvm_gpu_backend/BUILD
50-
+++ b/xla/service/gpu/llvm_gpu_backend/BUILD
51-
@@ -130,7 +130,6 @@ genrule(
52-
name = "generate_amdgpu_device_lib_data",
53-
srcs = [
54-
"@rocm_device_libs//:ockl",
55-
- "@rocm_device_libs//:ocml",
56-
],
57-
outs = ["amdgpu_device_lib_data.h"],
58-
cmd = "$(location {}) --llvm_link_bin $(location {}) $(SRCS) -o $@ --cpp_identifier=kAMDGPUDeviceLibData".format(

0 commit comments

Comments
 (0)