Skip to content

Commit 5bc0bd0

Browse files
bors[bot]haampie
andauthored
Merge #254
254: Fix UniformScaling constructor + copyto! r=haampie a=haampie .. and improve the tests a bit to test rectangular matrices. Fixes the issue where `CuArray(1f0I, 3, 4)` did not deduce the element type Co-authored-by: Harmen Stoppels <[email protected]>
2 parents b1be744 + 121f5ff commit 5bc0bd0

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/host/construction.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,22 @@ function uniformscaling_kernel(ctx::AbstractKernelContext, res::AbstractArray{T}
2828
return
2929
end
3030

31-
function (T::Type{<: AbstractGPUArray})(s::UniformScaling, dims::Dims{2})
31+
function (T::Type{<: AbstractGPUArray{U}})(s::UniformScaling, dims::Dims{2}) where {U}
3232
res = zeros(T, dims)
3333
gpu_call(uniformscaling_kernel, res, size(res, 1), s; total_threads=minimum(dims))
3434
res
3535
end
36+
37+
(T::Type{<: AbstractGPUArray})(s::UniformScaling{U}, dims::Dims{2}) where U = T{U}(s, dims)
38+
3639
(T::Type{<: AbstractGPUArray})(s::UniformScaling, m::Integer, n::Integer) = T(s, Dims((m, n)))
3740

41+
function Base.copyto!(A::AbstractGPUMatrix{T}, s::UniformScaling) where T
42+
fill!(A, zero(T))
43+
gpu_call(uniformscaling_kernel, A, size(A, 1), s; total_threads=minimum(size(A)))
44+
A
45+
end
46+
3847
function indexstyle(x::T) where T
3948
style = try
4049
Base.IndexStyle(x)

test/testsuite/construction.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,26 @@ function value_constructor(AT)
164164
@test all(x-> x == 2f0, Array(x1))
165165
@test all(x-> x == Int32(77), Array(x2))
166166

167-
x = Matrix{T}(I, 2, 2)
167+
x = Matrix{T}(I, 4, 2)
168168

169-
x1 = AT{T, 2}(I, 2, 2)
170-
x2 = AT{T}(I, (2, 2))
171-
x3 = AT{T, 2}(I, (2, 2))
169+
x1 = AT{T, 2}(I, 4, 2)
170+
x2 = AT{T}(I, (4, 2))
171+
x3 = AT{T, 2}(I, (4, 2))
172172

173173
@test Array(x1) x
174174
@test Array(x2) x
175175
@test Array(x3) x
176+
177+
x = Matrix(T(3) * I, 2, 4)
178+
x1 = AT(T(3) * I, 2, 4)
179+
@test eltype(x1) == T
180+
@test Array(x1) x
181+
182+
x = fill(T(3), (2, 4))
183+
x1 = fill(AT{T}, T(3), (2, 4))
184+
copyto!(x, 2I)
185+
copyto!(x1, 2I)
186+
@test Array(x1) x
176187
end
177188
end
178189
end

0 commit comments

Comments
 (0)