Skip to content

Commit 7f58602

Browse files
committed
mimicking CUDA
1 parent c3553fb commit 7f58602

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

src/gpuarrays.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@ GPUArrays.device(x::MtlArray) = x.dev
55
import KernelAbstractions
66
import KernelAbstractions: Backend
77

8-
@inline function GPUArrays.launch_heuristic(::MetalBackend, f::F, args::Vararg{Any,N};
9-
elements::Int, elements_per_thread::Int) where {F,N}
8+
@inline function GPUArrays.launch_heuristic(::MetalBackend, obj::O, args::Vararg{Any,N};
9+
elements::Int, elements_per_thread::Int) where {O,N}
1010

11-
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, nothing,
11+
ndrange = ceil(Int, elements / elements_per_thread)
12+
ndrange, workgroupsize, iterspace, dynamic = KA.launch_config(obj, ndrange,
1213
nothing)
1314

14-
# this might not be the final context, since we may tune the workgroupsize
1515
ctx = KA.mkcontext(obj, ndrange, iterspace)
1616

17-
kernel = @metal launch=false f(ctx, args...)
17+
kernel = @metal launch=false obj.f(ctx, args...)
1818

1919
# The pipeline state automatically computes occupancy stats
2020
threads = min(elements, kernel.pipeline.maxTotalThreadsPerThreadgroup)

0 commit comments

Comments
 (0)