Skip to content

Commit 66208dd

Browse files
committed
Deduct eltype from UniformScaling
1 parent 59f4887 commit 66208dd

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/host/construction.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,14 @@ function uniformscaling_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}
2828
return
2929
end
3030

31-
function (T::Type{<: AbstractGPUArray})(s::UniformScaling, dims::Dims{2})
31+
function (T::Type{<: AbstractGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
3232
res = zeros(T, dims)
3333
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
3434
res
3535
end
36+
37+
(T::Type{<: AbstractGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
38+
3639
(T::Type{<: AbstractGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
3740

3841
function indexstyle(x::T) where T

0 commit comments

Comments
 (0)