|  | 
| 1 |  | -using Flux | 
| 2 |  | -using Flux: onecold, onehotbatch | 
| 3 |  | -using Flux.Losses: logitcrossentropy | 
| 4 |  | -using GraphNeuralNetworks | 
| 5 |  | -using MLDatasets: Cora | 
| 6 |  | -using Statistics, Random | 
| 7 |  | -using CUDA | 
| 8 |  | -CUDA.allowscalar(false) | 
|  | 1 | +@testitem "Training Example" setup=[TestModule] begin | 
|  | 2 | +    using .TestModule | 
|  | 3 | +    using Flux | 
|  | 4 | +    using Flux: onecold, onehotbatch | 
|  | 5 | +    using Flux.Losses: logitcrossentropy | 
|  | 6 | +    using GraphNeuralNetworks | 
|  | 7 | +    using MLDatasets: Cora | 
|  | 8 | +    using Statistics, Random | 
|  | 9 | +    using CUDA | 
|  | 10 | +    CUDA.allowscalar(false) | 
| 9 | 11 | 
 | 
| 10 |  | -function eval_loss_accuracy(X, y, ids, model, g) | 
| 11 |  | -    ŷ = model(g, X) | 
| 12 |  | -    l = logitcrossentropy(ŷ[:, ids], y[:, ids]) | 
| 13 |  | -    acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) | 
| 14 |  | -    return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) | 
| 15 |  | -end | 
|  | 12 | +    function eval_loss_accuracy(X, y, ids, model, g) | 
|  | 13 | +        ŷ = model(g, X) | 
|  | 14 | +        l = logitcrossentropy(ŷ[:, ids], y[:, ids]) | 
|  | 15 | +        acc = mean(onecold(ŷ[:, ids]) .== onecold(y[:, ids])) | 
|  | 16 | +        return (loss = round(l, digits = 4), acc = round(acc * 100, digits = 2)) | 
|  | 17 | +    end | 
| 16 | 18 | 
 | 
| 17 |  | -# arguments for the `train` function  | 
| 18 |  | -Base.@kwdef mutable struct Args | 
| 19 |  | -    η = 5.0f-3            # learning rate | 
| 20 |  | -    epochs = 10         # number of epochs | 
| 21 |  | -    seed = 17           # set seed > 0 for reproducibility | 
| 22 |  | -    usecuda = false     # if true use cuda (if available) | 
| 23 |  | -    nhidden = 64        # dimension of hidden features | 
| 24 |  | -end | 
|  | 19 | +    # arguments for the `train` function  | 
|  | 20 | +    Base.@kwdef mutable struct Args | 
|  | 21 | +        η = 5.0f-3            # learning rate | 
|  | 22 | +        epochs = 10         # number of epochs | 
|  | 23 | +        seed = 17           # set seed > 0 for reproducibility | 
|  | 24 | +        usecuda = false     # if true use cuda (if available) | 
|  | 25 | +        nhidden = 64        # dimension of hidden features | 
|  | 26 | +    end | 
| 25 | 27 | 
 | 
| 26 |  | -function train(Layer; verbose = false, kws...) | 
| 27 |  | -    args = Args(; kws...) | 
| 28 |  | -    args.seed > 0 && Random.seed!(args.seed) | 
|  | 28 | +    function train(Layer; verbose = false, kws...) | 
|  | 29 | +        args = Args(; kws...) | 
|  | 30 | +        args.seed > 0 && Random.seed!(args.seed) | 
| 29 | 31 | 
 | 
| 30 |  | -    if args.usecuda && CUDA.functional() | 
| 31 |  | -        device = Flux.gpu | 
| 32 |  | -        args.seed > 0 && CUDA.seed!(args.seed) | 
| 33 |  | -    else | 
| 34 |  | -        device = Flux.cpu | 
| 35 |  | -    end | 
|  | 32 | +        if args.usecuda && CUDA.functional() | 
|  | 33 | +            device = Flux.gpu | 
|  | 34 | +            args.seed > 0 && CUDA.seed!(args.seed) | 
|  | 35 | +        else | 
|  | 36 | +            device = Flux.cpu | 
|  | 37 | +        end | 
| 36 | 38 | 
 | 
