Skip to content

Commit cd93b46

Browse files
tkoeppejax authors
authored andcommitted
Add initialization annotations (for the benefit of MSAN) to variables that are initialized by external functions.
PiperOrigin-RevId: 641879836
1 parent 991797a commit cd93b46

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

jaxlib/cuda/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -513,15 +513,14 @@ cc_library(
513513
deps = [
514514
":cuda_gpu_kernel_helpers",
515515
":cuda_vendor",
516-
"//jaxlib:absl_status_casters",
517-
"//jaxlib:kernel_nanobind_helpers",
518516
"@xla//xla/tsl/cuda:cublas",
519517
"@xla//xla/tsl/cuda:cudart",
520518
"@xla//xla/tsl/cuda:cudnn",
521519
"@xla//xla/tsl/cuda:cufft",
522520
"@xla//xla/tsl/cuda:cupti",
523521
"@xla//xla/tsl/cuda:cusolver",
524522
"@xla//xla/tsl/cuda:cusparse",
523+
"@com_google_absl//absl/base:dynamic_annotations",
525524
],
526525
)
527526

jaxlib/cuda/versions_helpers.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ limitations under the License.
1818
#include <cstddef>
1919
#include <stdexcept>
2020

21+
#include "absl/base/dynamic_annotations.h"
2122
#include "jaxlib/gpu/gpu_kernel_helpers.h"
2223
#include "jaxlib/gpu/vendor.h"
2324

@@ -30,39 +31,45 @@ namespace jax::cuda {
3031
int CudaRuntimeGetVersion() {
3132
int version;
3233
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaRuntimeGetVersion(&version)));
34+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
3335
return version;
3436
}
3537

3638
int CudaDriverGetVersion() {
3739
int version;
3840
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cudaDriverGetVersion(&version)));
41+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
3942
return version;
4043
}
4144

4245
uint32_t CuptiGetVersion() {
4346
uint32_t version;
4447
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuptiGetVersion(&version)));
48+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
4549
return version;
4650
}
4751

4852
int CufftGetVersion() {
4953
int version;
5054
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cufftGetVersion(&version)));
55+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
5156
return version;
5257
}
5358

5459
int CusolverGetVersion() {
5560
int version;
5661
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusolverGetVersion(&version)));
62+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
5763
return version;
5864
}
5965

6066
int CublasGetVersion() {
6167
int version;
62-
// NVIDIA promise that it's safe to parse nullptr as the handle to this
68+
// NVIDIA promise that it's safe to pass a null pointer as the handle to this
6369
// function.
6470
JAX_THROW_IF_ERROR(
6571
JAX_AS_STATUS(cublasGetVersion(/*handle=*/nullptr, &version)));
72+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
6673
return version;
6774
}
6875

@@ -73,6 +80,9 @@ int CusparseGetVersion() {
7380
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MAJOR_VERSION, &major)));
7481
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(MINOR_VERSION, &minor)));
7582
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cusparseGetProperty(PATCH_LEVEL, &patch)));
83+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
84+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
85+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&patch, sizeof patch);
7686
return major * 1000 + minor * 100 + patch;
7787
}
7888
size_t CudnnGetVersion() {
@@ -82,6 +92,7 @@ size_t CudnnGetVersion() {
8292
if (version == 0) {
8393
throw std::runtime_error("cuDNN not found.");
8494
}
95+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&version, sizeof version);
8596
return version;
8697
}
8798
int CudaComputeCapability(int device) {
@@ -91,6 +102,8 @@ int CudaComputeCapability(int device) {
91102
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
92103
JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpuDeviceGetAttribute(
93104
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
105+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&major, sizeof major);
106+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&minor, sizeof minor);
94107
return major * 10 + minor;
95108
}
96109

@@ -99,6 +112,7 @@ int CudaDeviceCount() {
99112
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuInit(0)));
100113
JAX_THROW_IF_ERROR(JAX_AS_STATUS(cuDeviceGetCount(&device_count)));
101114

115+
ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(&device_count, sizeof device_count);
102116
return device_count;
103117
}
104118

0 commit comments

Comments
 (0)