From 38dae0e7013809fa8418423e468993e36b523795 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Wed, 8 Feb 2023 16:23:12 +0800 Subject: [PATCH 1/6] add MMReader and MMFormat --- src/MatrixMarket.jl | 1 + src/format.jl | 65 ++++++++++++++++++++++++++++++++++++++++ src/mmread.jl | 73 +++++++++++++++++++++++++++++---------------- 3 files changed, 113 insertions(+), 26 deletions(-) create mode 100644 src/format.jl diff --git a/src/MatrixMarket.jl b/src/MatrixMarket.jl index f20b711..82a1f4e 100644 --- a/src/MatrixMarket.jl +++ b/src/MatrixMarket.jl @@ -7,6 +7,7 @@ using TranscodingStreams, CodecZlib export mmread, mmwrite, mminfo +include("format.jl") include("mminfo.jl") include("mmread.jl") include("mmwrite.jl") diff --git a/src/format.jl b/src/format.jl new file mode 100644 index 0000000..2994b0b --- /dev/null +++ b/src/format.jl @@ -0,0 +1,65 @@ +abstract type MMFormat end + +function readout(f::MMFormat, nrow::Int, ncol::Int, nentry::Int, symm) + rep = formattext(f) + field = generate_eltype(eltype(f)) + return (Tuple(f)..., nrow, ncol, nentry, rep, field, symm) +end + +struct CoordinateFormat{T} <: MMFormat + rows::Vector{Int} + cols::Vector{Int} + vals::Vector{T} +end + +function CoordinateFormat(field, nentry) + T = parse_eltype(field) + rows = Vector{Int}(undef, nentry) + cols = Vector{Int}(undef, nentry) + vals = Vector{T}(undef, nentry) + return CoordinateFormat{T}(rows, cols, vals) +end + +Base.eltype(::CoordinateFormat{T}) where T = T + +formattext(::CoordinateFormat) = "coordinate" + +Base.Tuple(f::CoordinateFormat) = (f.rows, f.cols, f.vals) + +function writeat!(f::CoordinateFormat{T}, i::Int, line::String) where T + f.rows[i], f.cols[i], f.vals[i] = parseline(T, line) + return f +end + +function readout(f::CoordinateFormat, nrow::Int, ncol::Int, symm) + symfunc = parse_symmetric(symm) + return symfunc(sparse(f.rows, f.cols, f.vals, nrow, ncol)) +end + +struct ArrayFormat{T} <: MMFormat + vals::Vector{T} +end + +function ArrayFormat(::Type{T}, nentry::Int) where {T} + vals = Vector{T}(undef, nentry) + return ArrayFormat{T}(vals) +end + +ArrayFormat(nentry::Int) = ArrayFormat(Float64, nentry) + +Base.eltype(::ArrayFormat{T}) where T = T + +formattext(::ArrayFormat) = "array" + +Base.Tuple(f::ArrayFormat) = (f.vals,) + +function writeat!(f::ArrayFormat{T}, i::Int, line::String) where T + f.vals[i] = parse(T, line) + return f +end + +function readout(f::ArrayFormat, nrow::Int, ncol::Int, symm) + A = reshape(f.vals, nrow, ncol) + symfunc = parse_symmetric(symm) + return symfunc(A) +end diff --git a/src/mmread.jl b/src/mmread.jl index 6c3ec60..563a5a3 100644 --- a/src/mmread.jl +++ b/src/mmread.jl @@ -27,36 +27,17 @@ function mmread(filename::String, infoonly::Bool=false, retcoord::Bool=false) end function mmread(stream::IO, infoonly::Bool=false, retcoord::Bool=false) - rows, cols, entries, rep, field, symm = mminfo(stream) - - infoonly && return rows, cols, entries, rep, field, symm - - T = parse_eltype(field) - symfunc = parse_symmetric(symm) - - if rep == "coordinate" - rn = Vector{Int}(undef, entries) - cn = Vector{Int}(undef, entries) - vals = Vector{T}(undef, entries) - for i in 1:entries - line = readline(stream) - splits = find_splits(line, num_splits(T)) - rn[i] = parse_row(line, splits) - cn[i] = parse_col(line, splits, T) - vals[i] = parse_val(line, splits, T) - end + nrow, ncol, nentry, rep, field, symm = mminfo(stream) - result = retcoord ? (rn, cn, vals, rows, cols, entries, rep, field, symm) : - symfunc(sparse(rn, cn, vals, rows, cols)) - else - vals = [parse(Float64, readline(stream)) for _ in 1:entries] - A = reshape(vals, rows, cols) - result = symfunc(A) - end + infoonly && return nrow, ncol, nentry, rep, field, symm - return result + reader = MMReader(nrow, ncol, nentry, rep, field, symm) + readlines!(reader, stream) + return readout(reader, retcoord) end +## Parsing + function parse_eltype(field::String) if field == "real" return Float64 @@ -107,6 +88,14 @@ end parse_val(line, splits, ::Type{Bool}) = true parse_val(line, splits, ::Type{T}) where {T} = parse(T, line[splits[2]:length(line)]) +function parseline(::Type{T}, line) where T + splits = find_splits(line, num_splits(T)) + r = parse_row(line, splits) + c = parse_col(line, splits, T) + v = parse_val(line, splits, T) + return r, c, v +end + num_splits(::Type{ComplexF64}) = 3 num_splits(::Type{Bool}) = 1 num_splits(elty) = 2 @@ -130,3 +119,35 @@ function find_splits(s::String, num) splits end + +## Reader + +struct MMReader{F <: MMFormat} + nrow::Int + ncol::Int + nentry::Int + rep::String + symm::String + format::F +end + +function MMReader(nrow::Integer, ncol::Integer, nentry::Integer, rep, field, symm) + format = (rep == "coordinate") ? CoordinateFormat(field, nentry) : ArrayFormat(nentry) + return MMReader{typeof(format)}(nrow, ncol, nentry, rep, symm, format) +end + +function readlines!(reader::MMReader, stream::IO) + for i in 1:reader.nentry + line = readline(stream) + writeat!(reader.format, i, line) + end + return reader +end + +function readout(reader::MMReader, retcoord::Bool=false) + if retcoord + return readout(reader.format, reader.nrow, reader.ncol, reader.nentry, reader.symm) + else + return readout(reader.format, reader.nrow, reader.ncol, reader.symm) + end +end From 6b88fdb00701601f6e4f1fc931945445e5857c2c Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Thu, 9 Feb 2023 15:32:30 +0800 Subject: [PATCH 2/6] add MMWriter and support iterate over formats --- src/format.jl | 29 ++++++++++ src/mmwrite.jl | 83 ++++++++++++++++++--------- test/mtx.jl | 149 +++++++++++++++++++++++++++---------------------- 3 files changed, 169 insertions(+), 92 deletions(-) diff --git a/src/format.jl b/src/format.jl index 2994b0b..9ddd884 100644 --- a/src/format.jl +++ b/src/format.jl @@ -1,5 +1,7 @@ abstract type MMFormat end +Base.length(f::MMFormat) = length(f.vals) + function readout(f::MMFormat, nrow::Int, ncol::Int, nentry::Int, symm) rep = formattext(f) field = generate_eltype(eltype(f)) @@ -20,6 +22,15 @@ function CoordinateFormat(field, nentry) return CoordinateFormat{T}(rows, cols, vals) end +function CoordinateFormat(A::SparseMatrixCSC{T}) where {T} + rows = rowvals(A) + vals = nonzeros(A) + n = size(A, 2) + cols = [repeat([j], length(nzrange(A, j))) for j in 1:n] + cols = collect(Iterators.flatten(cols)) + return CoordinateFormat{T}(rows, cols, vals) +end + Base.eltype(::CoordinateFormat{T}) where T = T formattext(::CoordinateFormat) = "coordinate" @@ -36,6 +47,15 @@ function readout(f::CoordinateFormat, nrow::Int, ncol::Int, symm) return symfunc(sparse(f.rows, f.cols, f.vals, nrow, ncol)) end +function Base.iterate(f::CoordinateFormat, i::Integer=zero(length(f))) + i += oneunit(i) + if i <= length(f) + return (f.rows[i], f.cols[i], f.vals[i]), i + else + return nothing + end +end + struct ArrayFormat{T} <: MMFormat vals::Vector{T} end @@ -63,3 +83,12 @@ function readout(f::ArrayFormat, nrow::Int, ncol::Int, symm) symfunc = parse_symmetric(symm) return symfunc(A) end + +function Base.iterate(f::ArrayFormat, i::Integer=zero(length(f))) + i += oneunit(i) + if i <= length(f) + return f.vals[i], i + else + return nothing + end +end diff --git a/src/mmwrite.jl b/src/mmwrite.jl index 29ec352..52e6e69 100644 --- a/src/mmwrite.jl +++ b/src/mmwrite.jl @@ -19,32 +19,24 @@ function mmwrite(filename::String, matrix::SparseMatrixCSC) close(stream) end -function mmwrite(stream::IO, matrix::SparseMatrixCSC) +function mmwrite(stream::IO, matrix::SparseMatrixCSC{T}) where {T} nl = get_newline() - elem = generate_eltype(eltype(matrix)) - sym = generate_symmetric(matrix) + elem = generate_eltype(T) + writer = MMWriter(matrix) + write(stream, header(writer)) + write(stream, nl) + write(stream, sizetext(writer)) + write(stream, nl) - # write header - write(stream, "%%MatrixMarket matrix coordinate $elem $sym$nl") - - # only use lower triangular part of symmetric and Hermitian matrices - if issymmetric(matrix) || ishermitian(matrix) - matrix = tril(matrix) - end - - # write matrix size and number of nonzeros - write(stream, "$(size(matrix, 1)) $(size(matrix, 2)) $(nnz(matrix))$nl") - - rows = rowvals(matrix) - vals = nonzeros(matrix) - for i in 1:size(matrix, 2) - for j in nzrange(matrix, i) - entity = generate_entity(i, j, rows, vals, elem) - write(stream, entity) - end + for (r, c, v) in writer.format + entity = generate_entity(r, c, v, elem) + write(stream, entity) + write(stream, nl) end end +## Generating + generate_eltype(::Type{<:Bool}) = "pattern" generate_eltype(::Type{<:Integer}) = "integer" generate_eltype(::Type{<:AbstractFloat}) = "real" @@ -61,14 +53,13 @@ function generate_symmetric(m::AbstractMatrix) end end -function generate_entity(i, j, rows, vals, kind::String) - nl = get_newline() +function generate_entity(r, c, v, kind::String) if kind == "pattern" - return "$(rows[j]) $i$nl" + return "$r $c" elseif kind == "complex" - return "$(rows[j]) $i $(real(vals[j])) $(imag(vals[j]))$nl" + return "$r $c $(real(v)) $(imag(v))" else - return "$(rows[j]) $i $(vals[j])$nl" + return "$r $c $v" end end @@ -79,3 +70,43 @@ function get_newline() return "\n" end end + +## Writer + +struct MMWriter{F <: MMFormat} + nrow::Int + ncol::Int + nentry::Int + symm::String + format::F +end + +function MMWriter(A::AbstractMatrix{T}) where {T} + nrow, ncol = size(A) + vals = reshape(A, :) + symm = generate_symmetric(A) + format = ArrayFormat{T}(vals) + return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format) +end + +function MMWriter(A::SparseMatrixCSC) + nrow, ncol = size(A) + symm = generate_symmetric(A) + + # only use lower triangular part of symmetric and Hermitian matrices + if symm == "symmetric" || symm == "hermitian" + A = tril(A) + end + + nentry = nnz(A) + format = CoordinateFormat(A) + return MMWriter{typeof(format)}(nrow, ncol, nentry, symm, format) +end + +function header(writer::MMWriter) + rep = formattext(writer.format) + elem = generate_eltype(eltype(writer.format)) + return "%%MatrixMarket matrix $rep $elem $(writer.symm)" +end + +sizetext(writer::MMWriter) = "$(writer.nrow) $(writer.ncol) $(writer.nentry)" diff --git a/test/mtx.jl b/test/mtx.jl index 2ce6674..c65aa94 100644 --- a/test/mtx.jl +++ b/test/mtx.jl @@ -10,72 +10,87 @@ testmatrices = download_unzip_nist_files() @testset "read/write mtx" begin - rows, cols, entries, rep, field, symm = mminfo(mtx_filename) - @test rows == 11 - @test cols == 12 - @test entries == 5 - @test rep == "coordinate" - @test field == "integer" - @test symm == "general" - - A = mmread(mtx_filename) - @test A isa SparseMatrixCSC - @test A == res - - newfilename = replace(mtx_filename, "test.mtx" => "test_write.mtx") - mmwrite(newfilename, res) - - f = open(mtx_filename) - sha_test = bytes2hex(sha256(read(f, String))) - close(f) - - f = open(newfilename) - sha_new = bytes2hex(sha256(read(f, String))) - close(f) - - @test sha_test == sha_new - rm(newfilename) + @testset "mminfo test.mtx" begin + rows, cols, entries, rep, field, symm = mminfo(mtx_filename) + @test rows == 11 + @test cols == 12 + @test entries == 5 + @test rep == "coordinate" + @test field == "integer" + @test symm == "general" + end + + @testset "mmread test.mtx" begin + A = mmread(mtx_filename) + @test A isa SparseMatrixCSC + @test A == res + end + + @testset "mmwrite test.mtx" begin + newfilename = replace(mtx_filename, "test.mtx" => "test_write.mtx") + mmwrite(newfilename, res) + + f = open(mtx_filename) + sha_test = bytes2hex(sha256(read(f, String))) + close(f) + + f = open(newfilename) + sha_new = bytes2hex(sha256(read(f, String))) + close(f) + + @test sha_test == sha_new + rm(newfilename) + end end @testset "read/write mtx.gz" begin gz_filename = mtx_filename * ".gz" - rows, cols, entries, rep, field, symm = mminfo(gz_filename) - @test rows == 11 - @test cols == 12 - @test entries == 5 - @test rep == "coordinate" - @test field == "integer" - @test symm == "general" - - A = mmread(gz_filename) - @test A isa SparseMatrixCSC - @test A == res - - newfilename = replace(gz_filename, "test.mtx.gz" => "test_write.mtx.gz") - mmwrite(newfilename, res) - - stream = GzipDecompressorStream(open(gz_filename)) - adjusted_content = replace(read(stream, String), "\n" => get_newline()) - sha_test = bytes2hex(sha256(adjusted_content)) - close(stream) - - stream = GzipDecompressorStream(open(newfilename)) - sha_new = bytes2hex(sha256(read(stream, String))) - close(stream) - - @test sha_test == sha_new - rm(newfilename) + @testset "mminfo test.mtx.gz" begin + rows, cols, entries, rep, field, symm = mminfo(gz_filename) + @test rows == 11 + @test cols == 12 + @test entries == 5 + @test rep == "coordinate" + @test field == "integer" + @test symm == "general" + end + + @testset "mmread test.mtx.gz" begin + A = mmread(gz_filename) + @test A isa SparseMatrixCSC + @test A == res + end + + @testset "mmwrite test.mtx.gz" begin + newfilename = replace(gz_filename, "test.mtx.gz" => "test_write.mtx.gz") + mmwrite(newfilename, res) + + stream = GzipDecompressorStream(open(gz_filename)) + adjusted_content = replace(read(stream, String), "\n" => get_newline()) + sha_test = bytes2hex(sha256(adjusted_content)) + close(stream) + + stream = GzipDecompressorStream(open(newfilename)) + sha_new = bytes2hex(sha256(read(stream, String))) + close(stream) + + @test sha_test == sha_new + rm(newfilename) + end end @testset "read/write NIST mtx files" begin # verify mmread(mmwrite(A)) == A for filename in filter(t -> endswith(t, ".mtx"), readdir()) new_filename = replace(filename, ".mtx" => "_.mtx") - A = MatrixMarket.mmread(filename) - MatrixMarket.mmwrite(new_filename, A) - new_A = MatrixMarket.mmread(new_filename) - @test new_A == A + @testset "$filename" begin + A = MatrixMarket.mmread(filename) + MatrixMarket.mmwrite(new_filename, A) + new_A = MatrixMarket.mmread(new_filename) + @test new_A == A + end + rm(new_filename) end end @@ -83,19 +98,21 @@ @testset "read/write NIST mtx.gz files" begin for gz_filename in filter(t -> endswith(t, ".mtx.gz"), readdir()) mtx_filename = replace(gz_filename, ".mtx.gz" => ".mtx") - - # reading from .mtx and .mtx.gz must be identical - A_gz = MatrixMarket.mmread(gz_filename) - - A = MatrixMarket.mmread(mtx_filename) - @test A_gz == A - - # writing to .mtx and .mtx.gz must be identical new_filename = replace(gz_filename, ".mtx.gz" => "_.mtx.gz") - mmwrite(new_filename, A) + A = MatrixMarket.mmread(mtx_filename) - new_A = MatrixMarket.mmread(new_filename) - @test new_A == A + @testset "mmread $gz_filename" begin + # reading from .mtx and .mtx.gz must be identical + A_gz = MatrixMarket.mmread(gz_filename) + @test A_gz == A + end + + @testset "mmwrite $gz_filename" begin + # writing to .mtx and .mtx.gz must be identical + mmwrite(new_filename, A) + new_A = MatrixMarket.mmread(new_filename) + @test new_A == A + end rm(new_filename) end From 27e85070a84ce99176708a13b3eb39367b6c909b Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Mon, 13 Feb 2023 23:05:35 +0800 Subject: [PATCH 3/6] add test cases --- test/mtx.jl | 17 +++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 18 insertions(+) diff --git a/test/mtx.jl b/test/mtx.jl index c65aa94..c979377 100644 --- a/test/mtx.jl +++ b/test/mtx.jl @@ -118,6 +118,23 @@ end end + @testset "read from online NIST mtx.gz files" begin + for (collectionname, setname, matrixname) in testmatrices + url = "https://math.nist.gov/pub/MatrixMarket2/$collectionname/$setname/$matrixname.mtx.gz" + mtx_filename = string(collectionname, '_', setname, '_', matrixname, ".mtx") + A = MatrixMarket.mmread(mtx_filename) + + @testset "mmread $matrixname.mtx.gz" begin + # reading from .mtx and .mtx.gz must be identical + buffer = PipeBuffer() + stream = TranscodingStream(GzipDecompressor(), buffer) + Downloads.download(url, buffer) + A_gz = MatrixMarket.mmread(stream) + @test A_gz == A + end + end + end + # clean up for filename in filter(t -> endswith(t, ".mtx"), readdir()) rm(filename) diff --git a/test/runtests.jl b/test/runtests.jl index 6ec7556..4536740 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using Downloads using GZip using SparseArrays using SHA +using TranscodingStreams using Test include("test_utils.jl") From d12a57034e68a528394cc089f194b4f54cefec67 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 14 Feb 2023 11:19:23 +0800 Subject: [PATCH 4/6] add tests for format --- src/format.jl | 7 +++++++ test/format.jl | 37 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 +++ 3 files changed, 47 insertions(+) create mode 100644 test/format.jl diff --git a/src/format.jl b/src/format.jl index 9ddd884..f534dee 100644 --- a/src/format.jl +++ b/src/format.jl @@ -37,6 +37,9 @@ formattext(::CoordinateFormat) = "coordinate" Base.Tuple(f::CoordinateFormat) = (f.rows, f.cols, f.vals) +Base.:(==)(x::CoordinateFormat, y::CoordinateFormat) = (x.rows == y.rows) && + (x.cols == y.cols) && (x.vals == y.vals) + function writeat!(f::CoordinateFormat{T}, i::Int, line::String) where T f.rows[i], f.cols[i], f.vals[i] = parseline(T, line) return f @@ -67,12 +70,16 @@ end ArrayFormat(nentry::Int) = ArrayFormat(Float64, nentry) +ArrayFormat(A::AbstractMatrix{T}) where {T} = ArrayFormat{T}(reshape(A, :)) + Base.eltype(::ArrayFormat{T}) where T = T formattext(::ArrayFormat) = "array" Base.Tuple(f::ArrayFormat) = (f.vals,) +Base.:(==)(x::ArrayFormat, y::ArrayFormat) = (x.vals == y.vals) + function writeat!(f::ArrayFormat{T}, i::Int, line::String) where T f.vals[i] = parse(T, line) return f diff --git a/test/format.jl b/test/format.jl new file mode 100644 index 0000000..e41029c --- /dev/null +++ b/test/format.jl @@ -0,0 +1,37 @@ +@testset "format" begin + @testset "CoordinateFormat" begin + T = Float64 + rows = [1, 2, 2, 3, 5, 7] + cols = [1, 1, 2, 3, 4, 4] + vals = T[1, 2, 3, 4, 5, 6] + A = sparse(rows, cols, vals) + + f = MatrixMarket.CoordinateFormat(rows, cols, vals) + @test MatrixMarket.CoordinateFormat(A) == f + @test length(f) == length(vals) + @test eltype(f) == T + @test MatrixMarket.formattext(f) == "coordinate" + @test Tuple(f) == (rows, cols, vals) + @test MatrixMarket.readout(f, 7, 4, "general") == A + + MatrixMarket.writeat!(f, 2, "3 1 7") + @test (f.rows[2], f.cols[2], f.vals[2]) == (3, 1, 7) + end + + @testset "ArrayFormat" begin + T = Float64 + vals = T[1, 2, 3, 4, 5, 6] + A = reshape(vals, 2, 3) + + f = MatrixMarket.ArrayFormat(vals) + @test MatrixMarket.ArrayFormat(A) == f + @test length(f) == length(vals) + @test eltype(f) == T + @test MatrixMarket.formattext(f) == "array" + @test Tuple(f) == (vals, ) + @test MatrixMarket.readout(f, 2, 3, "general") == A + + MatrixMarket.writeat!(f, 2, "7") + @test f.vals[2] == 7 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 4536740..d71ce96 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -14,6 +14,9 @@ const NIST_FILELIST = download_nist_filelist() tests = [ "mtx", + "reader", + "writer", + "format", ] @testset "MatrixMarket.jl" begin From 220e6762fb2e0d54b5fc626a064f3d175f842d3d Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 14 Feb 2023 15:20:49 +0800 Subject: [PATCH 5/6] add tests for reader and writer --- src/format.jl | 5 +++++ src/mmread.jl | 3 ++- src/mmwrite.jl | 1 + test/reader.jl | 16 ++++++++++++++++ test/writer.jl | 23 +++++++++++++++++++++++ 5 files changed, 47 insertions(+), 1 deletion(-) create mode 100644 test/reader.jl create mode 100644 test/writer.jl diff --git a/src/format.jl b/src/format.jl index f534dee..b9b83b1 100644 --- a/src/format.jl +++ b/src/format.jl @@ -63,6 +63,11 @@ struct ArrayFormat{T} <: MMFormat vals::Vector{T} end +function ArrayFormat(field, nentry::Int) + T = parse_eltype(field) + return ArrayFormat(T, nentry) +end + function ArrayFormat(::Type{T}, nentry::Int) where {T} vals = Vector{T}(undef, nentry) return ArrayFormat{T}(vals) diff --git a/src/mmread.jl b/src/mmread.jl index 563a5a3..b1a5606 100644 --- a/src/mmread.jl +++ b/src/mmread.jl @@ -132,7 +132,8 @@ struct MMReader{F <: MMFormat} end function MMReader(nrow::Integer, ncol::Integer, nentry::Integer, rep, field, symm) - format = (rep == "coordinate") ? CoordinateFormat(field, nentry) : ArrayFormat(nentry) + @assert nentry <= nrow * ncol "given nentry ($nentry) is greater than the product of nrow and ncol ($(nrow * ncol))" + format = (rep == "coordinate") ? CoordinateFormat(field, nentry) : ArrayFormat(field, nentry) return MMReader{typeof(format)}(nrow, ncol, nentry, rep, symm, format) end diff --git a/src/mmwrite.jl b/src/mmwrite.jl index 52e6e69..61ab335 100644 --- a/src/mmwrite.jl +++ b/src/mmwrite.jl @@ -83,6 +83,7 @@ end function MMWriter(A::AbstractMatrix{T}) where {T} nrow, ncol = size(A) + nentry = nrow * ncol vals = reshape(A, :) symm = generate_symmetric(A) format = ArrayFormat{T}(vals) diff --git a/test/reader.jl b/test/reader.jl new file mode 100644 index 0000000..47c1192 --- /dev/null +++ b/test/reader.jl @@ -0,0 +1,16 @@ +@testset "reader" begin + reader = MatrixMarket.MMReader(7, 4, 6, "coordinate", "real", "general") + @test reader.nrow == 7 + @test reader.ncol == 4 + @test reader.nentry == 6 + @test eltype(reader.format) == Float64 + @test reader.format isa MatrixMarket.CoordinateFormat + @test_throws AssertionError MatrixMarket.MMReader(7, 4, 100, "coordinate", "real", "general") + + reader = MatrixMarket.MMReader(2, 3, 6, "array", "integer", "general") + @test reader.nrow == 2 + @test reader.ncol == 3 + @test reader.nentry == 6 + @test eltype(reader.format) == Int64 + @test reader.format isa MatrixMarket.ArrayFormat +end diff --git a/test/writer.jl b/test/writer.jl new file mode 100644 index 0000000..f3c31ff --- /dev/null +++ b/test/writer.jl @@ -0,0 +1,23 @@ +@testset "writer" begin + A = sparse(rand([0, 1], 3, 4)) + writer = MatrixMarket.MMWriter(A) + @test writer.nrow == size(A, 1) + @test writer.ncol == size(A, 2) + @test writer.nentry == nnz(A) + @test writer.symm == "general" + @test eltype(writer.format) == Int64 + @test writer.format isa MatrixMarket.CoordinateFormat + @test MatrixMarket.header(writer) == "%%MatrixMarket matrix coordinate integer general" + @test MatrixMarket.sizetext(writer) == "$(size(A, 1)) $(size(A, 2)) $(nnz(A))" + + A = rand(ComplexF64, 3, 4) + writer = MatrixMarket.MMWriter(A) + @test writer.nrow == size(A, 1) + @test writer.ncol == size(A, 2) + @test writer.nentry == length(A) + @test writer.symm == "general" + @test eltype(writer.format) == ComplexF64 + @test writer.format isa MatrixMarket.ArrayFormat + @test MatrixMarket.header(writer) == "%%MatrixMarket matrix array complex general" + @test MatrixMarket.sizetext(writer) == "$(size(A, 1)) $(size(A, 2)) $(length(A))" +end From b20520898f92c690221e445ce2baf344ddbdd975 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Wed, 15 Feb 2023 11:08:11 +0800 Subject: [PATCH 6/6] add tests for exceptions --- test/format.jl | 1 + test/reader.jl | 7 +++++++ test/writer.jl | 2 ++ 3 files changed, 10 insertions(+) diff --git a/test/format.jl b/test/format.jl index e41029c..31b2435 100644 --- a/test/format.jl +++ b/test/format.jl @@ -25,6 +25,7 @@ f = MatrixMarket.ArrayFormat(vals) @test MatrixMarket.ArrayFormat(A) == f + @test eltype(MatrixMarket.ArrayFormat(length(vals))) == Float64 @test length(f) == length(vals) @test eltype(f) == T @test MatrixMarket.formattext(f) == "array" diff --git a/test/reader.jl b/test/reader.jl index 47c1192..4594024 100644 --- a/test/reader.jl +++ b/test/reader.jl @@ -6,6 +6,7 @@ @test eltype(reader.format) == Float64 @test reader.format isa MatrixMarket.CoordinateFormat @test_throws AssertionError MatrixMarket.MMReader(7, 4, 100, "coordinate", "real", "general") + @test MatrixMarket.readout(reader, true)[4:end] == (7, 4, 6, "coordinate", "real", "general") reader = MatrixMarket.MMReader(2, 3, 6, "array", "integer", "general") @test reader.nrow == 2 @@ -13,4 +14,10 @@ @test reader.nentry == 6 @test eltype(reader.format) == Int64 @test reader.format isa MatrixMarket.ArrayFormat + @test MatrixMarket.readout(reader, true)[2:end] == (2, 3, 6, "array", "integer", "general") + + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_eltype("aaa") + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_symmetric("aaa") + @test MatrixMarket.parse_dimension("3 4", "array") == (3, 4, 12) + @test_throws MatrixMarket.FileFormatException MatrixMarket.parse_dimension("3 4", "coordinate") end diff --git a/test/writer.jl b/test/writer.jl index f3c31ff..66fb774 100644 --- a/test/writer.jl +++ b/test/writer.jl @@ -20,4 +20,6 @@ @test writer.format isa MatrixMarket.ArrayFormat @test MatrixMarket.header(writer) == "%%MatrixMarket matrix array complex general" @test MatrixMarket.sizetext(writer) == "$(size(A, 1)) $(size(A, 2)) $(length(A))" + + @test_throws ErrorException MatrixMarket.generate_eltype(String) end