Skip to content

Commit 477172a

Browse files
committed
fix ndim blocks
1 parent 2251ae9 commit 477172a

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/jlbackend.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ function gpu_call(f, A::JLArray, args::Tuple, blocks = nothing, threads = C_NULL
106106
blocks = (blocks,)
107107
end
108108
if threads == C_NULL
109-
threads = (1,)
109+
threads = map(x-> 1, blocks)
110110
end
111111
idx = ntuple(i-> 1, length(blocks))
112112
blockdim = ceil.(Int, blocks ./ threads)
113113
state = JLState(threads, blockdim)
114114
device_args = to_device.(state, args)
115-
tasks = Vector{Task}(threads...)
115+
tasks = Vector{Task}(prod(threads))
116116
for blockidx in CartesianRange(blockdim)
117117
state.blockidx = blockidx.I
118118
block_args = to_blocks.(state, device_args)
@@ -131,6 +131,7 @@ end
131131
struct JLDevice end
132132
device(x::JLArray) = JLDevice()
133133
threads(dev::JLDevice) = 256
134+
blocks(dev::JLDevice) = (256, 256, 256)
134135

135136

136137
@inline function synchronize_threads(::JLState)

0 commit comments

Comments
 (0)