@@ -18,6 +18,7 @@ limitations under the License.
18
18
#include < cstddef>
19
19
#include < stdexcept>
20
20
21
+ #include " absl/base/dynamic_annotations.h"
21
22
#include " jaxlib/gpu/gpu_kernel_helpers.h"
22
23
#include " jaxlib/gpu/vendor.h"
23
24
@@ -30,39 +31,45 @@ namespace jax::cuda {
30
31
int CudaRuntimeGetVersion () {
31
32
int version;
32
33
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cudaRuntimeGetVersion (&version)));
34
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
33
35
return version;
34
36
}
35
37
36
38
int CudaDriverGetVersion () {
37
39
int version;
38
40
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cudaDriverGetVersion (&version)));
41
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
39
42
return version;
40
43
}
41
44
42
45
uint32_t CuptiGetVersion () {
43
46
uint32_t version;
44
47
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cuptiGetVersion (&version)));
48
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
45
49
return version;
46
50
}
47
51
48
52
int CufftGetVersion () {
49
53
int version;
50
54
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cufftGetVersion (&version)));
55
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
51
56
return version;
52
57
}
53
58
54
59
int CusolverGetVersion () {
55
60
int version;
56
61
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cusolverGetVersion (&version)));
62
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
57
63
return version;
58
64
}
59
65
60
66
int CublasGetVersion () {
61
67
int version;
62
- // NVIDIA promise that it's safe to parse nullptr as the handle to this
68
+ // NVIDIA promise that it's safe to pass a null pointer as the handle to this
63
69
// function.
64
70
JAX_THROW_IF_ERROR (
65
71
JAX_AS_STATUS (cublasGetVersion (/* handle=*/ nullptr , &version)));
72
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
66
73
return version;
67
74
}
68
75
@@ -73,6 +80,9 @@ int CusparseGetVersion() {
73
80
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cusparseGetProperty (MAJOR_VERSION, &major)));
74
81
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cusparseGetProperty (MINOR_VERSION, &minor)));
75
82
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cusparseGetProperty (PATCH_LEVEL, &patch)));
83
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&major, sizeof major);
84
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&minor, sizeof minor);
85
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&patch, sizeof patch);
76
86
return major * 1000 + minor * 100 + patch;
77
87
}
78
88
size_t CudnnGetVersion () {
@@ -82,6 +92,7 @@ size_t CudnnGetVersion() {
82
92
if (version == 0 ) {
83
93
throw std::runtime_error (" cuDNN not found." );
84
94
}
95
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&version, sizeof version);
85
96
return version;
86
97
}
87
98
int CudaComputeCapability (int device) {
@@ -91,6 +102,8 @@ int CudaComputeCapability(int device) {
91
102
&major, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device)));
92
103
JAX_THROW_IF_ERROR (JAX_AS_STATUS (gpuDeviceGetAttribute (
93
104
&minor, GPU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device)));
105
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&major, sizeof major);
106
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&minor, sizeof minor);
94
107
return major * 10 + minor;
95
108
}
96
109
@@ -99,6 +112,7 @@ int CudaDeviceCount() {
99
112
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cuInit (0 )));
100
113
JAX_THROW_IF_ERROR (JAX_AS_STATUS (cuDeviceGetCount (&device_count)));
101
114
115
+ ABSL_ANNOTATE_MEMORY_IS_INITIALIZED (&device_count, sizeof device_count);
102
116
return device_count;
103
117
}
104
118
0 commit comments