Skip to content

Commit ef0d35f

Browse files
committed
Improve NVML wrapper generation.
1 parent 039b097 commit ef0d35f

File tree

3 files changed

+22
-39
lines changed

3 files changed

+22
-39
lines changed

res/wrap/libnvml_prologue.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,10 @@ end
2626

2727
return
2828
end
29+
30+
macro NVML_STRUCT_VERSION(typename, version)
31+
struct_typename = Symbol("nvml$(String(typename))_v$(version)_t")
32+
struct_type = getfield(__module__, struct_typename)
33+
struct_version = UInt32(sizeof(struct_type)) | (UInt32(version) << 24)
34+
return :($struct_version)
35+
end

res/wrap/nvml.toml

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,7 @@ library_name = "libnvml()"
33
output_file_path = "../../lib/nvml/libnvml.jl"
44
prologue_file_path = "./libnvml_prologue.jl"
55

6-
output_ignorelist = [
7-
# we can't handle NVML_STRUCT_VERSION
8-
"^nvmlMemory_v2$",
9-
"^nvmlGpuInstanceProfileInfo_v2$",
10-
"^nvmlComputeInstanceProfileInfo_v2$",
11-
"^nvmlPowerValue_v2$",
12-
"^nvmlProcessDetailList_v1$",
13-
"^nvmlC2cModeInfo_v1$",
14-
"^nvmlGpuInstanceProfileInfo_v3$",
15-
"^nvmlComputeInstanceProfileInfo_v3$",
16-
"nvmlPciInfoExt_v1",
17-
"nvmlVgpuHeterogeneousMode_v1",
18-
"nvmlVgpuPlacementId_v1",
19-
"nvmlVgpuPlacementList_v1",
20-
"nvmlVgpuInstancesUtilizationInfo_v1",
21-
"nvmlVgpuProcessesUtilizationInfo_v1",
22-
"nvmlProcessesUtilizationInfo_v1",
23-
"nvmlEccSramErrorStatus_v1",
24-
"nvmlSystemConfComputeSettings_v1",
25-
"nvmlConfComputeSetKeyRotationThresholdInfo_v1",
26-
"nvmlConfComputeGetKeyRotationThresholdInfo_v1",
27-
"nvmlGpuFabricInfo_v2",
28-
"NVML_DEVICE_PCI_BUS_ID_LEGACY_FMT",
29-
"NVML_DEVICE_PCI_BUS_ID_FMT",
30-
"nvmlClockOffset_v1",
31-
"nvmlDeviceCapabilities_v1",
32-
"nvmlVgpuTypeBar1Info_v1",
33-
"nvmlSystemDriverBranchInfo_v1",
34-
]
6+
output_ignorelist = []
357

368

379
[codegen]
@@ -42,7 +14,7 @@ always_NUL_terminated_string = true
4214
[api]
4315
checked_rettypes = [ "nvmlReturn_t" ]
4416

45-
[api.nvmlInit_v2]
17+
[api.nvmlInit]
4618
needs_context = false
4719

4820
[api.nvmlInitWithFlags]

res/wrap/wrap.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
#
88
# To update the types of arguments, add `api.<function>.argtypes` to a library's TOML file.
99

10-
# TODO
11-
# - deal with NVML's `NVML_STRUCT_VERSION` (workaround: we ignore these symbols)
12-
1310
using Clang
1411
using Clang.Generators
1512

@@ -126,7 +123,7 @@ function rewriter!(ctx, options)
126123
lhs, rhs = expr.args
127124
if rhs isa Expr && rhs.head == :call
128125
name = string(rhs.args[1])
129-
if endswith(name, "STRUCT_SIZE")
126+
if endswith(name, "STRUCT_SIZE") || endswith(name, "STRUCT_VERSION")
130127
rhs.head = :macrocall
131128
rhs.args[1] = Symbol("@", name)
132129
insert!(rhs.args, 2, nothing)
@@ -182,14 +179,20 @@ function rewriter!(ctx, options)
182179
push!(names, fn[1:end-3])
183180
end
184181

182+
# versioned functions are aliased to the unversioned name, so we can
183+
# reuse the same type rewrites.
184+
if occursin("_v", fn)
185+
push!(names, replace(fn, r"_v\d+" => ""))
186+
end
187+
185188
# look for a template rewrite: many libraries have very similar functions,
186189
# e.g., `cublas[SDHCZ]gemm`, for which we can use the same type rewrites
187190
# registered as `cublas𝕏gemm` template with `T` and `S` placeholders.
188191
for name in copy(names), (typcode,(T,S)) in ["S"=>("Cfloat","Cfloat"),
189-
"D"=>("Cdouble","Cdouble"),
190-
"H"=>("Float16","Float16"),
191-
"C"=>("cuComplex","Cfloat"),
192-
"Z"=>("cuDoubleComplex","Cdouble")]
192+
"D"=>("Cdouble","Cdouble"),
193+
"H"=>("Float16","Float16"),
194+
"C"=>("cuComplex","Cfloat"),
195+
"Z"=>("cuDoubleComplex","Cdouble")]
193196
idx = findfirst(typcode, name)
194197
while idx !== nothing
195198
template_name = name[1:idx.start-1] * "𝕏" * name[idx.stop+1:end]
@@ -265,7 +268,8 @@ function main(name="all")
265268
end
266269

267270
if name == "all" || name == "nvml"
268-
wrap("nvml", ["$cuda/nvml.h"]; include_dirs=[cuda])
271+
wrap("nvml", ["$cuda/nvml.h"]; include_dirs=[cuda],
272+
defines=["NVML_NO_UNVERSIONED_FUNC_DEFS=0"])
269273
end
270274

271275
if name == "all" || name == "cupti"

0 commit comments

Comments
 (0)