Skip to content

Commit 9148e33

Browse files
committed
prevent get_backend from overflowing the stack
Prevent the `get_backend` methods from overflowing the stack/recurring without bound. Hoping this doesn't cause inference issues due to deeper call stacks. Fixes #588
1 parent 110d784 commit 9148e33

File tree

3 files changed

+16
-3
lines changed

3 files changed

+16
-3
lines changed

ext/LinearAlgebraExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearAlgebraExt
33
using KernelAbstractions: KernelAbstractions
44
using LinearAlgebra: Tridiagonal, Diagonal
55

6-
KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend(A.diag)
7-
KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend(A.d)
6+
KernelAbstractions.get_backend(A::Diagonal) = KernelAbstractions.get_backend_recur(x -> x.diag, A)
7+
KernelAbstractions.get_backend(A::Tridiagonal) = KernelAbstractions.get_backend_recur(x -> x.d, A)
88

99
end

src/KernelAbstractions.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,8 +510,17 @@ Get a [`Backend`](@ref) instance suitable for array `A`.
510510
"""
511511
function get_backend end
512512

513+
function get_backend_recur(f::F, x) where {F}
514+
t() = throw(ArgumentError("throwing to prevent a stack overflow, possibly a `get_backend` method is missing?"))
515+
y = f(x)
516+
if y isa typeof(x)
517+
@noinline t()
518+
end
519+
get_backend(y)
520+
end
521+
513522
# Should cover SubArray, ReshapedArray, ReinterpretArray, Hermitian, AbstractTriangular, etc.:
514-
get_backend(A::AbstractArray) = get_backend(parent(A))
523+
get_backend(A::AbstractArray) = get_backend_recur(parent, A)
515524

516525
# Define:
517526
# adapt_storage(::Backend, a::Array) = adapt(BackendArray, a)

test/test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ using Adapt
77

88
identity(x) = x
99

10+
struct UnknownAbstractVector <: AbstractVector{Float32} # issue #588
11+
end
12+
1013
function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; skip_tests = Set{String}())
1114
@conditional_testset "partition" skip_tests begin
1215
backend = Backend()
@@ -80,6 +83,7 @@ function unittest_testsuite(Backend, backend_str, backend_mod, BackendArrayT; sk
8083
@test @inferred(KernelAbstractions.get_backend(view(A, 2:4, 1:3))) isa backendT
8184
@test @inferred(KernelAbstractions.get_backend(Diagonal(x))) isa backendT
8285
@test @inferred(KernelAbstractions.get_backend(Tridiagonal(A))) isa backendT
86+
@test_throws ArgumentError KernelAbstractions.get_backend(UnknownAbstractVector()) # issue #588
8387
end
8488

8589
@conditional_testset "sparse" skip_tests begin

0 commit comments

Comments
 (0)