| 37 |  | -    # LOAD DATA | 
| 38 |  | -    dataset = Cora() | 
| 39 |  | -    classes = dataset.metadata["classes"] | 
| 40 |  | -    g = mldataset2gnngraph(dataset) |> device | 
| 41 |  | -    X = g.ndata.features | 
| 42 |  | -    y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged | 
| 43 |  | -    train_mask = g.ndata.train_mask | 
| 44 |  | -    test_mask = g.ndata.test_mask | 
| 45 |  | -    ytrain = y[:, train_mask] | 
|  | 39 | +        # LOAD DATA | 
|  | 40 | +        dataset = Cora() | 
|  | 41 | +        classes = dataset.metadata["classes"] | 
|  | 42 | +        g = mldataset2gnngraph(dataset) |> device | 
|  | 43 | +        X = g.ndata.features | 
|  | 44 | +        y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged | 
|  | 45 | +        train_mask = g.ndata.train_mask | 
|  | 46 | +        test_mask = g.ndata.test_mask | 
|  | 47 | +        ytrain = y[:, train_mask] | 
| 46 | 48 | 
 | 
| 47 |  | -    nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) | 
|  | 49 | +        nin, nhidden, nout = size(X, 1), args.nhidden, length(classes) | 
| 48 | 50 | 
 | 
| 49 |  | -    ## DEFINE MODEL | 
| 50 |  | -    model = GNNChain(Layer(nin, nhidden), | 
| 51 |  | -                     #  Dropout(0.5), | 
| 52 |  | -                     Layer(nhidden, nhidden), | 
| 53 |  | -                     Dense(nhidden, nout)) |> device | 
|  | 51 | +        ## DEFINE MODEL | 
|  | 52 | +        model = GNNChain(Layer(nin, nhidden), | 
|  | 53 | +                        #  Dropout(0.5), | 
|  | 54 | +                        Layer(nhidden, nhidden), | 
|  | 55 | +                        Dense(nhidden, nout)) |> device | 
| 54 | 56 | 
 | 
| 55 |  | -    opt = Flux.setup(Adam(args.η), model) | 
|  | 57 | +        opt = Flux.setup(Adam(args.η), model) | 
| 56 | 58 | 
 | 
| 57 |  | -    ## TRAINING | 
| 58 |  | -    function report(epoch) | 
| 59 |  | -        train = eval_loss_accuracy(X, y, train_mask, model, g) | 
| 60 |  | -        test = eval_loss_accuracy(X, y, test_mask, model, g) | 
| 61 |  | -        println("Epoch: $epoch   Train: $(train)   Test: $(test)") | 
| 62 |  | -    end | 
|  | 59 | +        ## TRAINING | 
|  | 60 | +        function report(epoch) | 
|  | 61 | +            train = eval_loss_accuracy(X, y, train_mask, model, g) | 
|  | 62 | +            test = eval_loss_accuracy(X, y, test_mask, model, g) | 
|  | 63 | +            println("Epoch: $epoch   Train: $(train)   Test: $(test)") | 
|  | 64 | +        end | 
| 63 | 65 | 
 | 
| 64 |  | -    verbose && report(0) | 
| 65 |  | -    @time for epoch in 1:(args.epochs) | 
| 66 |  | -        grad = Flux.gradient(model) do model | 
| 67 |  | -            ŷ = model(g, X) | 
| 68 |  | -            logitcrossentropy(ŷ[:, train_mask], ytrain) | 
|  | 66 | +        verbose && report(0) | 
|  | 67 | +        @time for epoch in 1:(args.epochs) | 
|  | 68 | +            grad = Flux.gradient(model) do model | 
|  | 69 | +                ŷ = model(g, X) | 
|  | 70 | +                logitcrossentropy(ŷ[:, train_mask], ytrain) | 
|  | 71 | +            end | 
|  | 72 | +            Flux.update!(opt, model, grad[1]) | 
|  | 73 | +            verbose && report(epoch) | 
| 69 | 74 |         end | 
| 70 |  | -        Flux.update!(opt, model, grad[1]) | 
| 71 |  | -        verbose && report(epoch) | 
| 72 |  | -    end | 
| 73 | 75 | 
 | 
| 74 |  | -    train_res = eval_loss_accuracy(X, y, train_mask, model, g) | 
| 75 |  | -    test_res = eval_loss_accuracy(X, y, test_mask, model, g) | 
| 76 |  | -    return train_res, test_res | 
| 77 |  | -end | 
|  | 76 | +        train_res = eval_loss_accuracy(X, y, train_mask, model, g) | 
|  | 77 | +        test_res = eval_loss_accuracy(X, y, test_mask, model, g) | 
|  | 78 | +        return train_res, test_res | 
|  | 79 | +    end | 
| 78 | 80 | 
 | 
