Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 7812347

Browse files
committed
feat: oneDNN wrapper based on oneDNN_jll
1 parent 40d9192 commit 7812347

File tree

10 files changed

+8046
-1
lines changed

10 files changed

+8046
-1
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,5 @@ scripts
1212
test_ext
1313

1414
benchmarks/results
15+
16+
deps

.typos.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ numer = "numer"
33
nd = "nd"
44
Ba = "Ba"
55
skipt = "skipt"
6+
abd = "abd"

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.2.0"
4+
version = "1.3.0"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
8+
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
89
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1011
CpuId = "adafc99b-e345-5852-983c-f28acb93d879"
@@ -29,6 +30,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
2930
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
3031
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
3132
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"
33+
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"
3234

3335
[weakdeps]
3436
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
@@ -55,6 +57,7 @@ AMDGPU = "0.9.6, 1"
5557
AppleAccelerate = "0.4"
5658
ArrayInterface = "7.9"
5759
BLISBLAS = "0.1"
60+
CEnum = "0.5.0"
5861
CUDA = "5.3.2"
5962
ChainRulesCore = "1.24"
6063
Compat = "4.15.0"
@@ -85,3 +88,4 @@ Tracker = "0.2.34"
8588
UnrolledUtilities = "0.1.2"
8689
cuDNN = "1.3"
8790
julia = "1.10"
91+
oneDNN_jll = "3.5.3"

generators/Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
[deps]
2+
Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
3+
oneDNN_jll = "3523a63d-8698-5b6f-b2c2-68eaa6bde0f0"
4+
5+
[compat]
6+
Clang = "0.18"
7+
oneDNN_jll = "3.5.3"

