@@ -1583,6 +1583,30 @@ function assemble!(B::PSparseMatrix,A::PSparseMatrix,cache)
15831583 psparse_assemble_impl!(B,A,T,cache)
15841584end
15851585
1586+ """
1587+ assemble!([f,]A::PSparseMatrix;kwargs...)
1588+ """
1589+ function assemble!(A:: PSparseMatrix ;kwargs... )
1590+ assemble!(+ ,A;kwargs... )
1591+ end
1592+
1593+ function assemble!(f,A:: PSparseMatrix ;kwargs... )
1594+ T = eltype(partition(A))
1595+ psparse_assemble_impl!(f,A,T;kwargs... )
1596+ end
1597+
1598+ """
1599+ assemble!([f,]A::PSparseMatrix,cache)
1600+ """
1601+ function assemble!(A:: PSparseMatrix ,cache)
1602+ assemble!(+ ,A,cache)
1603+ end
1604+
1605+ function assemble!(f,A:: PSparseMatrix ,cache)
1606+ T = eltype(partition(A))
1607+ psparse_assemble_impl!(f,A,T,cache)
1608+ end
1609+
15861610function psparse_assemble_impl(A,:: Type ,rows;kwargs... )
15871611 error(" Case not implemented yet" )
15881612end
@@ -1755,6 +1779,132 @@ function psparse_assemble_impl(
17551779 end
17561780end
17571781
1782+ function psparse_assemble_impl(
1783+ A,
1784+ :: Type{<:AbstractSparseMatrix} ,
1785+ rows;
1786+ reuse= Val(false ),
1787+ assembly_neighbors_options_cols= (;))
1788+
1789+ function setup_cache_snd(A,parts_snd,rows_sa,cols_sa)
1790+ local_to_owner_row = local_to_owner(rows_sa)
1791+ local_to_global_row = local_to_global(rows_sa)
1792+ local_to_global_col = local_to_global(cols_sa)
1793+ me = part_id(rows_sa)
1794+ owner_to_p = Dict(( owner=> i for (i,owner) in enumerate(parts_snd) ))
1795+ ptrs = zeros(Int32,length(parts_snd)+ 1 )
1796+ for (i,_,_) in nziterator(A)
1797+ owner = local_to_owner_row[i]
1798+ if owner != me
1799+ ptrs[owner_to_p[owner]+ 1 ] += 1
1800+ end
1801+ end
1802+ length_to_ptrs!(ptrs)
1803+ Tv = eltype(A)
1804+ ndata = ptrs[end ]- 1
1805+ I_snd_data = zeros(Int,ndata)
1806+ J_snd_data = zeros(Int,ndata)
1807+ V_snd_data = zeros(Tv,ndata)
1808+ k_snd_data = zeros(Int32,ndata)
1809+ for (k,(i,j,v)) in enumerate(nziterator(A))
1810+ owner = local_to_owner_row[i]
1811+ if owner != me
1812+ p = ptrs[owner_to_p[owner]]
1813+ I_snd_data[p] = local_to_global_row[i]
1814+ J_snd_data[p] = local_to_global_col[j]
1815+ V_snd_data[p] = v
1816+ k_snd_data[p] = k
1817+ ptrs[owner_to_p[owner]] += 1
1818+ end
1819+ end
1820+ rewind_ptrs!(ptrs)
1821+ I_snd = JaggedArray(I_snd_data,ptrs)
1822+ J_snd = JaggedArray(J_snd_data,ptrs)
1823+ V_snd = JaggedArray(V_snd_data,ptrs)
1824+ k_snd = JaggedArray(k_snd_data,ptrs)
1825+ (;I_snd,J_snd,V_snd,k_snd,parts_snd)
1826+ end
1827+ function setup_cache_rcv(I_rcv,J_rcv,V_rcv,parts_rcv)
1828+ k_rcv_data = zeros(Int32,length(I_rcv. data))
1829+ k_rcv = JaggedArray(k_rcv_data,I_rcv. ptrs)
1830+ (;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
1831+ end
1832+ function setup_own_triplets(A,cache_rcv,rows_sa,cols_sa)
1833+ local_to_own_rows = local_to_own(rows_sa)
1834+ I_sa, J_sa, V_sa = findnz(A)
1835+ k_own_to_sa = findall(i -> ! iszero(local_to_own_rows[i]), I_sa)
1836+
1837+ I_own = view(I_sa,k_own_to_sa)
1838+ I_rcv = cache_rcv. I_rcv. data
1839+ map_global_to_local!(I_rcv,rows_sa)
1840+ I = vcat(I_own,I_rcv)
1841+
1842+ J_own = view(J_sa,k_own_to_sa)
1843+ J_rcv = cache_rcv. J_rcv. data
1844+ map_local_to_global!(J_own,cols_sa)
1845+ J = vcat(J_own,J_rcv)
1846+
1847+ V_own = view(V_sa,k_own_to_sa)
1848+ V_rcv = cache_rcv. V_rcv. data
1849+ V = vcat(V_own,V_rcv)
1850+
1851+ (I,J,V), J, k_own_to_sa
1852+ end
1853+ function finalize_values(A,rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux)
1854+ I, J, V = triplets
1855+ k_own_to_sa = aux
1856+ I_rcv, J_rcv, k_rcv = cache_rcv. I_rcv. data, cache_rcv. J_rcv. data, cache_rcv. k_rcv. data
1857+ map_global_to_local!(J,cols_fa)
1858+ values = compresscoo(typeof(A),I,J,V,length(rows_fa),length(cols_fa))
1859+
1860+ k_sa = zeros(Int32,nnz(A))
1861+ n_own = length(k_own_to_sa)
1862+ I_own = view(I,1 : n_own)
1863+ J_own = view(J,1 : n_own)
1864+ k_own = view(k_sa,k_own_to_sa)
1865+ precompute_nzindex!(k_own,values,I_own,J_own)
1866+
1867+ n_tot = length(I)
1868+ I_rcv = view(I,n_own+ 1 : n_tot)
1869+ J_rcv = view(J,n_own+ 1 : n_tot)
1870+ precompute_nzindex!(k_rcv,values,I_rcv,J_rcv)
1871+
1872+ cache = (;k_sa,cache_snd... ,cache_rcv... )
1873+ values, cache
1874+ end
1875+ rows_sa = partition(axes(A,1 ))
1876+ cols_sa = partition(axes(A,2 ))
1877+ cols = map(remove_ghost,cols_sa)
1878+ parts_snd, parts_rcv = assembly_neighbors(rows_sa)
1879+ cache_snd = map(setup_cache_snd,partition(A),parts_snd,rows_sa,cols_sa)
1880+ I_snd = map(i-> i. I_snd,cache_snd)
1881+ J_snd = map(i-> i. J_snd,cache_snd)
1882+ V_snd = map(i-> i. V_snd,cache_snd)
1883+ graph = ExchangeGraph(parts_snd,parts_rcv)
1884+ t_I = exchange(I_snd,graph)
1885+ t_J = exchange(J_snd,graph)
1886+ t_V = exchange(V_snd,graph)
1887+ @fake_async begin
1888+ I_rcv = fetch(t_I)
1889+ J_rcv = fetch(t_J)
1890+ V_rcv = fetch(t_V)
1891+ cache_rcv = map(setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
1892+ triplets,J,aux = map(setup_own_triplets,partition(A),cache_rcv,rows_sa,cols_sa) |> tuple_of_arrays
1893+ J_owner = find_owner(cols_sa,J)
1894+ rows_fa = rows
1895+ cols_fa = map(union_ghost,cols,J,J_owner)
1896+ assembly_neighbors(cols_fa;assembly_neighbors_options_cols... )
1897+ vals_fa, cache = map(finalize_values,partition(A),rows_fa,cols_fa,cache_snd,cache_rcv,triplets,aux) |> tuple_of_arrays
1898+ assembled = true
1899+ B = PSparseMatrix(vals_fa,rows_fa,cols_fa,assembled)
1900+ if ! val_parameter(reuse)
1901+ B
1902+ else
1903+ B, cache
1904+ end
1905+ end
1906+ end
1907+
17581908function psparse_assemble_impl!(B,A,:: Type ,cache)
17591909 error(" case not implemented" )
17601910end
@@ -1815,6 +1965,177 @@ function psparse_assemble_impl!(B,A,::Type{<:AbstractSplitMatrix},cache)
18151965 end
18161966end
18171967
1968+ function psparse_assemble_impl!(B,A,:: Type{<:AbstractSparseMatrix} ,cache)
1969+ function setup_snd(A,cache)
1970+ V_snd = cache. V_snd. data
1971+ k_snd = cache. k_snd. data
1972+ nz = nonzeros(A)
1973+ for p in eachindex(k_snd)
1974+ k = k_snd[p]
1975+ V_snd[p] = nz[k]
1976+ end
1977+ end
1978+ function setup_sa(B,A,cache)
1979+ setcoofast!(B,nonzeros(A),cache. k_sa)
1980+ end
1981+ function setup_rcv(B,cache)
1982+ V_rcv = cache. V_rcv. data
1983+ k_rcv = cache. k_rcv. data
1984+ nz = nonzeros(B)
1985+ for p in eachindex(k_rcv)
1986+ k = k_rcv[p]
1987+ nz[k] += V_rcv[p]
1988+ end
1989+ end
1990+ map(setup_snd,partition(A),cache)
1991+ parts_snd = map(i-> i. parts_snd,cache)
1992+ parts_rcv = map(i-> i. parts_rcv,cache)
1993+ V_snd = map(i-> i. V_snd,cache)
1994+ V_rcv = map(i-> i. V_rcv,cache)
1995+ graph = ExchangeGraph(parts_snd,parts_rcv)
1996+ t = exchange!(V_rcv,V_snd,graph)
1997+ map(setup_sa,partition(B),partition(A),cache)
1998+ @fake_async begin
1999+ wait(t)
2000+ map(setup_rcv,partition(B),cache)
2001+ B
2002+ end
2003+ end
2004+
2005+ function psparse_assemble_impl!(f,A,:: Type ;kwargs... )
2006+ error(" case not implemented" )
2007+ end
2008+
2009+ function psparse_assemble_impl!(
2010+ f,
2011+ A,
2012+ :: Type{<:AbstractSparseMatrix} ;
2013+ reuse= Val(false ))
2014+
2015+ function setup_cache_snd(A,parts_snd,rows,cols)
2016+ local_to_owner_row = local_to_owner(rows)
2017+ local_to_global_row = local_to_global(rows)
2018+ local_to_global_col = local_to_global(cols)
2019+ me = part_id(rows)
2020+ owner_to_p = Dict(( owner=> i for (i,owner) in enumerate(parts_snd) ))
2021+ ptrs = zeros(Int32,length(parts_snd)+ 1 )
2022+ for (i,_,_) in nziterator(A)
2023+ owner = local_to_owner_row[i]
2024+ if owner != me
2025+ ptrs[owner_to_p[owner]+ 1 ] += 1
2026+ end
2027+ end
2028+ length_to_ptrs!(ptrs)
2029+ Tv = eltype(A)
2030+ ndata = ptrs[end ]- 1
2031+ I_snd_data = zeros(Int,ndata)
2032+ J_snd_data = zeros(Int,ndata)
2033+ V_snd_data = zeros(Tv,ndata)
2034+ k_snd_data = zeros(Int32,ndata)
2035+ for (k,(i,j,v)) in enumerate(nziterator(A))
2036+ owner = local_to_owner_row[i]
2037+ if owner != me
2038+ p = ptrs[owner_to_p[owner]]
2039+ I_snd_data[p] = local_to_global_row[i]
2040+ J_snd_data[p] = local_to_global_col[j]
2041+ V_snd_data[p] = v
2042+ k_snd_data[p] = k
2043+ ptrs[owner_to_p[owner]] += 1
2044+ end
2045+ end
2046+ rewind_ptrs!(ptrs)
2047+ I_snd = JaggedArray(I_snd_data,ptrs)
2048+ J_snd = JaggedArray(J_snd_data,ptrs)
2049+ V_snd = JaggedArray(V_snd_data,ptrs)
2050+ k_snd = JaggedArray(k_snd_data,ptrs)
2051+ (;I_snd,J_snd,V_snd,k_snd,parts_snd)
2052+ end
2053+ function setup_cache_rcv(I_rcv,J_rcv,V_rcv,parts_rcv)
2054+ k_rcv_data = zeros(Int32,length(I_rcv. data))
2055+ k_rcv = JaggedArray(k_rcv_data,I_rcv. ptrs)
2056+ (;I_rcv,J_rcv,V_rcv,k_rcv,parts_rcv)
2057+ end
2058+ function finalize_values!(A,rows,cols,cache_snd,cache_rcv)
2059+ I_rcv_data = cache_rcv. I_rcv. data
2060+ J_rcv_data = cache_rcv. J_rcv. data
2061+ V_rcv_data = cache_rcv. V_rcv. data
2062+ k_rcv_data = cache_rcv. k_rcv. data
2063+ A_nonzeros = nonzeros(A)
2064+ map_global_to_local!(I_rcv_data,rows)
2065+ map_global_to_local!(J_rcv_data,cols)
2066+ for p in eachindex(k_rcv_data)
2067+ i = I_rcv_data[p]
2068+ j = J_rcv_data[p]
2069+ k = nzindex(A,i,j)
2070+ @boundscheck @assert k > 0 " The sparsity pattern of the ghost layer is inconsistent"
2071+ k_rcv_data[p] = k
2072+ A_nonzeros[k] = f(A_nonzeros[k],V_rcv_data[p])
2073+ end
2074+ cache = (;cache_snd... ,cache_rcv... )
2075+ cache
2076+ end
2077+ rows = partition(axes(A,1 ))
2078+ cols = partition(axes(A,2 ))
2079+ parts_snd, parts_rcv = assembly_neighbors(rows)
2080+ cache_snd = map(setup_cache_snd,partition(A),parts_snd,rows,cols)
2081+ I_snd = map(i-> i. I_snd,cache_snd)
2082+ J_snd = map(i-> i. J_snd,cache_snd)
2083+ V_snd = map(i-> i. V_snd,cache_snd)
2084+ graph = ExchangeGraph(parts_snd,parts_rcv)
2085+ t_I = exchange(I_snd,graph)
2086+ t_J = exchange(J_snd,graph)
2087+ t_V = exchange(V_snd,graph)
2088+ @fake_async begin
2089+ I_rcv = fetch(t_I)
2090+ J_rcv = fetch(t_J)
2091+ V_rcv = fetch(t_V)
2092+ cache_rcv = map(setup_cache_rcv,I_rcv,J_rcv,V_rcv,parts_rcv)
2093+ cache = map(finalize_values!,partition(A),rows,cols,cache_snd,cache_rcv)
2094+ if ! val_parameter(reuse)
2095+ A
2096+ else
2097+ A, cache
2098+ end
2099+ end
2100+ end
2101+
2102+ function psparse_assemble_impl!(f:: Function ,A,:: Type ,cache)
2103+ error(" case not implemented" )
2104+ end
2105+
2106+ function psparse_assemble_impl!(f:: Function ,A,:: Type{<:AbstractSparseMatrix} ,cache)
2107+ function setup_snd(A,cache)
2108+ V_snd_data = cache. V_snd. data
2109+ k_snd_data = cache. k_snd. data
2110+ A_nonzeros = nonzeros(A)
2111+ for p in eachindex(k_snd_data)
2112+ k = k_snd_data[p]
2113+ V_snd_data[p] = A_nonzeros[k]
2114+ end
2115+ end
2116+ function setup_rcv(A,cache)
2117+ V_rcv_data = cache. V_rcv. data
2118+ k_rcv_data = cache. k_rcv. data
2119+ A_nonzeros = nonzeros(A)
2120+ for p in eachindex(k_rcv_data)
2121+ k = k_rcv_data[p]
2122+ A_nonzeros[k] = f(A_nonzeros[k],V_rcv_data[p])
2123+ end
2124+ end
2125+ map(setup_snd,partition(A),cache)
2126+ parts_snd = map(i-> i. parts_snd,cache)
2127+ parts_rcv = map(i-> i. parts_rcv,cache)
2128+ V_snd = map(i-> i. V_snd,cache)
2129+ V_rcv = map(i-> i. V_rcv,cache)
2130+ graph = ExchangeGraph(parts_snd,parts_rcv)
2131+ t = exchange!(V_rcv,V_snd,graph)
2132+ @fake_async begin
2133+ wait(t)
2134+ map(setup_rcv,partition(A),cache)
2135+ A
2136+ end
2137+ end
2138+
18182139"""
18192140 consistent(A::PSparseMatrix,rows;kwargs...)
18202141"""
0 commit comments