Skip to content

Commit e5f6dc7

Browse files
authored
Minimal support for native Windows (#394)
Doesn't include support for vendor libraries, due to the binary parts of the oneAPI toolkit we need to build the oneAPI support library being incompatible with our MinGW build environment on Yggdrasil (as they're MSVC-generated, introducing C++ ABI incompatibilities).
1 parent ea463c1 commit e5f6dc7

File tree

13 files changed

+146
-105
lines changed

13 files changed

+146
-105
lines changed

Project.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,37 @@ ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
1010
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
1111
GPUCompiler = "61eb1bfa-7361-4325-ad38-22787b887f55"
1212
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
13-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1413
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
14+
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
15+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
NEO_jll = "700fe977-ac61-5f37-bbc8-c6c4b2b6a9fd"
16-
oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
17-
oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
18-
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
1917
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2018
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2119
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
22-
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2320
SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361"
2421
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
22+
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
2523
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
24+
oneAPI_Level_Zero_Headers_jll = "f4bc562b-d309-54f8-9efb-476e56f0410d"
25+
oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
26+
oneAPI_Support_jll = "b049733a-a71d-5ed3-8eba-7d323ac00b36"
2627

2728
[compat]
2829
Adapt = "4"
2930
CEnum = "0.4, 0.5"
3031
ExprTools = "0.1"
3132
GPUArrays = "10"
3233
GPUCompiler = "0.23, 0.24, 0.25, 0.26"
33-
julia = "1.8"
3434
KernelAbstractions = "0.9.1"
3535
LLVM = "6"
3636
NEO_jll = "=24.05.28454"
37-
oneAPI_Level_Zero_Loader_jll = "1.9"
38-
oneAPI_Support_jll = "~0.3.1"
3937
Preferences = "1"
40-
SpecialFunctions = "1.3, 2"
4138
SPIRV_LLVM_Translator_unified_jll = "0.3"
39+
SpecialFunctions = "1.3, 2"
4240
StaticArrays = "1"
41+
julia = "1.8"
42+
oneAPI_Level_Zero_Loader_jll = "1.9"
43+
oneAPI_Support_jll = "~0.3.1"
4344

4445
[extras]
4546
libigc_jll = "94295238-5935-5bd7-bb0f-b00942e9bdd5"

lib/level-zero/oneL0.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,14 @@ using CEnum
66

77
using Printf
88

9-
using NEO_jll
10-
using oneAPI_Level_Zero_Loader_jll
9+
using Libdl
10+
11+
if Sys.iswindows()
12+
const libze_loader = "ze_loader"
13+
else
14+
using NEO_jll
15+
using oneAPI_Level_Zero_Loader_jll
16+
end
1117

1218
include("utils.jl")
1319
include("pointer.jl")
@@ -85,16 +91,23 @@ function __init__()
8591
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
8692
precompiling && return
8793

88-
if !oneAPI_Level_Zero_Loader_jll.is_available()
89-
@error """No oneAPI Level Zero loader found for your platform. Currently, only Linux x86 is supported.
90-
If you have a local oneAPI toolchain, you can use that; refer to the documentation for more details."""
91-
return
92-
end
94+
if Sys.iswindows()
95+
if Libdl.dlopen(libze_loader; throw_error=false) === nothing
96+
@error "The oneAPI Level Zero loader was not found. Please ensure the Intel GPU drivers are installed."
97+
return
98+
end
99+
else
100+
if !oneAPI_Level_Zero_Loader_jll.is_available()
101+
@error """No oneAPI Level Zero loader found for your platform. Currently, only Linux x86 is supported.
102+
If you have a local oneAPI toolchain, you can use that; refer to the documentation for more details."""
103+
return
104+
end
93105

94-
if !NEO_jll.is_available()
95-
@error """No oneAPI driver found for your platform. Currently, only Linux x86_64 is supported.
96-
If you have a local oneAPI toolchain, you can use that; refer to the documentation for more details."""
97-
return
106+
if !NEO_jll.is_available()
107+
@error """No oneAPI driver found for your platform. Currently, only Linux x86_64 is supported.
108+
If you have a local oneAPI toolchain, you can use that; refer to the documentation for more details."""
109+
return
110+
end
98111
end
99112

100113
try

lib/support/Support.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ end
2222
function __init__()
2323
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
2424
precompiling && return
25-
25+
2626
if !oneAPI_Support_jll.is_available()
2727
@error """oneAPI support wrapper not available for your platform."""
2828
return

