@@ -5,16 +5,16 @@ GPUArrays.device(x::MtlArray) = x.dev
55import KernelAbstractions
66import KernelAbstractions: Backend
77
8- @inline function GPUArrays. launch_heuristic(:: MetalBackend , f :: F , args:: Vararg{Any,N} ;
9- elements:: Int , elements_per_thread:: Int ) where {F ,N}
8+ @inline function GPUArrays. launch_heuristic(:: MetalBackend , obj :: O , args:: Vararg{Any,N} ;
9+ elements:: Int , elements_per_thread:: Int ) where {O ,N}
1010
11- ndrange, workgroupsize, iterspace, dynamic = KA. launch_config(obj, nothing ,
11+ ndrange = ceil(Int, elements / elements_per_thread)
12+ ndrange, workgroupsize, iterspace, dynamic = KA. launch_config(obj, ndrange,
1213 nothing )
1314
14- # this might not be the final context, since we may tune the workgroupsize
1515 ctx = KA. mkcontext(obj, ndrange, iterspace)
1616
17- kernel = @metal launch= false f(ctx, args... )
17+ kernel = @metal launch= false obj . f(ctx, args... )
1818
1919 # The pipeline state automatically computes occupancy stats
2020 threads = min(elements, kernel. pipeline. maxTotalThreadsPerThreadgroup)
0 commit comments