generators/generator.toml

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
[general]
2+
# it could also be an expression as long as `Meta.parse` can parse this string successfully.
3+
# basically, it should be the `expression` in the following code:
4+
# ccall((function_name, expression), returntype, (argtype1, ...), argvalue1, ...)
5+
library_name = "libdnnl"
6+
7+
# this entry allows you to specify different library names for different headers.
8+
# in the following example:
9+
# library_names = {"config.h" = "libclang_config", "libclang_p.*.h" = "libclang_patch"}
10+
# those functions in the `config.h` will be generated as:
11+
# ccall((function_name, libclang_config), returntype, (argtype1, ...), argvalue1, ...)
12+
library_names = {}
13+
14+
# output file path relative to the working directory
15+
output_file_path = "../src/onednn/lib.jl"
16+
17+
# if these are set, common file (types and constants) and API file (functions) will be separated
18+
# this is for compatibility, so prologue and epilogue are not supported.
19+
# output_api_file_path = "api.jl"
20+
# output_common_file_path = "common.jl"
21+
22+
# if this entry is not empty, the generator will print the code below to the `output_file_path`.
23+
# module module_name
24+
#
25+
# end # module
26+
module_name = "Lib"
27+
28+
# if this entry is not empty, the generator will print the code below to the `output_file_path`.
29+
# using jll_pkg_name
30+
# export jll_pkg_name
31+
jll_pkg_name = "oneDNN_jll"
32+
33+
# for packages that have extra JLL package dependencies
34+
jll_pkg_extra = []
35+
36+
# identifiers that starts with the string listed in this entry will be exported.
37+
export_symbol_prefixes = ["CX", "clang_"]
38+
39+
# the code in the following file will be copy-pasted to `output_file_path` before the generated code.
40+
# this is often used for applying custom patches, e.g. adding missing definitions.
41+
prologue_file_path = "prologue.jl"
42+
43+
# the code in the following file will be copy-pasted to `output_file_path` after the generated code.
44+
# this is often used for applying custom patches.
45+
epilogue_file_path = ""
46+
47+
# node with an id in the `output_ignorelist` will be ignored in the printing passes.
48+
# this is very useful for custom editing.
49+
output_ignorelist = [
50+
"CINDEX_EXPORTS",
51+
"CINDEX_VERSION",
52+
"CINDEX_VERSION_STRING",
53+
"CINDEX_LINKAGE",
54+
"CINDEX_DEPRECATED",
55+
"LLVM_CLANG_C_STRICT_PROTOTYPES_BEGIN",
56+
"LLVM_CLANG_C_STRICT_PROTOTYPES_END",
57+
"LLVM_CLANG_C_EXTERN_C_BEGIN",
58+
"LLVM_CLANG_C_EXTERN_C_END"
59+
]
60+
61+
# Julia's `@enum` do not allow duplicated values, so by default, C enums are translated to
62+
# CEnum.jl's `@cenum`.
63+
# if this entry is true, `@enum` is used and those duplicated enum constants are just commented.
64+
use_julia_native_enum_type = false
65+
66+
# use `@cenum` but do not print `using CEnum`.
67+
# this is useful in the case of using `CEnum` directly in the source tree instead of using `CEnum` as a dependency
68+
print_using_CEnum = false
69+
70+
# Print enums directly as integers without @(c)enum wrapper
71+
# Override above two options
72+
print_enum_as_integer = false
73+
74+
# use deterministic symbol instead of `gensym`-generated `var"##XXX"`
75+
use_deterministic_symbol = true
76+
77+
# by default, only those declarations in the local header file are processed.
78+
# those declarations in the system headers will be treated specially and will be generated if necessary.
79+
# if you'd like to generate all of the symbols in the system headers, please set this option to false.
80+
is_local_header_only = true
81+
82+
# set this option to false if you'd like to ignore the symbols(even if necessary) in the system headers.
83+
generate_isystem_symbols = true
84+
85+
# if this option is set to true, C code with a style of
86+
# ```c
87+
# typedef struct {
88+
# int x;
89+
# } my_struct;
90+
# ```
91+
# will be generated as:
92+
# ```julia
93+
# struct my_struct
94+
# x::Cint
95+
# end
96+
# ```
97+
# instead of
98+
# ```julia
99+
# struct var"##Ctag#NUM"
100+
# x::Cint
101+
# end
102+
# const my_struct = var"##Ctag#NUM"
103+
# ```
104+
smart_de_anonymize = true
105+
106+
# if set to true, static functions will be ignored
107+
skip_static_functions = false
108+
109+
# EXPERIMENTAL
110+
# if this option is set to true, those structs that are not necessary to be an
111+
# immutable struct will be generated as a mutable struct.
112+
# this option is default to false, do read the paragraph below before using this feature.
113+
auto_mutability = false
114+
115+
# add inner constructor `Foo() = new()`
116+
auto_mutability_with_new = true
117+
118+
# if you feel like certain structs should not be generated as mutable struct, please add them in the following list.
119+
# for example, if a C function accepts a `Vector` of some type as its argument like:
120+
# void foo(mutable_type *list, int n);
121+
# when calling this function via `ccall`, passing a `Vector{mutable_type}(undef, n)` to the first
122+
# argument will trigger a crash, the reason is mutable structs are not stored inline within a `Vector`,
123+
# one should use `Ref{NTuple{n,mutable_type}}()` instead.
124+
# this is not convenient and that's where the `auto_mutability_ignorelist` comes in.
125+
auto_mutability_ignorelist = []
126+
127+
# opposite to `auto_mutability_ignorelist` and has a higher priority
128+
auto_mutability_includelist = []
129+
130+
# if set to "raw", extract and dump raw c comment;
131+
# if set to "doxygen", parse and format doxygen comment.
132+
# note: by default, Clang only parses doxygen comment, pass `-fparse-all-comments` to Clang in order to parse non-doxygen comments.
133+
extract_c_comment_style = "doxygen"
134+
135+
# Pass a function to explicitly generate documentation. It will be called like
136+
# `callback_documentation(node::ExprNode, doc::Vector{String})` if it is
137+
# set. The `doc` argument will contain the docs parsed from the headers if
138+
# `extract_c_comment_style` is set, otherwise it will be an empty vector.
139+
#
140+
# Do *not* set this in the TOML file, it should be set in the generator script
141+
# to a function that takes in an ExprNode and returns a String[] (string
142+
# vector).
143+
# callback_documentation = ""
144+
145+
# if set to true, single line comment will be printed as """comment""" instead of """\ncomment\n"""
146+
fold_single_line_comment = false
147+
148+
# if set to "outofline", documentation of struct fields will be collected at the "Fields" section of the struct
149+
# if set to "inline", documentation of struct fields will go right above struct definition
150+
struct_field_comment_style = "outofline"
151+
152+
# if set to "outofline", documentation of enumerators will be collected at the "Enumerators" section of the enum
153+
enumerator_comment_style = "outofline"
154+
155+
# if set to true, C function prototype will be included in documentation
156+
show_c_function_prototype = false
157+
158+
[codegen]
159+
# map C's bool to Julia's Bool instead of `Cuchar` a.k.a `UInt8`.
160+
use_julia_bool = true
161+
162+
# set this to true if the C routine always expects a NUL-terminated string.
163+
# TODO: support filtering
164+
always_NUL_terminated_string = true
165+
166+
# generate strictly typed function
167+
is_function_strictly_typed = false
168+
169+
# if true, opaque pointers in function arguments will be translated to `Ptr{Cvoid}`.
170+
opaque_func_arg_as_PtrCvoid = false
171+
172+
# if true, opaque types are translated to `mutable struct` instead of `Cvoid`.
173+
opaque_as_mutable_struct = true
174+
175+
# if true, use Julia 1.5's new `@ccall` macro
176+
use_ccall_macro = true
177+
178+
# if true, variadic functions are wrapped with `@ccall` macro. Otherwise variadic functions are ignored.
179+
wrap_variadic_function = false
180+
181+
# generate getproperty/setproperty! methods for the types in the following list
182+
field_access_method_list = []
183+
184+
# the generator will prefix the function argument names in the following list with a "_" to
185+
# prevent the generated symbols from conflicting with the symbols defined and exported in Base.
186+
function_argument_conflict_symbols = []
187+
188+
# emit constructors for all custom-layout structs like bitfield in the list,
189+
# or set to `true` to do so for all such structs
190+
add_record_constructors = []
191+
192+
[codegen.macro]
193+
# it‘s highly recommended to set this entry to "basic".
194+
# if you'd like to skip all of the macros, please set this entry to "disable".
195+
# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive".
196+
macro_mode = "basic"
197+
198+
# function-like macros in the following list will always be translated.
199+
functionlike_macro_includelist = [
200+
"CINDEX_VERSION_ENCODE",
201+
]
202+
203+
# if true, the generator prints the following message as comments.
204+
# "# Skipping MacroDefinition: ..."
205+
add_comment_for_skipped_macro = true
206+
207+
# if true, ignore any macros that is suffixed with "_H" or in the `ignore_header_guards_with_suffixes` list
208+
ignore_header_guards = true
209+
ignore_header_guards_with_suffixes = []
210+
211+
# if true, ignore those pure definition macros in the C code
212+
ignore_pure_definition = true

