@@ -12,7 +12,7 @@ We create a dataset consisting in multiple random graphs and associated data fea
1212
1313``` julia 
1414using  GNNLux, Lux, Statistics, MLUtils, Random
15- using  Zygote, Optimizers 
15+ using  Zygote, Optimisers 
1616
1717all_graphs =  GNNGraph[]
1818
@@ -22,61 +22,49 @@ for _ in 1:1000
2222            gdata= (; y =  randn (Float32)))         #  Regression target   
2323    push! (all_graphs, g)
2424end 
25- ``` 
2625
27- ###  Model building   
26+ train_graphs, test_graphs  =  MLUtils . splitobs (all_graphs, at = 0.8 ) 
2827
29- We concisely define our model as a  [ ` GNNLux.GNNChain ` ] ( @ref )  containing two graph convolutional layers. If CUDA is available, our model will live on the gpu. 
28+ #  g = rand_graph(10, 40, ndata=(; x = randn(Float32, 16,10)), gdata=(; y = randn(Float32)))  
3029
31- ``` julia 
32- device =  CUDA. functional () ?  Lux. gpu_device () :  Lux. cpu_device ()
3330rng =  Random. default_rng ()
3431
3532model =  GNNChain (GCNConv (16  =>  64 ),
3633                x ->  relu .(x),     
3734                GCNConv (64  =>  64 , relu),
38-                 GlobalMeanPool (),   #  Aggregate node-wise features into graph-wise features 
35+                 x  ->   mean (x, dims = 2 ), 
3936                Dense (64 , 1 )) 
4037
4138ps, st =  LuxCore. setup (rng, model)
42- ``` 
43- 
44- ### Training   
4539
46- 
47- ``` julia 
48- train_graphs, test_graphs =  MLUtils. splitobs (all_graphs, at= 0.8 )
49- 
50- train_loader =  MLUtils. DataLoader (train_graphs, 
51-                 batchsize= 32 , shuffle= true , collate= true )
52- test_loader =  MLUtils. DataLoader (test_graphs, 
53-                 batchsize= 32 , shuffle= false , collate= true )
54- 
55- for  epoch in  1 : 100 
56-     for  g in  train_loader
57-         g =  g |>  device
58-         grad =  gradient (model ->  loss (model, g), model)
59-         Flux. update! (opt, model, grad[1 ])
60-     end 
61- 
62-     @info  (; epoch, train_loss= loss (model, train_loader), test_loss= loss (model, test_loader))
40+ function  custom_loss (model, ps,st,tuple)
41+     g,x,y =  tuple
42+     y_pred,st =  model (g, x, ps, st)  
43+     return  MSELoss ()(y_pred, y), (layers =  st,), 0 
6344end 
6445
65- function  train_model! (model, ps, st, train_loader)
66-     train_state =  Lux. Training. TrainState (model, ps, st, Adam (0.001f0 ))
6746
68-     for  iter in  1 : 1000 
69-         for  g in  train_loader
70-             _, loss, _, train_state =  Lux. Training. single_train_step! (AutoZygote (), MSELoss (),
71-                 ((g, g. x). .. ,g. y), train_state)
72-             if  iter %  100  ==  1  ||  iter ==  1000 
73-                 @info  " Iteration: %04d \t  Loss: %10.9g\n " 
47+ function  train_model! (model, ps, st, train_graphs, test_graphs)
48+     train_state =  Lux. Training. TrainState (model, ps, st, Adam (0.0001f0 ))
49+     loss= 0 
50+     for  iter in  1 : 100 
51+         for  g in  train_graphs
52+             _, loss, _, train_state =  Lux. Training. single_train_step! (AutoZygote (), custom_loss,(g, g. x, g. y), train_state)
53+         end 
54+         if  iter %  10  ==  0  ||  iter ==  100 
55+             st_ =  Lux. testmode (train_state. states)
56+             test_loss = 0 
57+             for  g in  test_graphs
58+                 ŷ, st_ =  model (g, g. x, train_state. parameters, st_)
59+                 st_ =  (layers =  st_,)
60+                 test_loss +=  MSELoss ()(g. y,ŷ)
7461            end 
62+             test_loss =  test_loss/ length (test_graphs)
63+             @info  (; iter, loss, test_loss)
7564        end 
7665    end 
7766
7867    return  model, ps, st
7968end 
8069
81- train_model! (model, ps, st, train_loader)
82- ``` 
70+ train_model! (model, ps, st, train_graphs, test_graphs)
0 commit comments