Skip to content

Commit 908e7f8

Browse files
committed
fix CUDActx overdubs
1 parent 9d5769b commit 908e7f8

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/backends/cuda.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ end
201201
###
202202
# GPU implementation of shared memory
203203
###
204-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
204+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(SharedMemory), ::Type{T}, ::Val{Dims}, ::Val{Id}) where {T, Dims, Id}
205205
ptr = CUDAnative._shmem(Val(Id), T, Val(prod(Dims)))
206206
CUDAnative.CuDeviceArray(Dims, CUDAnative.DevicePtr{T, CUDAnative.AS.Shared}(ptr))
207207
end
@@ -211,6 +211,6 @@ end
211211
# - private memory for each workitem
212212
###
213213

214-
@inline function Cassette.overdub(ctx::CPUCtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
214+
@inline function Cassette.overdub(ctx::CUDACtx, ::typeof(Scratchpad), ::Type{T}, ::Val{Dims}) where {T, Dims}
215215
MArray{__size(Dims), T}(undef)
216216
end

0 commit comments

Comments
 (0)