generators/prologue.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
using CEnum: @cenum
2+
3+
const NULL = C_NULL
4+
5+
# This file is automatically generated by Clang.jl. Don't edit it manually. If needed,
6+
# look at the "generators/" directory and modify the relevant files there.

generators/wrap.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using Clang.Generators
2+
using oneDNN_jll
3+
4+
cur_dir = pwd()
5+
6+
cd(@__DIR__)
7+
8+
include_dir = joinpath(oneDNN_jll.artifact_dir, "include")
9+
10+
options = load_options(joinpath(@__DIR__, "generator.toml"))
11+
12+
onednn_headers = [
13+
joinpath(include_dir, "dnnl.h"),
14+
joinpath(include_dir, "dnnl_types.h"),
15+
joinpath(include_dir, "dnnl_config.h"),
16+
joinpath(include_dir, "dnnl_version.h")
17+
]
18+
19+
args = get_default_args()
20+
push!(args, "-I$include_dir")
21+
22+
ctx = create_context(onednn_headers, args, options)
23+
24+
# run generator
25+
build!(ctx, BUILDSTAGE_NO_PRINTING)
26+
27+
function rewrite!(e::Expr)
28+
# const DNNL_RUNTIME_SIZE_VAL = size_t(DNNL_RUNTIME_DIM_VAL)
29+
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
30+
e.args[1].args[1] == :DNNL_RUNTIME_SIZE_VAL && e.args[1].args[2] isa Expr &&
31+
e.args[1].args[2].head == :call && e.args[1].args[2].args[1] == :size_t &&
32+
e.args[1].args[2].args[2] == :DNNL_RUNTIME_DIM_VAL
33+
e.args[1].args[2] = unsigned(typemin(Int64))
34+
return
35+
end
36+
# const DNNL_RUNTIME_DIM_VAL = INT64_MIN
37+
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
38+
e.args[1].args[1] == :DNNL_RUNTIME_DIM_VAL && e.args[1].args[2] == :INT64_MIN
39+
e.args[1].args[2] = typemin(Int64)
40+
return
41+
end
42+
# const DNNL_RUNTIME_S32_VAL = DNNL_RUNTIME_S32_VAL_REP
43+
if e.head == :const && e.args[1] isa Expr && e.args[1].head == :(=) &&
44+
e.args[1].args[1] == :DNNL_RUNTIME_S32_VAL &&
45+
e.args[1].args[2] == :DNNL_RUNTIME_S32_VAL_REP
46+
e.args[1].args[2] = 0
47+
return
48+
end
49+
return
50+
end
51+
52+
function rewrite!(dag::ExprDAG)
53+
for node in get_nodes(dag), expr in get_exprs(node)
54+
rewrite!(expr)
55+
end
56+
end
57+
58+
rewrite!(ctx.dag)
59+
60+
build!(ctx, BUILDSTAGE_PRINTING_ONLY)
61+
62+
cd(cur_dir)

src/LuxLib.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@ const CRC = ChainRulesCore
1818

1919
include("utils.jl")
2020
include("traits.jl")
21+
22+
include("onednn/oneDNN.jl")
23+
2124
include("impl/Impl.jl")
25+
2226
include("api/API.jl")
2327

2428
@compat(public,

0 commit comments

Comments
 (0)