@@ -11,6 +11,20 @@ Base.@kwdef mutable struct AMGXMatrix <: AMGXObject
1111 return m
1212 end
1313end
14+ function Base. show (io:: IO , mime:: MIME"text/plain" , m:: AMGXMatrix )
15+ invoke (show, Tuple{IO, MIME " text/plain" , AMGXObject}, io, mime, m)
16+ m. handle == C_NULL && return
17+ n, block_dims = matrix_get_size (m)
18+ if n != = 0
19+ nz = amgx_nnz (m)
20+ if block_dims == (1 , 1 )
21+ print (io, " of size $n ×$n with $nz stored entries" )
22+ else
23+ block_dim_x, block_dim_y = block_dims
24+ print (io, " of size $n ⋅$block_dim_x ×$n ⋅$block_dim_y with $nz stored block entries" )
25+ end
26+ end
27+ end
1428get_api_destroy_call (:: Type{AMGXMatrix} ) = API. AMGX_matrix_destroy
1529function dec_refcount_parents (m:: AMGXMatrix )
1630 dec_refcount! (m. resources)
@@ -82,18 +96,18 @@ function upload!(matrix::AMGXMatrix, cu_matrix::CUDA.CUSPARSE.CuSparseMatrixCSR)
8296 upload! (matrix, cu_matrix. rowPtr, cu_matrix. colVal, cu_matrix. nzVal)
8397end
8498
85- function get_size (matrix:: AMGXMatrix )
99+ function matrix_get_size (matrix:: AMGXMatrix )
86100 n_ptr, block_dim_x_ptr, block_dim_y_ptr = Ref {Cint} (), Ref {Cint} (), Ref {Cint} ()
87101 @checked API. AMGX_matrix_get_size (matrix. handle, n_ptr, block_dim_x_ptr, block_dim_y_ptr)
88102 return Int (n_ptr[]), (Int (block_dim_x_ptr[]), Int (block_dim_y_ptr[]))
89103end
90104function Base. size (matrix:: AMGXMatrix )
91- n, block_dims = get_size (matrix)
105+ n, block_dims = matrix_get_size (matrix)
92106 return n * block_dims[1 ], n * block_dims[2 ]
93107end
94108
95109function replace_coefficients! (m:: AMGXMatrix , data:: VectorOrCuVector{T} , diag_data:: Union{VectorOrCuVector{T}, Nothing} = nothing ) where {T <: Union{Float64, Float32} }
96- n, block_dims = get_size (m)
110+ n, block_dims = matrix_get_size (m)
97111 _amgx_nnz = amgx_nnz (m)
98112 if length (data) != _amgx_nnz * prod (block_dims)
99113 throw (ArgumentError (" can not change the number of nnz entries in `replace_coefficients!`" ))
@@ -125,6 +139,6 @@ function amgx_nnz(matrix)
125139end
126140
127141function SparseArrays. nnz (matrix:: AMGXMatrix )
128- _, block_dims = get_size (matrix)
142+ _, block_dims = matrix_get_size (matrix)
129143 return amgx_nnz (matrix) * prod (block_dims)
130144end
0 commit comments