| 79 |  | -function train_many(; usecuda = false) | 
| 80 |  | -    for (layer, Layer) in [ | 
| 81 |  | -        ("GCNConv", (nin, nout) -> GCNConv(nin => nout, relu)), | 
| 82 |  | -        ("ResGatedGraphConv", (nin, nout) -> ResGatedGraphConv(nin => nout, relu)), | 
| 83 |  | -        ("GraphConv", (nin, nout) -> GraphConv(nin => nout, relu, aggr = mean)), | 
| 84 |  | -        ("SAGEConv", (nin, nout) -> SAGEConv(nin => nout, relu)), | 
| 85 |  | -        ("GATConv", (nin, nout) -> GATConv(nin => nout, relu)), | 
| 86 |  | -        ("GINConv", (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr = mean)), | 
| 87 |  | -        ("TransformerConv", | 
| 88 |  | -         (nin, nout) -> TransformerConv(nin => nout, concat = false, | 
| 89 |  | -                                        add_self_loops = true, root_weight = false, | 
| 90 |  | -                                        heads = 2)), | 
| 91 |  | -        ## ("ChebConv", (nin, nout) -> ChebConv(nin => nout, 2)), # not working on gpu | 
| 92 |  | -        ## ("NNConv", (nin, nout) -> NNConv(nin => nout)),  # needs edge features | 
| 93 |  | -        ## ("GatedGraphConv", (nin, nout) -> GatedGraphConv(nout, 2)),  # needs nin = nout | 
| 94 |  | -        ## ("EdgeConv",(nin, nout) -> EdgeConv(Dense(2nin, nout, relu))), # Fits the training set but does not generalize well | 
| 95 |  | -    ] | 
| 96 |  | -        @show layer | 
| 97 |  | -        @time train_res, test_res = train(Layer; usecuda, verbose = false) | 
| 98 |  | -        # @show train_res, test_res | 
| 99 |  | -        @test train_res.acc > 94 | 
| 100 |  | -        @test test_res.acc > 69 | 
|  | 81 | +    function train_many(; usecuda = false) | 
|  | 82 | +        for (layer, Layer) in [ | 
|  | 83 | +            ("GCNConv", (nin, nout) -> GCNConv(nin => nout, relu)), | 
|  | 84 | +            ("ResGatedGraphConv", (nin, nout) -> ResGatedGraphConv(nin => nout, relu)), | 
|  | 85 | +            ("GraphConv", (nin, nout) -> GraphConv(nin => nout, relu, aggr = mean)), | 
|  | 86 | +            ("SAGEConv", (nin, nout) -> SAGEConv(nin => nout, relu)), | 
|  | 87 | +            ("GATConv", (nin, nout) -> GATConv(nin => nout, relu)), | 
|  | 88 | +            ("GINConv", (nin, nout) -> GINConv(Dense(nin, nout, relu), 0.01, aggr = mean)), | 
|  | 89 | +            ("TransformerConv", | 
|  | 90 | +            (nin, nout) -> TransformerConv(nin => nout, concat = false, | 
|  | 91 | +                                            add_self_loops = true, root_weight = false, | 
|  | 92 | +                                            heads = 2)), | 
|  | 93 | +            ## ("ChebConv", (nin, nout) -> ChebConv(nin => nout, 2)), # not working on gpu | 
|  | 94 | +            ## ("NNConv", (nin, nout) -> NNConv(nin => nout)),  # needs edge features | 
|  | 95 | +            ## ("GatedGraphConv", (nin, nout) -> GatedGraphConv(nout, 2)),  # needs nin = nout | 
|  | 96 | +            ## ("EdgeConv",(nin, nout) -> EdgeConv(Dense(2nin, nout, relu))), # Fits the training set but does not generalize well | 
|  | 97 | +        ] | 
|  | 98 | +            @show layer | 
|  | 99 | +            @time train_res, test_res = train(Layer; usecuda, verbose = false) | 
|  | 100 | +            # @show train_res, test_res | 
|  | 101 | +            @test train_res.acc > 94 | 
|  | 102 | +            @test test_res.acc > 69 | 
|  | 103 | +        end | 
| 101 | 104 |     end | 
| 102 |  | -end | 
| 103 | 105 | 
 | 
| 104 |  | -train_many(usecuda = false) | 
| 105 |  | -if TEST_GPU | 
| 106 |  | -    train_many(usecuda = true) | 
|  | 106 | +    train_many(usecuda = false) | 
|  | 107 | +    # #TODO | 
|  | 108 | +    # if TEST_GPU | 
|  | 109 | +    #     train_many(usecuda = true) | 
|  | 110 | +    # end | 
| 107 | 111 | end | 
0 commit comments