|  | 
| 1 |  | -@testitem "Training Example" setup=[TestModule] begin | 
| 2 |  | -    using .TestModule | 
|  | 1 | +@testmodule TrainingExampleModule begin | 
| 3 | 2 |     using Flux | 
| 4 | 3 |     using Flux: onecold, onehotbatch | 
| 5 | 4 |     using Flux.Losses: logitcrossentropy | 
| 6 | 5 |     using GraphNeuralNetworks | 
| 7 | 6 |     using MLDatasets: Cora | 
| 8 | 7 |     using Statistics, Random | 
| 9 |  | -    using CUDA | 
| 10 |  | -    CUDA.allowscalar(false) | 
| 11 |  | - | 
|  | 8 | +     | 
| 12 | 9 |     function eval_loss_accuracy(X, y, ids, model, g) | 
| 13 | 10 |         ŷ = model(g, X) | 
| 14 | 11 |         l = logitcrossentropy(ŷ[:, ids], y[:, ids]) | 
|  | 
| 21 | 18 |         η = 5.0f-3            # learning rate | 
| 22 | 19 |         epochs = 10         # number of epochs | 
| 23 | 20 |         seed = 17           # set seed > 0 for reproducibility | 
| 24 |  | -        usecuda = false     # if true use cuda (if available) | 
|  | 21 | +        use_gpu = false     # if true use gpu (if available) | 
| 25 | 22 |         nhidden = 64        # dimension of hidden features | 
| 26 | 23 |     end | 
| 27 | 24 | 
 | 
| 28 | 25 |     function train(Layer; verbose = false, kws...) | 
| 29 | 26 |         args = Args(; kws...) | 
| 30 | 27 |         args.seed > 0 && Random.seed!(args.seed) | 
| 31 | 28 | 
 | 
| 32 |  | -        if args.usecuda && CUDA.functional() | 
| 33 |  | -            device = Flux.gpu | 
| 34 |  | -            args.seed > 0 && CUDA.seed!(args.seed) | 
|  | 29 | +        if args.use_gpu | 
|  | 30 | +            device = gpu_device(force=true) | 
|  | 31 | +            Random.seed!(default_device_rng(device)) | 
| 35 | 32 |         else | 
| 36 |  | -            device = Flux.cpu | 
|  | 33 | +            device = cpu_device() | 
| 37 | 34 |         end | 
| 38 | 35 | 
 | 
| 39 | 36 |         # LOAD DATA | 
| 40 | 37 |         dataset = Cora() | 
| 41 | 38 |         classes = dataset.metadata["classes"] | 
| 42 | 39 |         g = mldataset2gnngraph(dataset) |> device | 
| 43 | 40 |         X = g.ndata.features | 
| 44 |  | -        y = onehotbatch(g.ndata.targets |> cpu, classes) |> device # remove when https://github.com/FluxML/Flux.jl/pull/1959 tagged | 
|  | 41 | +        y = onehotbatch(g.ndata.targets, classes) | 
| 45 | 42 |         train_mask = g.ndata.train_mask | 
| 46 | 43 |         test_mask = g.ndata.test_mask | 
| 47 | 44 |         ytrain = y[:, train_mask] | 
|  | 
| 78 | 75 |         return train_res, test_res | 
| 79 | 76 |     end | 
| 80 | 77 | 
 | 
| 81 |  | -    function train_many(; usecuda = false) | 
|  | 78 | +    function train_many(; use_gpu = false) | 
| 82 | 79 |         for (layer, Layer) in [ | 
| 83 | 80 |             ("GCNConv", (nin, nout) -> GCNConv(nin => nout, relu)), | 
| 84 | 81 |             ("ResGatedGraphConv", (nin, nout) -> ResGatedGraphConv(nin => nout, relu)), | 
|  | 
| 96 | 93 |             ## ("EdgeConv",(nin, nout) -> EdgeConv(Dense(2nin, nout, relu))), # Fits the training set but does not generalize well | 
| 97 | 94 |         ] | 
| 98 | 95 |             @show layer | 
| 99 |  | -            @time train_res, test_res = train(Layer; usecuda, verbose = false) | 
|  | 96 | +            @time train_res, test_res = train(Layer; use_gpu, verbose = false) | 
| 100 | 97 |             # @show train_res, test_res | 
| 101 | 98 |             @test train_res.acc > 94 | 
| 102 | 99 |             @test test_res.acc > 69 | 
| 103 | 100 |         end | 
| 104 | 101 |     end | 
|  | 102 | +end # module | 
|  | 103 | + | 
|  | 104 | +@testitem "training example" setup=[TrainingExampleModule] begin | 
|  | 105 | +    using .TrainingExampleModule | 
|  | 106 | +    train_many() | 
|  | 107 | +end | 
| 105 | 108 | 
 | 
| 106 |  | -    train_many(usecuda = false) | 
| 107 |  | -    # #TODO | 
| 108 |  | -    # if TEST_GPU | 
| 109 |  | -    #     train_many(usecuda = true) | 
| 110 |  | -    # end | 
|  | 109 | +@testitem "training example GPU" setup=[TrainingExampleModule] begin | 
|  | 110 | +    using .TrainingExampleModule | 
|  | 111 | +    train_many(use_gpu = true) | 
| 111 | 112 | end | 
|  | 113 | + | 
0 commit comments