src/device/opencl/integer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Integer Functions
22

33
# TODO: vector types
4-
const generic_integer_types = [Cchar, Cuchar, Cshort, Cushort, Cint, Cuint, Clong, Culong]
4+
const generic_integer_types = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64]
55

66

77
# generically typed

src/device/opencl/math.jl

Lines changed: 50 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Math Functions
22

33
# TODO: vector types
4-
const generic_types = [Cfloat,Cdouble]
5-
const generic_types_float = [Cfloat]
6-
const generic_types_double = [Cdouble]
4+
const generic_types = [Float32,Float64]
5+
const generic_types_float = [Float32]
6+
const generic_types_double = [Float64]
77

88

99
# generically typed
@@ -112,10 +112,10 @@ end
112112

113113
for gentypef in generic_types_float
114114

115-
if gentypef !== Cfloat
115+
if gentypef !== Float32
116116
@eval begin
117-
@device_override Base.max(x::$gentypef, y::Cfloat) = @builtin_ccall("fmax", $gentypef, ($gentypef, Cfloat), x, y)
118-
@device_override Base.min(x::$gentypef, y::Cfloat) = @builtin_ccall("fmin", $gentypef, ($gentypef, Cfloat), x, y)
117+
@device_override Base.max(x::$gentypef, y::Float32) = @builtin_ccall("fmax", $gentypef, ($gentypef, Float32), x, y)
118+
@device_override Base.min(x::$gentypef, y::Float32) = @builtin_ccall("fmin", $gentypef, ($gentypef, Float32), x, y)
119119
end
120120
end
121121

@@ -126,10 +126,10 @@ end
126126

127127
for gentyped in generic_types_double
128128

129-
if gentyped !== Cdouble
129+
if gentyped !== Float64
130130
@eval begin
131-
@device_override Base.min(x::$gentyped, y::Cdouble) = @builtin_ccall("fmin", $gentyped, ($gentyped, Cdouble), x, y)
132-
@device_override Base.max(x::$gentyped, y::Cdouble) = @builtin_ccall("fmax", $gentyped, ($gentyped, Cdouble), x, y)
131+
@device_override Base.min(x::$gentyped, y::Float64) = @builtin_ccall("fmin", $gentyped, ($gentyped, Float64), x, y)
132+
@device_override Base.max(x::$gentyped, y::Float64) = @builtin_ccall("fmax", $gentyped, ($gentyped, Float64), x, y)
133133
end
134134
end
135135

@@ -138,47 +138,47 @@ end
138138

139139
# specifically typed
140140

