Skip to content

Commit 4b165d7

Browse files
authored
add nicer show methods (#13)
1 parent 7f38ec8 commit 4b165d7

File tree

8 files changed

+62
-8
lines changed

8 files changed

+62
-8
lines changed

src/AMGX.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,15 @@ function warn_not_destroyed_on_finalize(x::AMGXObject)
120120
!(x.handle == C_NULL) && @async @warn("AMGX: likely memory leak: a `$(typeof(x))` was finalized without having been `close`d")
121121
end
122122

123+
function Base.show(io::IO, ::MIME"text/plain", object::AMGXObject)
124+
ptr_str = object.handle == C_NULL ? "uninitialized" : "@" * sprint(show, UInt(object.handle))
125+
print(io, typeof(object), " ", ptr_str)
126+
object.handle == C_NULL && return
127+
if hasfield(typeof(object), :mode)
128+
print(io, " ", object.mode)
129+
end
130+
end
131+
123132

124133
############
125134
# Includes #

src/Matrix.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,20 @@ Base.@kwdef mutable struct AMGXMatrix <: AMGXObject
1111
return m
1212
end
1313
end
14+
function Base.show(io::IO, mime::MIME"text/plain", m::AMGXMatrix)
15+
invoke(show, Tuple{IO, MIME"text/plain", AMGXObject}, io, mime, m)
16+
m.handle == C_NULL && return
17+
n, block_dims = matrix_get_size(m)
18+
if n !== 0
19+
nz = amgx_nnz(m)
20+
if block_dims == (1, 1)
21+
print(io, " of size $n×$n with $nz stored entries")
22+
else
23+
block_dim_x, block_dim_y = block_dims
24+
print(io, " of size $n$block_dim_x×$n$block_dim_y with $nz stored block entries")
25+
end
26+
end
27+
end
1428
get_api_destroy_call(::Type{AMGXMatrix}) = API.AMGX_matrix_destroy
1529
function dec_refcount_parents(m::AMGXMatrix)
1630
dec_refcount!(m.resources)
@@ -82,18 +96,18 @@ function upload!(matrix::AMGXMatrix, cu_matrix::CUDA.CUSPARSE.CuSparseMatrixCSR)
8296
upload!(matrix, cu_matrix.rowPtr, cu_matrix.colVal, cu_matrix.nzVal)
8397
end
8498

85-
function get_size(matrix::AMGXMatrix)
99+
function matrix_get_size(matrix::AMGXMatrix)
86100
n_ptr, block_dim_x_ptr, block_dim_y_ptr = Ref{Cint}(), Ref{Cint}(), Ref{Cint}()
87101
@checked API.AMGX_matrix_get_size(matrix.handle, n_ptr, block_dim_x_ptr, block_dim_y_ptr)
88102
return Int(n_ptr[]), (Int(block_dim_x_ptr[]), Int(block_dim_y_ptr[]))
89103
end
90104
function Base.size(matrix::AMGXMatrix)
91-
n, block_dims = get_size(matrix)
105+
n, block_dims = matrix_get_size(matrix)
92106
return n * block_dims[1], n * block_dims[2]
93107
end
94108

95109
function replace_coefficients!(m::AMGXMatrix, data::VectorOrCuVector{T}, diag_data::Union{VectorOrCuVector{T}, Nothing}=nothing) where {T <: Union{Float64, Float32}}
96-
n, block_dims = get_size(m)
110+
n, block_dims = matrix_get_size(m)
97111
_amgx_nnz = amgx_nnz(m)
98112
if length(data) != _amgx_nnz * prod(block_dims)
99113
throw(ArgumentError("can not change the number of nnz entries in `replace_coefficients!`"))
@@ -125,6 +139,6 @@ function amgx_nnz(matrix)
125139
end
126140

127141
function SparseArrays.nnz(matrix::AMGXMatrix)
128-
_, block_dims = get_size(matrix)
142+
_, block_dims = matrix_get_size(matrix)
129143
return amgx_nnz(matrix) * prod(block_dims)
130144
end

src/Solver.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ Base.@kwdef mutable struct Solver <: AMGXObject
22
handle::API.AMGX_solver_handle = API.AMGX_solver_handle(C_NULL)
33
resources::Union{Resources, Nothing} = nothing
44
config::Union{Config, Nothing} = nothing
5+
mode::Union{Mode, Nothing} = nothing
56
bound_matrix::Union{AMGXMatrix, Nothing} = nothing
6-
function Solver(handle::API.AMGX_solver_handle, resources::Union{Resources, Nothing},
7+
function Solver(handle::API.AMGX_solver_handle, resources::Union{Resources, Nothing}, mode::Union{Mode, Nothing},
78
config::Union{Config, Nothing}, bound_matrix::Union{AMGXMatrix, Nothing})
8-
solver = new(handle, resources, config, bound_matrix)
9+
solver = new(handle, resources, config, mode, bound_matrix)
910
finalizer(warn_not_destroyed_on_finalize, solver)
1011
return solver
1112
end
@@ -16,6 +17,7 @@ function dec_refcount_parents(solver::Solver)
1617
solver.resources = nothing
1718
solver.config = nothing
1819
solver.bound_matrix = nothing
20+
solver.mode = nothing
1921
nothing
2022
end
2123
get_api_destroy_call(::Type{Solver}) = API.AMGX_solver_destroy
@@ -27,6 +29,7 @@ function create!(solver::Solver, res::Resources, mode::Mode, config::Config)
2729
solver.handle = solver_handle_ptr[]
2830
solver.resources = res
2931
solver.config = config
32+
solver.mode = mode
3033
inc_refcount!(solver.resources)
3134
inc_refcount!(solver.config)
3235
return solver

src/Vector.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,18 @@ Base.@kwdef mutable struct AMGXVector <: AMGXObject
99
return v
1010
end
1111
end
12+
function Base.show(io::IO, mime::MIME"text/plain", v::AMGXVector)
13+
invoke(show, Tuple{IO, MIME"text/plain", AMGXObject}, io, mime, v)
14+
v.handle == C_NULL && return
15+
n, block_dim = vector_get_size(v)
16+
if n !== 0
17+
if block_dim == 1
18+
print(io, " of length $n")
19+
else
20+
print(io, " of length $n$block_dim")
21+
end
22+
end
23+
end
1224
get_api_destroy_call(::Type{AMGXVector}) = API.AMGX_vector_destroy
1325
function dec_refcount_parents(v::AMGXVector)
1426
dec_refcount!(v.resources)

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
using AMGX
22

3+
repl_output(x) = sprint((io, x) -> show(io, MIME("text/plain"), x), x)
4+
35
# Hide annoying output from the library
46
AMGX.register_print_callback(x -> nothing)
57

test/test_config.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
module TestConfig
22

3+
import ..repl_output
34
using AMGX, Defer, Test
45
using AMGX: Config, AMGXException
56

67
@testset "Config" begin
78
@scope @testset "Dict" begin
89
cfg = @! Config()
10+
@test occursin("uninitialized", repl_output(cfg))
911
d = Dict("max_levels" => 10)
1012
AMGX.create!(cfg, d)
13+
str = sprint((io, cfg) -> show(io, MIME("text/plain"), cfg), cfg)
14+
@test !occursin("uninitialized", repl_output(cfg))
15+
@test occursin("@", repl_output(cfg))
1116

1217
d = Dict("max_lovels" => 10)
1318
@test_throws AMGXException("Incorrect amgx configuration provided.") Config(d)

test/test_matrix.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ module TestMatrix
22

33
# TODO: Test `diag_dat` argument to `upload!` and `replace_coefficients!`
44

5+
import ..repl_output
56
using AMGX, Defer, Test, JSON, CUDA, SparseArrays
67
using AMGX: Config, Resources, AMGXMatrix, dDDI, dFFI
78

89
@scope @testset "Matrix" begin
910
c = @! Config("")
1011
r = @! Resources(c)
1112
m = @! AMGXMatrix(r, dDDI)
13+
@test occursin("dDDI", repl_output(m))
1214

1315
@testset "upload" begin
1416
AMGX.upload!(m,
@@ -17,8 +19,9 @@ using AMGX: Config, Resources, AMGXMatrix, dDDI, dFFI
1719
[1.0, 2.0, 3.0]
1820
)
1921
@test nnz(m) == 3
20-
@test AMGX.get_size(m) == (2, (1,1))
22+
@test AMGX.matrix_get_size(m) == (2, (1,1))
2123
@test size(m) == (2, 2)
24+
@test occursin("of size 2×2 with 3 stored entries", repl_output(m))
2225

2326
@testset "replace coefficients" begin
2427
# TODO: Should test this does something
@@ -48,9 +51,11 @@ using AMGX: Config, Resources, AMGXMatrix, dDDI, dFFI
4851
blocks_flatten;
4952
block_dims = (2,2)
5053
)
51-
@test AMGX.get_size(m) == (2, (2,2))
54+
@test AMGX.matrix_get_size(m) == (2, (2,2))
5255
@test size(m) == (4, 4)
5356
@test nnz(m) == 12
57+
@test occursin("of size 2⋅2×2⋅2 with 3 stored block entries", repl_output(m))
58+
5459
@testset "replace coefficients" begin
5560
AMGX.replace_coefficients!(m, ones(Float64, length(blocks_flatten)))
5661
@test_throws ArgumentError AMGX.replace_coefficients!(m, ones(Float64, length(blocks_flatten)-1))

test/test_vector.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
module TestVector
22

3+
import ..repl_output
34
using AMGX, Defer, Test, JSON, CUDA
45
using AMGX: Config, Resources, AMGXVector, dDDI, dFFI
56

@@ -9,8 +10,10 @@ using AMGX: Config, Resources, AMGXVector, dDDI, dFFI
910

1011
@scope @testset "upload/download" begin
1112
v = @! AMGXVector(r, dDDI)
13+
@test occursin("dDDI", repl_output(v))
1214
v_h = [1.0, 2.0, 3.0]
1315
AMGX.upload!(v, v_h)
16+
@test occursin("of length 3", repl_output(v))
1417
@test length(v) == 3
1518
@test AMGX.download(v) == v_h
1619
@test Vector(v) == v_h
@@ -46,6 +49,7 @@ using AMGX: Config, Resources, AMGXVector, dDDI, dFFI
4649
v = @! AMGXVector(r, dDDI)
4750
v_h = [1.0, 2.0, 3.0, 4.0]
4851
AMGX.upload!(v, v_h; block_dim = 2)
52+
@test occursin("of length 2⋅2", repl_output(v))
4953
@test length(v_h) == 4
5054
@test AMGX.vector_get_size(v) == (2, 2)
5155
v_h = [1.0, 2.0, 3.0]

0 commit comments

Comments
 (0)