Skip to content

Commit db923c0

Browse files
separate GNNGraphs from GNNlib (#446)
* separate GNNGraphs from GNNlib * complete factorization * rebase
1 parent e2623eb commit db923c0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+550
-147
lines changed

GNNGraphs/Project.toml

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
name = "GNNGraphs"
2+
uuid = "aed8fd31-079b-4b5a-b342-a13352159b8c"
3+
authors = ["Carlo Lucibello and contributors"]
4+
version = "0.1.0"
5+
6+
[deps]
7+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
10+
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
11+
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
12+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
13+
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
14+
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
15+
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
16+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
17+
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
18+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
19+
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
20+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
21+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
22+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
23+
24+
[weakdeps]
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
26+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
27+
28+
[extensions]
29+
GNNGraphsCUDAExt = "CUDA"
30+
GNNGraphsSimpleWeightedGraphsExt = "SimpleWeightedGraphs"
31+
32+
[compat]
33+
Adapt = "4"
34+
CUDA = "5"
35+
ChainRulesCore = "1"
36+
Functors = "0.4.1"
37+
Graphs = "1.4"
38+
KrylovKit = "0.8"
39+
LinearAlgebra = "1"
40+
LuxDeviceUtils = "0.1.24"
41+
MLDatasets = "0.7"
42+
MLUtils = "0.4"
43+
NNlib = "0.9"
44+
NearestNeighbors = "0.4"
45+
Random = "1"
46+
SimpleWeightedGraphs = "1.4.0"
47+
SparseArrays = "1"
48+
Statistics = "1"
49+
StatsBase = "0.34"
50+
cuDNN = "1"
51+
julia = "1.9"
52+
53+
[extras]
54+
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
55+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
56+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
57+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
58+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
59+
InlineStrings = "842dd82b-1e85-43dc-bf29-5d0ee9dffc48"
60+
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
61+
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
62+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
63+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
64+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
65+
66+
[targets]
67+
test = ["Test", "Adapt", "DataFrames", "InlineStrings", "SimpleWeightedGraphs", "Zygote", "FiniteDifferences", "ChainRulesTestUtils", "MLDatasets", "CUDA", "cuDNN"]

GNNGraphs/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# GNNGraphs.jl
2+
3+
A package implementing graph types for graph deep learning.
4+
5+
This package is currently under development and may break frequentely.
6+
It is not meant for final users but for GNN libraries developers.
7+
Final user should use GraphNeuralNetworks.jl instead.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
module GNNGraphsCUDAExt
2+
3+
using CUDA
4+
using Random, Statistics, LinearAlgebra
5+
using GNNGraphs
6+
using GNNGraphs: COO_T, ADJMAT_T, SPARSE_T
7+
8+
const CUMAT_T = Union{CUDA.AnyCuMatrix, CUDA.CUSPARSE.CuSparseMatrix}
9+
10+
include("query.jl")
11+
include("transform.jl")
12+
include("utils.jl")
13+
14+
end #module

GNNlib/ext/GNNlibCUDAExt/GNNGraphs/utils.jl renamed to GNNGraphs/ext/GNNGraphsCUDAExt/utils.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ GNNGraphs.iscuarray(x::AnyCuArray) = true
33

44

55
function sort_edge_index(u::AnyCuArray, v::AnyCuArray)
6+
dev = get_device(u)
7+
cdev = cpu_device()
8+
u, v = u |> cdev, v |> cdev
69
#TODO proper cuda friendly implementation
7-
sort_edge_index(u |> Flux.cpu, v |> Flux.cpu) |> Flux.gpu
8-
end
10+
sort_edge_index(u, v) |> dev
11+
end

GNNlib/ext/GNNlibSimpleWeightedGraphsExt/GNNlibSimpleWeightedGraphsExt.jl renamed to GNNGraphs/ext/GNNGraphsSimpleWeightedGraphsExt/GNNGraphsSimpleWeightedGraphsExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
module GNNlibSimpleWeightedGraphsExt
1+
module GNNGraphsSimpleWeightedGraphsExt
22

3-
using GNNlib
43
using Graphs
4+
using GNNGraphs
55
using SimpleWeightedGraphs
66

7-
function GNNlib.GNNGraph(g::T; kws...) where
7+
function GNNGraphs.GNNGraph(g::T; kws...) where
88
{T <: Union{SimpleWeightedGraph, SimpleWeightedDiGraph}}
99
return GNNGraph(g.weights, kws...)
1010
end

GNNlib/src/GNNGraphs/GNNGraphs.jl renamed to GNNGraphs/src/GNNGraphs.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ module GNNGraphs
33
using SparseArrays
44
using Functors: @functor
55
import Graphs
6-
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
6+
using Graphs: AbstractGraph, outneighbors, inneighbors, adjacency_matrix, degree,
77
has_self_loops, is_directed
8-
import MLUtils
9-
using MLUtils: getobs, numobs, ones_like, zeros_like, batch
108
import NearestNeighbors
119
import NNlib
1210
import StatsBase
1311
import KrylovKit
1412
using ChainRulesCore
1513
using LinearAlgebra, Random, Statistics
1614
import MLUtils
15+
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch
1716
import Functors
17+
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice
1818

1919
include("chainrules.jl") # hacks for differentiability
2020

@@ -78,7 +78,9 @@ export add_nodes,
7878
to_unidirected,
7979
random_walk_pe,
8080
remove_nodes,
81-
# from Flux
81+
ppr_diffusion,
82+
drop_nodes,
83+
# from MLUtils
8284
batch,
8385
unbatch,
8486
# from SparseArrays
@@ -101,8 +103,7 @@ include("operators.jl")
101103

102104
include("convert.jl")
103105
include("utils.jl")
104-
export sort_edge_index,
105-
color_refinement
106+
export sort_edge_index, color_refinement
106107

107108
include("gatherscatter.jl")
108109
# _gather, _scatter
File renamed without changes.

0 commit comments

Comments
 (0)