Skip to content

Commit c50b1b7

Browse files
committed
Initial commit
1 parent 39d39c2 commit c50b1b7

File tree

3 files changed

+478
-1
lines changed

3 files changed

+478
-1
lines changed

src/NDimensionalSparseArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module NDimensionalSparseArrays
33
include("ndsparsearray.jl")
44

55
export NDSparseArray, nnz, sparsity, stored_indices, stored_values, stored_pairs,
6-
spzeros, spones, spfill, findnz, dropstored!, compress!, hasindex, to_dense
6+
spzeros, spones, spfill, findnz, dropstored!, compress!, hasindex, to_dense,
7+
add!, sub!, mul!
78

89
end

src/ndsparsearray.jl

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,3 +485,132 @@ Collect only the stored values in the sparse array.
485485
To get a dense representation, use `to_dense(A)`.
486486
"""
487487
Base.collect(A::NDSparseArray) = collect(stored_values(A))
488+
489+
# In-place arithmetic operations
490+
# Julia doesn't have += as a base function - it's syntax sugar for a = a + b
491+
# We'll create optimized in-place functions that users can call directly
492+
493+
"""
494+
add!(A::NDSparseArray, B::NDSparseArray)
495+
496+
In-place addition of sparse array `B` to sparse array `A`.
497+
Modifies `A` and returns it. More efficient than `A = A + B`.
498+
"""
499+
function add!(A::NDSparseArray{T, N}, B::NDSparseArray{S, N}) where {T, S, N}
500+
size(A) == size(B) || throw(DimensionMismatch("Array dimensions must match"))
501+
502+
# Add elements from B, converting to A's type
503+
for (idx, val_b) in B.data
504+
if haskey(A.data, idx)
505+
# Both arrays have this index
506+
A.data[idx] = A.data[idx] + convert(T, val_b)
507+
else
508+
# Only B has this index (A is effectively zero here)
509+
A.data[idx] = convert(T, val_b)
510+
end
511+
end
512+
513+
return A
514+
end
515+
516+
"""
517+
add!(A::NDSparseArray, scalar::Number)
518+
519+
In-place addition of scalar to all stored elements in sparse array `A`.
520+
More efficient than `A = A + scalar`.
521+
"""
522+
function add!(A::NDSparseArray{T}, scalar::Number) where {T}
523+
# If scalar is zero, no operation needed
524+
if scalar == zero(typeof(scalar))
525+
return A
526+
end
527+
528+
# Add scalar to all stored values, converting to A's type
529+
scalar_converted = convert(T, scalar)
530+
for (idx, val) in A.data
531+
A.data[idx] = val + scalar_converted
532+
end
533+
534+
return A
535+
end
536+
537+
"""
538+
sub!(A::NDSparseArray, B::NDSparseArray)
539+
540+
In-place subtraction of sparse array `B` from sparse array `A`.
541+
Modifies `A` and returns it. More efficient than `A = A - B`.
542+
"""
543+
function sub!(A::NDSparseArray{T, N}, B::NDSparseArray{S, N}) where {T, S, N}
544+
size(A) == size(B) || throw(DimensionMismatch("Array dimensions must match"))
545+
546+
# Subtract elements from B, converting to A's type
547+
for (idx, val_b) in B.data
548+
val_b_converted = convert(T, val_b)
549+
if haskey(A.data, idx)
550+
# Both arrays have this index
551+
new_val = A.data[idx] - val_b_converted
552+
if new_val != zero(T)
553+
A.data[idx] = new_val
554+
else
555+
# Remove zero values to maintain sparsity
556+
delete!(A.data, idx)
557+
end
558+
else
559+
# Only B has this index (A is effectively zero here)
560+
new_val = zero(T) - val_b_converted
561+
if new_val != zero(T)
562+
A.data[idx] = new_val
563+
end
564+
end
565+
end
566+
567+
return A
568+
end
569+
570+
"""
571+
sub!(A::NDSparseArray, scalar::Number)
572+
573+
In-place subtraction of scalar from all stored elements in sparse array `A`.
574+
More efficient than `A = A - scalar`.
575+
"""
576+
function sub!(A::NDSparseArray{T}, scalar::Number) where {T}
577+
# If scalar is zero, no operation needed
578+
if scalar == zero(typeof(scalar))
579+
return A
580+
end
581+
582+
# Subtract scalar from all stored values, converting to A's type
583+
scalar_converted = convert(T, scalar)
584+
for (idx, val) in A.data
585+
A.data[idx] = val - scalar_converted
586+
end
587+
588+
return A
589+
end
590+
591+
"""
592+
mul!(A::NDSparseArray, scalar::Number)
593+
594+
In-place scalar multiplication of sparse array `A`.
595+
More efficient than `A = A * scalar`.
596+
"""
597+
function mul!(A::NDSparseArray{T}, scalar::Number) where {T}
598+
# If scalar is zero, clear all elements
599+
if scalar == zero(typeof(scalar))
600+
empty!(A.data)
601+
return A
602+
end
603+
604+
# If scalar is one, no operation needed
605+
if scalar == one(typeof(scalar))
606+
return A
607+
end
608+
609+
# Multiply all stored values, converting scalar to A's type
610+
scalar_converted = convert(T, scalar)
611+
for (idx, val) in A.data
612+
A.data[idx] = val * scalar_converted
613+
end
614+
615+
return A
616+
end

0 commit comments

Comments
 (0)