Skip to content

Commit e845204

Browse files
committed
Initial commit
1 parent 9ff327b commit e845204

11 files changed

+1290
-22
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ jobs:
1919
fail-fast: false
2020
matrix:
2121
version:
22+
- '1.9'
23+
- '1.10'
2224
- '1.11'
2325
os:
2426
- ubuntu-latest

Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,14 @@ version = "0.1.0"
66

77
[compat]
88
Aqua = "0.8"
9+
LinearAlgebra = "1.0"
910
Test = "1.0"
10-
julia = "1.11"
11+
julia = "1.9"
1112

1213
[extras]
1314
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
15+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1416
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1517

1618
[targets]
17-
test = ["Aqua", "Test"]
19+
test = ["Aqua", "LinearAlgebra", "Test"]

src/MultidimensionalSparseArrays.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module MultidimensionalSparseArrays
22

33
include("sparse_array.jl")
44

5-
export SparseArray, nnz, sparsity, stored_indices, stored_values, stored_pairs
5+
export SparseArray, nnz, sparsity, stored_indices, stored_values, stored_pairs,
6+
spzeros, spones, spfill, findnz, dropstored!, compress!
67

78
end

src/sparse_array.jl

Lines changed: 284 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -45,21 +45,48 @@ SparseArray{T, N}(::UndefInitializer, dims::Vararg{Int, N}) where {T, N} =
4545
SparseArray{T, N}(::UndefInitializer, dims::NTuple{N, Int}) where {T, N} =
4646
SparseArray{T, N}(dims, zero(T))
4747

48-
# Constructor from dense array
49-
function SparseArray(A::AbstractArray{T, N}) where {T, N}
48+
# Constructor from dense array with optional tolerance for floating point
49+
function SparseArray(A::AbstractArray{T, N}; atol::Real = 0) where {T, N}
5050
sparse_array = SparseArray{T, N}(size(A))
51+
default_val = sparse_array.default_value
52+
5153
for I in CartesianIndices(A)
5254
val = A[I]
53-
if val != sparse_array.default_value
54-
sparse_array.data[I] = val
55+
# Use tolerance for floating point comparison
56+
if T <: AbstractFloat
57+
if abs(val - default_val) > atol
58+
sparse_array.data[I] = val
59+
end
60+
else
61+
if val != default_val
62+
sparse_array.data[I] = val
63+
end
5564
end
5665
end
5766
return sparse_array
5867
end
5968

6069
# Required AbstractArray interface
6170
Base.size(A::SparseArray) = A.dims
62-
Base.IndexStyle(::Type{<:SparseArray}) = IndexCartesian()
71+
Base.IndexStyle(::Type{<:SparseArray}) = IndexLinear()
72+
73+
# Linear indexing support
74+
@inline function Base.getindex(A::SparseArray, i::Int)
75+
@boundscheck checkbounds(A, i)
76+
idx = CartesianIndices(A)[i]
77+
return get(A.data, idx, A.default_value)
78+
end
79+
80+
@inline function Base.setindex!(A::SparseArray, val, i::Int)
81+
@boundscheck checkbounds(A, i)
82+
idx = CartesianIndices(A)[i]
83+
if val == A.default_value
84+
delete!(A.data, idx)
85+
else
86+
A.data[idx] = val
87+
end
88+
return val
89+
end
6390

