Skip to content

Commit 2e72e6e

Browse files
committed
Initial commit
1 parent c50b1b7 commit 2e72e6e

File tree

2 files changed

+107
-22
lines changed

2 files changed

+107
-22
lines changed

src/ndsparsearray.jl

Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -514,15 +514,21 @@ function add!(A::NDSparseArray{T, N}, B::NDSparseArray{S, N}) where {T, S, N}
514514
end
515515

516516
"""
517-
add!(A::NDSparseArray, scalar::Number)
517+
add!(A::NDSparseArray, scalar)
518518
519519
In-place addition of scalar to all stored elements in sparse array `A`.
520520
More efficient than `A = A + scalar`.
521-
"""
522-
function add!(A::NDSparseArray{T}, scalar::Number) where {T}
523-
# If scalar is zero, no operation needed
524-
if scalar == zero(typeof(scalar))
525-
return A
521+
Works with any type that supports addition with the array's element type.
522+
"""
523+
function add!(A::NDSparseArray{T}, scalar) where {T}
524+
# For generic types, we can't assume zero() exists, so we try to check if it's zero
525+
# For most numeric types this will work, for others we'll skip the check
526+
try
527+
if scalar == zero(typeof(scalar))
528+
return A
529+
end
530+
catch MethodError
531+
# zero() not defined for this type, continue with the operation
526532
end
527533

528534
# Add scalar to all stored values, converting to A's type
@@ -568,15 +574,21 @@ function sub!(A::NDSparseArray{T, N}, B::NDSparseArray{S, N}) where {T, S, N}
568574
end
569575

570576
"""
571-
sub!(A::NDSparseArray, scalar::Number)
577+
sub!(A::NDSparseArray, scalar)
572578
573579
In-place subtraction of scalar from all stored elements in sparse array `A`.
574580
More efficient than `A = A - scalar`.
575-
"""
576-
function sub!(A::NDSparseArray{T}, scalar::Number) where {T}
577-
# If scalar is zero, no operation needed
578-
if scalar == zero(typeof(scalar))
579-
return A
581+
Works with any type that supports subtraction with the array's element type.
582+
"""
583+
function sub!(A::NDSparseArray{T}, scalar) where {T}
584+
# For generic types, we can't assume zero() exists, so we try to check if it's zero
585+
# For most numeric types this will work, for others we'll skip the check
586+
try
587+
if scalar == zero(typeof(scalar))
588+
return A
589+
end
590+
catch MethodError
591+
# zero() not defined for this type, continue with the operation
580592
end
581593

582594
# Subtract scalar from all stored values, converting to A's type
@@ -589,21 +601,30 @@ function sub!(A::NDSparseArray{T}, scalar::Number) where {T}
589601
end
590602

591603
"""
592-
mul!(A::NDSparseArray, scalar::Number)
604+
mul!(A::NDSparseArray, scalar)
593605
594606
In-place scalar multiplication of sparse array `A`.
595607
More efficient than `A = A * scalar`.
596-
"""
597-
function mul!(A::NDSparseArray{T}, scalar::Number) where {T}
598-
# If scalar is zero, clear all elements
599-
if scalar == zero(typeof(scalar))
600-
empty!(A.data)
601-
return A
608+
Works with any type that supports multiplication with the array's element type.
609+
"""
610+
function mul!(A::NDSparseArray{T}, scalar) where {T}
611+
# For generic types, we can't assume zero() or one() exist, so we try to check
612+
# For most numeric types this will work, for others we'll skip the check
613+
try
614+
if scalar == zero(typeof(scalar))
615+
empty!(A.data)
616+
return A
617+
end
618+
catch MethodError
619+
# zero() not defined for this type, continue
602620
end
603621

604-
# If scalar is one, no operation needed
605-
if scalar == one(typeof(scalar))
606-
return A
622+
try
623+
if scalar == one(typeof(scalar))
624+
return A
625+
end
626+
catch MethodError
627+
# one() not defined for this type, continue
607628
end
608629

609630
# Multiply all stored values, converting scalar to A's type

test/test_inplace_operations.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,70 @@ using Test
342342
@test A[2, 2, 2] == 20
343343
@test A[1, 2, 1] == 14
344344
end
345+
346+
@testset "Generic scalar types" begin
347+
# Test with custom numeric type that supports +, -, *
348+
struct CustomNumber
349+
value::Float64
350+
end
351+
Base.:+(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value + b.value)
352+
Base.:-(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value - b.value)
353+
Base.:*(a::CustomNumber, b::CustomNumber) = CustomNumber(a.value * b.value)
354+
Base.convert(::Type{CustomNumber}, x::CustomNumber) = x
355+
356+
A = NDSparseArray{CustomNumber, 2}((2, 2))
357+
A[1, 1] = CustomNumber(5.0)
358+
A[2, 2] = CustomNumber(10.0)
359+
360+
add!(A, CustomNumber(3.0))
361+
362+
@test A[1, 1].value == 8.0
363+
@test A[2, 2].value == 13.0
364+
@test nnz(A) == 2
365+
366+
# Test with rational numbers
367+
B = NDSparseArray{Rational{Int}, 2}((2, 2))
368+
B[1, 1] = 1//2
369+
B[2, 2] = 3//4
370+
371+
add!(B, 1//4)
372+
373+
@test B[1, 1] == 3//4
374+
@test B[2, 2] == 1//1
375+
@test nnz(B) == 2
376+
377+
# Test with complex numbers and complex scalar
378+
C = NDSparseArray{Complex{Int}, 2}((2, 2))
379+
C[1, 1] = 2 + 3im
380+
C[2, 2] = 1 + 1im
381+
382+
add!(C, 1 + 2im)
383+
384+
@test C[1, 1] == 3 + 5im
385+
@test C[2, 2] == 2 + 3im
386+
@test nnz(C) == 2
387+
388+
# Test multiplication with rational
389+
mul!(B, 2//3)
390+
391+
@test B[1, 1] == 1//2 # (3//4) * (2//3) = 1//2
392+
@test B[2, 2] == 2//3 # (1//1) * (2//3) = 2//3
393+
@test nnz(B) == 2
394+
end
395+
396+
@testset "Generic type error handling" begin
397+
# Test that incompatible types still give sensible errors
398+
A = NDSparseArray{Int, 2}((2, 2))
399+
A[1, 1] = 5
400+
401+
# Test with incompatible scalar type that can't convert
402+
struct IncompatibleType
403+
value::String
404+
end
405+
406+
incompatible = IncompatibleType("test")
407+
@test_throws MethodError add!(A, incompatible)
408+
end
345409
end
346410

347411
end

0 commit comments

Comments
 (0)