Skip to content

Commit b5eaa25

Browse files
authored
Fix indeterminism bug in LLVM datalayout (#6303)
The data layout depends on the flag `nvptx-short-ptr` but this flag was being set after we set the data layout. Since this is a global this was only affecting the first kernel compiled.
1 parent 6917a7f commit b5eaa25

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

third_party/nvidia/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ def make_llir(self, src, metadata, options, capability):
331331
proc = sm_arch_from_capability(capability)
332332
features = get_features(options, self.target.arch)
333333
triple = 'nvptx64-nvidia-cuda'
334+
nvidia.set_short_ptr()
334335
llvm.attach_datalayout(llvm_mod, triple, proc, features)
335336
nvidia.set_nvvm_reflect_ftz(llvm_mod)
336337

@@ -366,7 +367,7 @@ def make_ptx(self, src, metadata, opt, capability):
366367
triple = 'nvptx64-nvidia-cuda'
367368
proc = sm_arch_from_capability(capability)
368369
features = get_features(opt, self.target.arch)
369-
ret = llvm.translate_to_asm(src, triple, proc, features, ['nvptx-short-ptr'], opt.enable_fp_fusion, False)
370+
ret = llvm.translate_to_asm(src, triple, proc, features, [], opt.enable_fp_fusion, False)
370371
# Find kernel names (there should only be one)
371372
names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret)
372373
assert len(names) == 1

third_party/nvidia/triton_nvidia.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ void init_triton_nvidia(py::module &&m) {
7979
context.loadAllAvailableDialects();
8080
});
8181

82+
// Set short point option, this needs to be set before setting the data
83+
// layout.
84+
m.def("set_short_ptr", []() {
85+
auto options = llvm::cl::getRegisteredOptions();
86+
const char *flag = "nvptx-short-ptr";
87+
auto *shortPtr = static_cast<llvm::cl::opt<bool> *>(options[flag]);
88+
assert(shortPtr);
89+
shortPtr->setValue(true);
90+
});
91+
8292
// TODO: could be done in python if we had a generic interface to set metadata
8393
m.def("set_nvvm_reflect_ftz", [](llvm::Module *mod) {
8494
// please check https://llvm.org/docs/NVPTXUsage.html#reflection-parameters

0 commit comments

Comments
 (0)