141-
# frexp(x::Cfloat{n}, Cint{n} *exp) = @builtin_ccall("frexp", Cfloat{n}, (Cfloat{n}, Cint{n} *), x, exp)
142-
# frexp(x::Cfloat, Cint *exp) = @builtin_ccall("frexp", Cfloat, (Cfloat, Cint *), x, exp)
143-
# frexp(x::Cdouble{n}, Cint{n} *exp) = @builtin_ccall("frexp", Cdouble{n}, (Cdouble{n}, Cint{n} *), x, exp)
144-
# frexp(x::Cdouble, Cint *exp) = @builtin_ccall("frexp", Cdouble, (Cdouble, Cint *), x, exp)
145-
146-
# ilogb(x::Cfloat{n}) = @builtin_ccall("ilogb", Cint{n}, (Cfloat{n},), x)
147-
@device_function ilogb(x::Cfloat) = @builtin_ccall("ilogb", Cint, (Cfloat,), x)
148-
# ilogb(x::Cdouble{n}) = @builtin_ccall("ilogb", Cint{n}, (Cdouble{n},), x)
149-
@device_function ilogb(x::Cdouble) = @builtin_ccall("ilogb", Cint, (Cdouble,), x)
150-
151-
# ldexp(x::Cfloat{n}, k::Cint{n}) = @builtin_ccall("ldexp", Cfloat{n}, (Cfloat{n}, Cint{n}), x, k)
152-
# ldexp(x::Cfloat{n}, k::Cint) = @builtin_ccall("ldexp", Cfloat{n}, (Cfloat{n}, Cint), x, k)
153-
@device_override Base.ldexp(x::Cfloat, k::Cint) = @builtin_ccall("ldexp", Cfloat, (Cfloat, Cint), x, k)
154-
# ldexp(x::Cdouble{n}, k::Cint{n}) = @builtin_ccall("ldexp", Cdouble{n}, (Cdouble{n}, Cint{n}), x, k)
155-
# ldexp(x::Cdouble{n}, k::Cint) = @builtin_ccall("ldexp", Cdouble{n}, (Cdouble{n}, Cint), x, k)
156-
@device_override Base.ldexp(x::Cdouble, k::Cint) = @builtin_ccall("ldexp", Cdouble, (Cdouble, Cint), x, k)
157-
158-
# lgamma_r(x::Cfloat{n}, Cint{n} *signp) = @builtin_ccall("lgamma_r", Cfloat{n}, (Cfloat{n}, Cint{n} *), x, signp)
159-
# lgamma_r(x::Cfloat, Cint *signp) = @builtin_ccall("lgamma_r", Cfloat, (Cfloat, Cint *), x, signp)
160-
# lgamma_r(x::Cdouble{n}, Cint{n} *signp) = @builtin_ccall("lgamma_r", Cdouble{n}, (Cdouble{n}, Cint{n} *), x, signp)
161-
# Cdouble lgamma_r(x::Cdouble, Cint *signp) = @builtin_ccall("lgamma_r", Cdouble, (Cdouble, Cint *), x, signp)
162-
163-
# nan(nancode::uintn) = @builtin_ccall("nan", Cfloat{n}, (uintn,), nancode)
164-
@device_function nan(nancode::Cuint) = @builtin_ccall("nan", Cfloat, (Cuint,), nancode)
165-
# nan(nancode::Culong{n}) = @builtin_ccall("nan", Cdouble{n}, (Culong{n},), nancode)
166-
@device_function nan(nancode::Culong) = @builtin_ccall("nan", Cdouble, (Culong,), nancode)
167-
168-
# pown(x::Cfloat{n}, y::Cint{n}) = @builtin_ccall("pown", Cfloat{n}, (Cfloat{n}, Cint{n}), x, y)
169-
@device_override Base.:(^)(x::Cfloat, y::Cint) = @builtin_ccall("pown", Cfloat, (Cfloat, Cint), x, y)
170-
# pown(x::Cdouble{n}, y::Cint{n}) = @builtin_ccall("pown", Cdouble{n}, (Cdouble{n}, Cint{n}), x, y)
171-
@device_override Base.:(^)(x::Cdouble, y::Cint) = @builtin_ccall("pown", Cdouble, (Cdouble, Cint), x, y)
172-
173-
# remquo(x::Cfloat{n}, y::Cfloat{n}, Cint{n} *quo) = @builtin_ccall("remquo", Cfloat{n}, (Cfloat{n}, Cfloat{n}, Cint{n} *), x, y, quo)
174-
# remquo(x::Cfloat, y::Cfloat, Cint *quo) = @builtin_ccall("remquo", Cfloat, (Cfloat, Cfloat, Cint *), x::Cfloat, y, quo)
175-
# remquo(x::Cdouble{n}, y::Cdouble{n}, Cint{n} *quo) = @builtin_ccall("remquo", Cdouble{n}, (Cdouble{n}, Cdouble{n}, Cint{n} *), x, y, quo)
176-
# remquo(x::Cdouble, y::Cdouble, Cint *quo) = @builtin_ccall("remquo", Cdouble, (Cdouble, Cdouble, Cint *), x, y, quo)
177-
178-
# rootn(x::Cfloat{n}, y::Cint{n}) = @builtin_ccall("rootn", Cfloat{n}, (Cfloat{n}, Cint{n}), x, y)
179-
@device_function rootn(x::Cfloat, y::Cint) = @builtin_ccall("rootn", Cfloat, (Cfloat, Cint), x, y)
180-
# rootn(x::Cdouble{n}, y::Cint{n}) = @builtin_ccall("rootn", Cdouble{n}, (Cdouble{n}, Cint{n}), x, y)
181-
# rootn(x::Cdouble, y::Cint) = @builtin_ccall("rootn", Cdouble{n}, (Cdouble, Cint), x, y)
141+
# frexp(x::Float32{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float32{n}, (Float32{n}, Int32{n} *), x, exp)
142+
# frexp(x::Float32, Int32 *exp) = @builtin_ccall("frexp", Float32, (Float32, Int32 *), x, exp)
143+
# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp)
144+
# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp)
145+
146+
# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x)
147+
@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x)
148+
# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x)
149+
@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x)
150+
151+
# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k)
152+
# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k)
153+
@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k)
154+
# ldexp(x::Float64{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float64{n}, (Float64{n}, Int32{n}), x, k)
155+
# ldexp(x::Float64{n}, k::Int32) = @builtin_ccall("ldexp", Float64{n}, (Float64{n}, Int32), x, k)
156+
@device_override Base.ldexp(x::Float64, k::Int32) = @builtin_ccall("ldexp", Float64, (Float64, Int32), x, k)
157+
158+
# lgamma_r(x::Float32{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float32{n}, (Float32{n}, Int32{n} *), x, signp)
159+
# lgamma_r(x::Float32, Int32 *signp) = @builtin_ccall("lgamma_r", Float32, (Float32, Int32 *), x, signp)
160+
# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp)
161+
# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp)
162+
163+
# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode)
164+
@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode)
165+
# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode)
166+
@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode)
167+
168+
# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y)
169+
@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y)
170+
# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y)
171+
@device_override Base.:(^)(x::Float64, y::Int32) = @builtin_ccall("pown", Float64, (Float64, Int32), x, y)
172+
173+
# remquo(x::Float32{n}, y::Float32{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float32{n}, (Float32{n}, Float32{n}, Int32{n} *), x, y, quo)
174+
# remquo(x::Float32, y::Float32, Int32 *quo) = @builtin_ccall("remquo", Float32, (Float32, Float32, Int32 *), x::Float32, y, quo)
175+
# remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo)
176+
# remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo)
177+
178+
# rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y)
179+
@device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y)
180+
# rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y)
181+
# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y)
182182

