@@ -8,6 +8,7 @@ function default_backend()
8
8
JLBackend
9
9
end
10
10
end
11
+
11
12
let compute_contexts = Context[]
12
13
function current_context ()
13
14
if isempty (compute_contexts)
45
46
function free (x:: AbstractArray )
46
47
47
48
end
49
+ #=
50
+ Functions to select contexts
51
+ =#
52
+
53
+ is_gpu (ctx) = false
54
+ is_cpu (ctx) = false
55
+ is_opencl (ctx) = false
56
+ is_cudanative (ctx) = false
57
+ is_julia (ctx) = false
58
+ is_opengl (ctx) = false
59
+ has_atleast (ctx, attribute, value) = error (" has_atleast not implemented yet" )
48
60
49
61
# BLAS support
50
62
hasblas (x) = false
51
63
include (" blas.jl" )
52
64
include (" supported_backends.jl" )
53
65
include (" shared.jl" )
54
66
55
- function init (sym:: Symbol , args... ; kw_args... )
56
- if sym == :julia
57
- JLBackend. init (args... ; kw_args... )
58
- elseif sym == :cudanative
59
- CUBackend. init (args... ; kw_args... )
60
- elseif sym == :opencl
61
- CLBackend. init (args... ; kw_args... )
62
- elseif sym == :opengl
63
- GLBackend. init (args... ; kw_args... )
67
+ function to_backend_module (backend:: Symbol )
68
+ if backend in supported_backends ()
69
+ if sym == :julia
70
+ JLBackend
71
+ elseif sym == :cudanative
72
+ CUBackend
73
+ elseif sym == :opencl
74
+ CLBackend
75
+ elseif sym == :opengl
76
+ GLBackend
77
+ end
64
78
else
65
79
error (" $sym not a supported backend. Try one of: $(supported_backends ()) " )
66
80
end
67
81
end
82
+ function init (sym:: Symbol , args... ; kw_args... )
83
+ backend_module (sym). init (args... ; kw_args... )
84
+ end
85
+ function init (filterfuncs:: Function... ; kw_args... )
86
+ init_from_device (first (devices (filterfuncs... )))
87
+ end
88
+ backend_modules () = to_backend_module .(supported_backends ())
89
+
90
+
68
91
69
92
93
+ function devices (filter_funcs... )
94
+ result = []
95
+ for Module in backend_modules ()
96
+ for device in Module. devices ()
97
+ if all (f-> f (device), filter_funcs)
98
+ push! (result, device)
99
+ end
100
+ end
101
+ end
102
+ result
103
+ end
104
+
70
105
"""
71
106
Iterates through all backends and calls `f` after initializing the current one!
72
107
"""
73
108
function perbackend (f)
74
109
for backend in supported_backends ()
75
110
ctx = GPUArrays. init (backend)
76
- f (backend)
111
+ f (ctx)
112
+ end
113
+ end
114
+
115
+ """
116
+ Iterates through all available devices and calls `f` after initializing the current one!
117
+ """
118
+ function forall_devices (f, filterfuncs... )
119
+ for device in devices (filterfunc)
120
+ make_current (device)
121
+ f (device)
77
122
end
78
123
end
0 commit comments