|
7 | 7 | #
|
8 | 8 | # To update the types of arguments, add `api.<function>.argtypes` to a library's TOML file.
|
9 | 9 |
|
10 |
| -# TODO |
11 |
| -# - deal with NVML's `NVML_STRUCT_VERSION` (workaround: we ignore these symbols) |
12 |
| - |
13 | 10 | using Clang
|
14 | 11 | using Clang.Generators
|
15 | 12 |
|
@@ -126,7 +123,7 @@ function rewriter!(ctx, options)
|
126 | 123 | lhs, rhs = expr.args
|
127 | 124 | if rhs isa Expr && rhs.head == :call
|
128 | 125 | name = string(rhs.args[1])
|
129 |
| - if endswith(name, "STRUCT_SIZE") |
| 126 | + if endswith(name, "STRUCT_SIZE") || endswith(name, "STRUCT_VERSION") |
130 | 127 | rhs.head = :macrocall
|
131 | 128 | rhs.args[1] = Symbol("@", name)
|
132 | 129 | insert!(rhs.args, 2, nothing)
|
@@ -182,14 +179,20 @@ function rewriter!(ctx, options)
|
182 | 179 | push!(names, fn[1:end-3])
|
183 | 180 | end
|
184 | 181 |
|
| 182 | + # versioned functions are aliased to the unversioned name, so we can |
| 183 | + # reuse the same type rewrites. |
| 184 | + if occursin("_v", fn) |
| 185 | + push!(names, replace(fn, r"_v\d+" => "")) |
| 186 | + end |
| 187 | + |
185 | 188 | # look for a template rewrite: many libraries have very similar functions,
|
186 | 189 | # e.g., `cublas[SDHCZ]gemm`, for which we can use the same type rewrites
|
187 | 190 | # registered as `cublas𝕏gemm` template with `T` and `S` placeholders.
|
188 | 191 | for name in copy(names), (typcode,(T,S)) in ["S"=>("Cfloat","Cfloat"),
|
189 |
| - "D"=>("Cdouble","Cdouble"), |
190 |
| - "H"=>("Float16","Float16"), |
191 |
| - "C"=>("cuComplex","Cfloat"), |
192 |
| - "Z"=>("cuDoubleComplex","Cdouble")] |
| 192 | + "D"=>("Cdouble","Cdouble"), |
| 193 | + "H"=>("Float16","Float16"), |
| 194 | + "C"=>("cuComplex","Cfloat"), |
| 195 | + "Z"=>("cuDoubleComplex","Cdouble")] |
193 | 196 | idx = findfirst(typcode, name)
|
194 | 197 | while idx !== nothing
|
195 | 198 | template_name = name[1:idx.start-1] * "𝕏" * name[idx.stop+1:end]
|
@@ -265,7 +268,8 @@ function main(name="all")
|
265 | 268 | end
|
266 | 269 |
|
267 | 270 | if name == "all" || name == "nvml"
|
268 |
| - wrap("nvml", ["$cuda/nvml.h"]; include_dirs=[cuda]) |
| 271 | + wrap("nvml", ["$cuda/nvml.h"]; include_dirs=[cuda], |
| 272 | + defines=["NVML_NO_UNVERSIONED_FUNC_DEFS=0"]) |
269 | 273 | end
|
270 | 274 |
|
271 | 275 | if name == "all" || name == "cupti"
|
|
0 commit comments