diff --git a/src/ndsparsearray.jl b/src/ndsparsearray.jl index 30b74a6..74cef4d 100644 --- a/src/ndsparsearray.jl +++ b/src/ndsparsearray.jl @@ -492,14 +492,8 @@ More efficient than `A = A + scalar`. Works with any type that supports addition with the array's element type. """ function add!(A::NDSparseArray{T}, scalar) where {T} - # For generic types, we can't assume zero() exists, so we try to check if it's zero - # For most numeric types this will work, for others we'll skip the check - try - if scalar == zero(typeof(scalar)) - return A - end - catch MethodError - # zero() not defined for this type, continue with the operation + if iszero(scalar) + return A end # Add scalar to all stored values, converting to A's type @@ -552,14 +546,8 @@ More efficient than `A = A - scalar`. Works with any type that supports subtraction with the array's element type. """ function sub!(A::NDSparseArray{T}, scalar) where {T} - # For generic types, we can't assume zero() exists, so we try to check if it's zero - # For most numeric types this will work, for others we'll skip the check - try - if scalar == zero(typeof(scalar)) - return A - end - catch MethodError - # zero() not defined for this type, continue with the operation + if iszero(scalar) + return A end # Subtract scalar from all stored values, converting to A's type @@ -579,23 +567,13 @@ More efficient than `A = A * scalar`. Works with any type that supports multiplication with the array's element type. """ function mul!(A::NDSparseArray{T}, scalar) where {T} - # For generic types, we can't assume zero() or one() exist, so we try to check - # For most numeric types this will work, for others we'll skip the check - try - if scalar == zero(typeof(scalar)) - empty!(A.data) - return A - end - catch MethodError - # zero() not defined for this type, continue + if iszero(scalar) + empty!(A.data) + return A end - try - if scalar == one(typeof(scalar)) - return A - end - catch MethodError - # one() not defined for this type, continue + if isone(scalar) + return A end # Multiply all stored values, converting scalar to A's type diff --git a/test/test_inplace_operations.jl b/test/test_inplace_operations.jl index a5a52c9..68af8da 100644 --- a/test/test_inplace_operations.jl +++ b/test/test_inplace_operations.jl @@ -352,6 +352,7 @@ using Test Base.:-(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value - b.value) Base.:*(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value * b.value) Base.convert(::Type{CustomNumber}, x::CustomNumber) = x + Base.iszero(x::CustomNumber) = iszero(x.value) A = NDSparseArray{CustomNumber, 2}((2, 2)) A[1, 1] = CustomNumber(5.0)