183183

184184
# TODO: half and native

src/device/opencl/synchronization.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
export barrier
44

5-
const cl_mem_fence_flags = Cuint
5+
const cl_mem_fence_flags = UInt32
66
const CLK_LOCAL_MEM_FENCE = cl_mem_fence_flags(1)
77
const CLK_GLOBAL_MEM_FENCE = cl_mem_fence_flags(2)
88

9-
#barrier(flags=0) = @builtin_ccall("barrier", Cvoid, (Cuint,), flags)
9+
#barrier(flags=0) = @builtin_ccall("barrier", Cvoid, (UInt32,), flags)
1010
barrier(flags=0) = Base.llvmcall(("""
1111
declare void @_Z7barrierj(i32) #0
1212
define void @entry(i32 %0) #1 {

src/device/opencl/work_item.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@ export get_work_dim,
1212

1313
# TODO: 1-indexed dimension selection?
1414

15-
get_work_dim() = @builtin_ccall("get_work_dim", Cuint, ()) % Int
15+
get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int
1616

17-
get_global_size(dimindx::Integer=0) = @builtin_ccall("get_global_size", Csize_t, (Cuint,), dimindx) % Int
18-
get_global_id(dimindx::Integer=0) = @builtin_ccall("get_global_id", Csize_t, (Cuint,), dimindx) % Int + 1
17+
get_global_size(dimindx::Integer=0) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx) % Int
18+
get_global_id(dimindx::Integer=0) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx) % Int + 1
1919

20-
get_local_size(dimindx::Integer=0) = @builtin_ccall("get_local_size", Csize_t, (Cuint,), dimindx) % Int
21-
get_enqueued_local_size(dimindx::Integer=0) = @builtin_ccall("get_enqueued_local_size", Csize_t, (Cuint,), dimindx) % Int
22-
get_local_id(dimindx::Integer=0) = @builtin_ccall("get_local_id", Csize_t, (Cuint,), dimindx) % Int + 1
20+
get_local_size(dimindx::Integer=0) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx) % Int
21+
get_enqueued_local_size(dimindx::Integer=0) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx) % Int
22+
get_local_id(dimindx::Integer=0) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx) % Int + 1
2323

24-
get_num_groups(dimindx::Integer=0) = @builtin_ccall("get_num_groups", Csize_t, (Cuint,), dimindx) % Int
25-
get_group_id(dimindx::Integer=0) = @builtin_ccall("get_group_id", Csize_t, (Cuint,), dimindx) % Int + 1
24+
get_num_groups(dimindx::Integer=0) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx) % Int
25+
get_group_id(dimindx::Integer=0) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx) % Int + 1
2626

27-
get_global_offset(dimindx::Integer=0) = @builtin_ccall("get_global_offset", Csize_t, (Cuint,), dimindx) % Int + 1
27+
get_global_offset(dimindx::Integer=0) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx) % Int + 1
2828

29-
get_global_linear_id() = @builtin_ccall("get_global_linear_id", Csize_t, ()) % Int + 1
30-
get_local_linear_id() = @builtin_ccall("get_local_linear_id", Csize_t, ()) % Int + 1
29+
get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1
30+
get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1

src/device/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ macro builtin_ccall(name, ret, argtypes, args...)
1919
"l"
2020
elseif T == Culong
2121
"m"
22+
elseif T == Clonglong
23+
"x"
24+
elseif T == Culonglong
25+
"y"
2226
elseif T == Cshort
2327
"s"
2428
elseif T == Cushort

src/oneAPI.jl

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,12 @@ using Core: LLVMPtr
1717

1818
using SPIRV_LLVM_Translator_unified_jll, SPIRV_Tools_jll
1919

20-
export oneL0, SYCL
20+
export oneL0
2121

2222
# core library
2323
include("../lib/utils/APIUtils.jl")
2424
include("../lib/level-zero/oneL0.jl")
25-
include("../lib/support/Support.jl")
26-
include("../lib/sycl/SYCL.jl")
27-
using .oneL0, .SYCL
25+
using .oneL0
2826
functional() = oneL0.functional[]
2927

3028
# device functionality (needs to be loaded first, because of generated functions)
@@ -54,9 +52,17 @@ include("compiler/compilation.jl")
5452
include("compiler/execution.jl")
5553
include("compiler/reflection.jl")
5654

55+
if Sys.islinux()
56+
# library interop
57+
include("../lib/support/Support.jl")
58+
include("../lib/sycl/SYCL.jl")
59+
using .SYCL
60+
export SYCL
61+
5762
# array libraries
5863
include("../lib/mkl/oneMKL.jl")
5964
export oneMKL
65+
end
6066

6167
# integrations and specialized functionality
6268
include("broadcast.jl")
@@ -73,13 +79,15 @@ function __init__()
7379
precompiling = ccall(:jl_generating_output, Cint, ()) != 0
7480
precompiling && return
7581

76-
if !Sys.islinux()
77-
@error("oneAPI.jl is only supported on Linux")
78-
return
82+
if Sys.iswindows()
83+
@warn """oneAPI.jl support for native Windows is experimental and incomplete.
84+
For the time being, it is recommended to use WSL or Linux instead."""
7985
end
8086

81-
# ensure that the OpenCL runtime dispatcher finds the ICD files from our artifacts
82-
ENV["OCL_ICD_VENDORS"] = oneL0.NEO_jll.libigdrcl
87+
if Sys.islinux()
88+
# ensure that the OpenCL runtime dispatcher finds the ICD files from our artifacts
89+
ENV["OCL_ICD_VENDORS"] = oneL0.NEO_jll.libigdrcl
90+
end
8391
end
8492

8593
function set_debug!(debug::Bool)

src/utils.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,20 +8,22 @@ function versioninfo(io::IO=stdout)
88
deps = Pkg.dependencies()
99
versions = Dict(map(uuid->deps[uuid].name => deps[uuid].version, collect(keys(deps))))
1010

11-
println(io, "Binary dependencies:")
12-
for jll in [oneL0.NEO_jll, oneL0.NEO_jll.libigc_jll, oneL0.NEO_jll.gmmlib_jll,
13-
SPIRV_LLVM_Translator_unified_jll, SPIRV_Tools_jll]
14-
name = string(jll)
15-
print(io, "- $(name[1:end-4]): $(versions[name])")
16-
if jll.host_platform !== nothing
17-
debug = tryparse(Bool, get(jll.host_platform.tags, "debug", "false"))
18-
if debug === true
19-
print(io, " (debug)")
11+
if Sys.islinux()
12+
println(io, "Binary dependencies:")
13+
for jll in [oneL0.NEO_jll, oneL0.NEO_jll.libigc_jll, oneL0.NEO_jll.gmmlib_jll,
14+
SPIRV_LLVM_Translator_unified_jll, SPIRV_Tools_jll]
15+
name = string(jll)
16+
print(io, "- $(name[1:end-4]): $(versions[name])")
17+
if jll.host_platform !== nothing
18+
debug = tryparse(Bool, get(jll.host_platform.tags, "debug", "false"))
19+
if debug === true
20+
print(io, " (debug)")
21+
end
2022
end
23+
println(io)
2124
end
2225
println(io)
2326
end
24-
println(io)
2527

2628
println(io, "Toolchain:")
2729
println(io, "- Julia: $VERSION")

0 commit comments

Comments
 (0)