Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 9 additions & 31 deletions src/ndsparsearray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,8 @@ More efficient than `A = A + scalar`.
Works with any type that supports addition with the array's element type.
"""
function add!(A::NDSparseArray{T}, scalar) where {T}
# For generic types, we can't assume zero() exists, so we try to check if it's zero
# For most numeric types this will work, for others we'll skip the check
try
if scalar == zero(typeof(scalar))
return A
end
catch MethodError
# zero() not defined for this type, continue with the operation
if iszero(scalar)
return A
end

# Add scalar to all stored values, converting to A's type
Expand Down Expand Up @@ -552,14 +546,8 @@ More efficient than `A = A - scalar`.
Works with any type that supports subtraction with the array's element type.
"""
function sub!(A::NDSparseArray{T}, scalar) where {T}
# For generic types, we can't assume zero() exists, so we try to check if it's zero
# For most numeric types this will work, for others we'll skip the check
try
if scalar == zero(typeof(scalar))
return A
end
catch MethodError
# zero() not defined for this type, continue with the operation
if iszero(scalar)
return A
end

# Subtract scalar from all stored values, converting to A's type
Expand All @@ -579,23 +567,13 @@ More efficient than `A = A * scalar`.
Works with any type that supports multiplication with the array's element type.
"""
function mul!(A::NDSparseArray{T}, scalar) where {T}
# For generic types, we can't assume zero() or one() exist, so we try to check
# For most numeric types this will work, for others we'll skip the check
try
if scalar == zero(typeof(scalar))
empty!(A.data)
return A
end
catch MethodError
# zero() not defined for this type, continue
if iszero(scalar)
empty!(A.data)
return A
end

try
if scalar == one(typeof(scalar))
return A
end
catch MethodError
# one() not defined for this type, continue
if isone(scalar)
return A
end

# Multiply all stored values, converting scalar to A's type
Expand Down
1 change: 1 addition & 0 deletions test/test_inplace_operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ using Test
Base.:-(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value - b.value)
Base.:*(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value * b.value)
Base.convert(::Type{CustomNumber}, x::CustomNumber) = x
Base.iszero(x::CustomNumber) = iszero(x.value)

A = NDSparseArray{CustomNumber, 2}((2, 2))
A[1, 1] = CustomNumber(5.0)
Expand Down
Loading