6491
# Indexing
6592
@inline function Base.getindex(A::SparseArray{T, N}, I::Vararg{Int, N}) where {T, N}
@@ -145,26 +172,16 @@ Return an iterator over (index, value) pairs for stored elements.
145172
"""
146173
stored_pairs(A::SparseArray) = pairs(A.data)
147174

148-
# Display
149-
function Base.show(io::IO, ::MIME"text/plain", A::SparseArray{T, N}) where {T, N}
150-
println(io, "$(size(A)) SparseArray{$T, $N} with $(nnz(A)) stored entries:")
151-
if nnz(A) > 0
152-
for (idx, val) in stored_pairs(A)
153-
println(io, " $idx => $val")
154-
end
155-
end
156-
end
175+
# Display (basic version - improved version defined later)
157176

158177
# Basic arithmetic operations
159178
Base.:(==)(A::SparseArray, B::SparseArray) =
160179
size(A) == size(B) && A.default_value == B.default_value && A.data == B.data
161180

162-
# Copy
181+
# Copy (more efficient)
163182
function Base.copy(A::SparseArray{T, N}) where {T, N}
164183
B = SparseArray{T, N}(A.dims, A.default_value)
165-
for (k, v) in A.data
166-
B.data[k] = v
167-
end
184+
merge!(B.data, A.data)
168185
return B
169186
end
170187

@@ -173,3 +190,252 @@ Base.similar(A::SparseArray{T, N}) where {T, N} =
173190

174191
Base.similar(A::SparseArray{T, N}, ::Type{S}) where {T, S, N} =
175192
SparseArray{S, N}(A.dims, zero(S))
193+
194+
Base.similar(A::SparseArray, ::Type{S}, dims::Dims) where {S} =
195+
SparseArray{S, length(dims)}(dims, zero(S))
196+
197+
# Specialized constructors
198+
"""
199+
spzeros(T, dims...)
200+
201+
Create a sparse array of zeros with element type `T` and given dimensions.
202+
"""
203+
spzeros(::Type{T}, dims::Vararg{Int, N}) where {T, N} = SparseArray{T, N}(dims, zero(T))
204+
spzeros(::Type{T}, dims::NTuple{N, Int}) where {T, N} = SparseArray{T, N}(dims, zero(T))
205+
spzeros(dims::Vararg{Int, N}) where {N} = spzeros(Float64, dims...)
206+
207+
"""
208+
spones(T, dims...)
209+
210+
Create a sparse array filled with ones of type `T` and given dimensions.
211+
Note: This creates a dense-like structure, which may not be memory efficient for large arrays.
212+
"""
213+
function spones(::Type{T}, dims::Vararg{Int, N}) where {T, N}
214+
A = SparseArray{T, N}(dims, zero(T))
215+
one_val = one(T)
216+
for I in CartesianIndices(A)
217+
A.data[I] = one_val
218+
end
219+
return A
220+
end
221+
spones(dims::Vararg{Int, N}) where {N} = spones(Float64, dims...)
222+
223+
"""
224+
spfill(val, dims...)
225+
226+
Create a sparse array filled with the given value.
227+
"""
228+
function spfill(val::T, dims::Vararg{Int, N}) where {T, N}
229+
A = SparseArray{T, N}(dims, zero(T))
230+
if val != zero(T)
231+
for I in CartesianIndices(A)
232+
A.data[I] = val
233+
end
234+
end
235+
return A
236+
end
237+
238+
# Fill methods
239+
"""
240+
fill!(A::SparseArray, val)
241+
242+
Fill sparse array `A` with value `val`. If `val` is the default value,
243+
this efficiently clears all stored elements.
244+
"""
245+
function Base.fill!(A::SparseArray, val)
246+
if val == A.default_value
247+
empty!(A.data)
248+
else
249+
for I in CartesianIndices(A)
250+
A.data[I] = val
251+
end
252+
end
253+
return A
254+
end
255+
256+
# Finding functions
257+
"""
258+
findnz(A::SparseArray)
259+
260+
Return the indices and values of the stored (non-zero) elements in `A`.
261+
Returns `(I, V)` where `I` is a vector of `CartesianIndex` and `V` is a vector of values.
262+
"""
263+
function findnz(A::SparseArray{T, N}) where {T, N}
264+
indices = collect(keys(A.data))
265+
values = collect(vals for vals in Base.values(A.data))
266+
return (indices, values)
267+
end
268+
269+
"""
270+
findall(f, A::SparseArray)
271+
272+
Find all indices where function `f` returns true.
273+
"""
274+
function Base.findall(f::F, A::SparseArray) where {F<:Function}
275+
result = CartesianIndex{ndims(A)}[]
276+
277+
# Check stored values
278+
for (idx, val) in A.data
279+
if f(val)
280+
push!(result, idx)
281+
end
282+
end
283+
284+
# Check default values if predicate could match them
285+
if f(A.default_value)
286+
for I in CartesianIndices(A)
287+
if !haskey(A.data, I)
288+
push!(result, I)
289+
end
290+
end
291+
end
292+
293+
return result
294+
end
295+
296+
# Resolve ambiguity with Base.findall(pred::Base.Fix2{typeof(in)}, x::AbstractArray)
297+
function Base.findall(pred::Base.Fix2{typeof(in)}, A::SparseArray)
298+
# Use Base's implementation by converting to the function form
299+
return findall(x -> pred(x), A)
300+
end
301+
302+
# Arithmetic operations
303+
"""
304+
+(A::SparseArray, B::SparseArray)
305+
306+
Element-wise addition of two sparse arrays.
307+
"""
308+
function Base.:+(A::SparseArray{T, N}, B::SparseArray{S, N}) where {T, S, N}
309+
size(A) == size(B) || throw(DimensionMismatch("Array dimensions must match"))
310+
311+
R = promote_type(T, S)
312+
result = SparseArray{R, N}(size(A), zero(R))
313+
314+
# Add elements from A
315+
for (idx, val) in A.data
316+
result.data[idx] = val + B[idx]
317+
end
318+
319+
# Add elements from B that aren't in A
320+
for (idx, val) in B.data
321+
if !haskey(A.data, idx)
322+
result.data[idx] = A[idx] + val
323+
end
324+
end
325+
326+
return result
327+
end
328+
329+
"""
330+
-(A::SparseArray, B::SparseArray)
331+
332+
Element-wise subtraction of two sparse arrays.
333+
"""
334+
function Base.:-(A::SparseArray{T, N}, B::SparseArray{S, N}) where {T, S, N}
335+
size(A) == size(B) || throw(DimensionMismatch("Array dimensions must match"))
336+
337+
R = promote_type(T, S)
338+
result = SparseArray{R, N}(size(A), zero(R))
339+
340+
# Subtract elements
341+
for (idx, val) in A.data
342+
new_val = val - B[idx]
343+
if new_val != zero(R)
344+
result.data[idx] = new_val
345+
end
346+
end
347+
348+
# Handle elements only in B
349+
for (idx, val) in B.data
350+
if !haskey(A.data, idx)
351+
new_val = A[idx] - val
352+
if new_val != zero(R)
353+
result.data[idx] = new_val
354+
end
355+
end
356+
end
357+
358+
return result
359+
end
360+
361+
"""
362+
*(A::SparseArray, scalar)
363+
364+
Scalar multiplication of sparse array.
365+
"""
366+
function Base.:*(A::SparseArray{T, N}, scalar::Number) where {T, N}
367+
S = promote_type(T, typeof(scalar))
368+
result = SparseArray{S, N}(size(A), zero(S))
369+
370+
if scalar != 0
371+
for (idx, val) in A.data
372+
result.data[idx] = val * scalar
373+
end
374+
end
375+
376+
return result
377+
end
378+
379+
Base.:*(scalar::Number, A::SparseArray) = A * scalar
380+
381+
# Improved display with better formatting
382+
function Base.show(io::IO, ::MIME"text/plain", A::SparseArray{T, N}) where {T, N}
383+
compact = get(io, :compact, false)
384+
385+
if compact
386+
print(io, "$(size(A)) SparseArray{$T, $N}")
387+
return
388+
end
389+
390+
stored_count = nnz(A)
391+
total_elements = length(A)
392+
sparsity_pct = round(sparsity(A) * 100, digits=2)
393+
394+
println(io, "$(size(A)) SparseArray{$T, $N} with $stored_count stored entries:")
395+
println(io, " Sparsity: $sparsity_pct% ($(total_elements - stored_count) zeros)")
396+
397+
if stored_count > 0
398+
# Show up to 10 entries, sorted by index
399+
sorted_pairs = sort(collect(stored_pairs(A)), by = x -> x[1])
400+
display_count = min(10, length(sorted_pairs))
401+
402+
for i in 1:display_count
403+
idx, val = sorted_pairs[i]
404+
println(io, " $idx => $val")
405+
end
406+
407+
if stored_count > 10
408+
println(io, "")
409+
println(io, " ($(stored_count - 10) more entries)")
410+
end
411+
end
412+
end
413+
414+
# Memory efficiency method
415+
"""
416+
dropstored!(A::SparseArray, val)
417+
418+
Remove all stored entries that equal `val` from the sparse array.
419+
This can help reduce memory usage when entries become equal to the default value.
420+
"""
421+
function dropstored!(A::SparseArray, val)
422+
to_delete = CartesianIndex{ndims(A)}[]
423+
for (idx, stored_val) in A.data
424+
if stored_val == val
425+
push!(to_delete, idx)
426+
end
427+
end
428+
429+
for idx in to_delete
430+
delete!(A.data, idx)
431+
end
432+
433+
return A
434+
end
435+
436+
"""
437+
compress!(A::SparseArray)
438+
439+
Remove stored entries that equal the default value to reduce memory usage.
440+
"""
441+
compress!(A::SparseArray) = dropstored!(A, A.default_value)

0 commit comments

Comments
 (0)