Skip to content

Commit 5bf82a5

Browse files
committed
do_forward -> train_step and test_step;fix some typos
1 parent 7b74d97 commit 5bf82a5

File tree

8 files changed

+89
-59
lines changed

8 files changed

+89
-59
lines changed

README.md

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,25 +109,32 @@ Otherwise you can also use `model.show('model')` or `model.show('train')` to sho
109109
NOTE: you should install texttable first.
110110

111111
## Visualization
112+
NOTE: you must install [SciencePlots](https://github.com/garrettj403/SciencePlots) package for a better preview.
112113
+ Accuracy
113114
```python
114115
import matplotlib.pyplot as plt
115-
plt.plot(his.history['acc'])
116-
plt.plot(his.history['val_acc'])
117-
plt.legend(['Accuracy', 'Val Accuracy'])
118-
plt.xlabel('Epochs')
119-
plt.show()
116+
with plt.style.context(['science', 'no-latex']):
117+
plt.plot(his.history['acc'])
118+
plt.plot(his.history['val_acc'])
119+
plt.legend(['Train Accuracy', 'Val Accuracy'])
120+
plt.ylabel('Accuracy')
121+
plt.xlabel('Epochs')
122+
plt.autoscale(tight=True)
123+
plt.show()
120124
```
121125
![visualization](https://github.com/EdisonLeeeee/GraphGallery/blob/master/imgs/visualization_acc.png)
122126

123127
+ Loss
124128
```python
125129
import matplotlib.pyplot as plt
126-
plt.plot(his.history['loss'])
127-
plt.plot(his.history['val_loss'])
128-
plt.legend(['Loss', 'Val Loss'])
129-
plt.xlabel('Epochs')
130-
plt.show()
130+
with plt.style.context(['science', 'no-latex']):
131+
plt.plot(his.history['loss'])
132+
plt.plot(his.history['val_loss'])
133+
plt.legend(['Train Loss', 'Val Loss'])
134+
plt.ylabel('Loss')
135+
plt.xlabel('Epochs')
136+
plt.autoscale(tight=True)
137+
plt.show()
131138
```
132139
![visualization](https://github.com/EdisonLeeeee/GraphGallery/blob/master/imgs/visualization_loss.png)
133140

graphgallery/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@
2323
from graphgallery import data
2424

2525

26-
__version__ = '0.1.9'
26+
__version__ = '0.1.10'
2727

2828
__all__ = ['graphgallery', 'nn', 'utils', 'sequence', 'data', '__version__']

graphgallery/nn/models/basemodel.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import random
33
import logging
44

5+
import os.path as osp
56
import numpy as np
67
import tensorflow as tf
78
import scipy.sparse as sp
89

910
from graphgallery import config, check_and_convert, asintarr, Bunch
1011
from graphgallery.utils.type_check import is_list_like
1112
from graphgallery.utils.misc import print_table
13+
from graphgallery.data.utils import makedirs
14+
1215

1316

1417
class BaseModel:
@@ -97,8 +100,8 @@ def __init__(self, adj, x, labels=None, device="CPU:0", seed=None, name=None, **
97100
self.__sparse = is_adj_sparse
98101

99102
# log path
100-
self.weight_dir = "/tmp/weight"
101-
self.weight_path = f"{self.weight_dir}/{name}_weights"
103+
self.weight_dir = osp.expanduser(osp.normpath("/tmp/weight"))
104+
self.weight_path = osp.join(self.weight_dir, f"{name}_weights")
102105

103106
# data types, default: `float32` and `int64`
104107
self.floatx = config.floatx()
@@ -170,20 +173,19 @@ def _check_inputs(self, adj, x):
170173
raise RuntimeError(f"The adjacency matrix should be N by N square matrix.")
171174
return adj, x
172175

173-
@property
174-
def model(self):
175-
return self.__model
176-
177-
################### TODO: This may cause ERROR #############
178176
def __getattr__(self, attr):
177+
################### TODO: This may cause ERROR #############
179178
try:
180179
return self.__dict__[attr]
181180
except KeyError:
182181
if hasattr(self.model, attr):
183182
return getattr(self.model, attr)
184183
raise AttributeError(f"'{self.name}' and '{self.name}.model' objects have no attribute '{attr}'")
185184

186-
185+
@property
186+
def model(self):
187+
return self.__model
188+
187189
@model.setter
188190
def model(self, m):
189191
# Back up
@@ -194,8 +196,8 @@ def model(self, m):
194196

195197

196198
def save(self, path=None, as_model=False):
197-
if not os.path.exists(self.weight_dir):
198-
os.makedirs(self.weight_dir)
199+
if not osp.exists(self.weight_dir):
200+
makedirs(self.weight_dir)
199201
logging.log(logging.WARNING, f"Make Directory in {self.weight_dir}")
200202

201203
if path is None:
@@ -216,6 +218,7 @@ def save(self, path=None, as_model=False):
216218
def load(self, path=None, as_model=False):
217219
if not path:
218220
path = self.weight_path
221+
219222
if not path.endswith('.h5'):
220223
path += '.h5'
221224
if as_model:

graphgallery/nn/models/semisupervised/sbvat.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -128,13 +128,6 @@ def build(self, hiddens=[16], activations=['relu'], dropouts=[0.5],
128128
self.epsilon = epsilon # Norm length for (virtual) adversarial training
129129
self.n_power_iterations = n_power_iterations # Number of power iterations
130130

131-
# def propagation(self, x, adj, training=True):
132-
# h = x
133-
# for layer in self.GCN_layers:
134-
# h = self.dropout_layer(h, training=training)
135-
# h = layer([h, adj])
136-
# return h
137-
138131
def propagation(self, x, adj, training=True):
139132
h = x
140133
for dropout_layer, GCN_layer in zip(self.dropout_layers, self.GCN_layers[:-1]):
@@ -145,7 +138,7 @@ def propagation(self, x, adj, training=True):
145138
return h
146139

147140
@tf.function
148-
def do_train_forward(self, sequence):
141+
def train_step(self, sequence):
149142

150143
with tf.device(self.device):
151144
self.train_metric.reset_states()
@@ -171,7 +164,7 @@ def do_train_forward(self, sequence):
171164
return loss, self.train_metric.result()
172165

173166
@tf.function
174-
def do_test_forward(self, sequence):
167+
def test_step(self, sequence):
175168

176169
with tf.device(self.device):
177170
self.test_metric.reset_states()
@@ -186,14 +179,6 @@ def do_test_forward(self, sequence):
186179

187180
return loss, self.test_metric.result()
188181

189-
def do_forward(self, sequence, training=True):
190-
if training:
191-
loss, accuracy = self.do_train_forward(sequence)
192-
else:
193-
loss, accuracy = self.do_test_forward(sequence)
194-
195-
return loss.numpy(), accuracy.numpy()
196-
197182
def virtual_adversarial_loss(self, x, adj, logit, adv_mask):
198183
d = tf.random.normal(shape=tf.shape(x), dtype=self.floatx)
199184

graphgallery/nn/models/semisupervised/semi_supervised_model.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def train_v1(self, idx_train, idx_val=None,
220220
if self.do_before_train:
221221
self.do_before_train()
222222

223-
loss, accuracy = self.do_forward(train_data)
223+
loss, accuracy = self.train_step(train_data)
224224
train_data.on_epoch_end()
225225

226226
history.add_results(loss, 'loss')
@@ -230,7 +230,7 @@ def train_v1(self, idx_train, idx_val=None,
230230
if self.do_before_validation:
231231
self.do_before_validation()
232232

233-
val_loss, val_accuracy = self.do_forward(val_data, training=False)
233+
val_loss, val_accuracy = self.test_step(val_data)
234234

235235
history.add_results(val_loss, 'val_loss')
236236
history.add_results(val_accuracy, 'val_acc')
@@ -409,7 +409,7 @@ def train(self, idx_train, idx_val=None,
409409
self.do_before_train()
410410

411411
callbacks.on_train_batch_begin(0)
412-
loss, accuracy = self.do_forward(train_data)
412+
loss, accuracy = self.train_step(train_data)
413413
train_data.on_epoch_end()
414414

415415
training_logs = {'loss': loss, 'acc': accuracy}
@@ -419,7 +419,7 @@ def train(self, idx_train, idx_val=None,
419419
if self.do_before_validation:
420420
self.do_before_validation()
421421

422-
val_loss, val_accuracy = self.do_forward(val_data, training=False)
422+
val_loss, val_accuracy = self.test_step(val_data)
423423
training_logs.update({'val_loss': val_loss, 'val_acc': val_accuracy})
424424

425425
callbacks.on_epoch_end(epoch, training_logs)
@@ -584,7 +584,7 @@ def train_v2(self, idx_train, idx_val=None,
584584
self.do_before_train()
585585

586586
callbacks.on_train_batch_begin(0)
587-
loss, accuracy = self.do_forward(train_data)
587+
loss, accuracy = self.train_step(train_data)
588588
train_data.on_epoch_end()
589589

590590
training_logs = {'loss': loss, 'acc': accuracy}
@@ -594,7 +594,7 @@ def train_v2(self, idx_train, idx_val=None,
594594
if self.do_before_validation:
595595
self.do_before_validation()
596596

597-
val_loss, val_accuracy = self.do_forward(val_data, training=False)
597+
val_loss, val_accuracy = self.test_step(val_data)
598598
training_logs.update({'val_loss': val_loss, 'val_acc': val_accuracy})
599599

600600
callbacks.on_epoch_end(epoch, training_logs)
@@ -650,16 +650,15 @@ def test(self, index, **kwargs):
650650
if self.do_before_test:
651651
self.do_before_test(**kwargs)
652652

653-
loss, accuracy = self.do_forward(test_data, training=False)
653+
loss, accuracy = self.test_step(test_data)
654654

655655
return loss, accuracy
656656

657-
def do_forward(self, sequence, training=True):
657+
def train_step(self, sequence):
658658
"""
659659
Forward propagation for the input `sequence`. This method will be called
660-
in `train` and `test`, you can rewrite it for you customized training/testing
661-
process. If you want to specify your customized data during traing/testing/predicting,
662-
you can implement a sub-class of `graphgallery.NodeSequence`, wich is iterable
660+
in `train`. If you want to specify your customized data during traing/testing/predicting,
661+
you can implement a subclass of `graphgallery.NodeSequence`, wich is iterable
663662
and yields `inputs` and `labels` in each iteration.
664663
665664
@@ -672,8 +671,6 @@ def do_forward(self, sequence, training=True):
672671
----------
673672
sequence: `graphgallery.NodeSequence`
674673
The input `sequence`.
675-
trainng (Boolean, optional):
676-
Indicating training or test procedure. (default: :obj:`True`)
677674
678675
Return:
679676
----------
@@ -684,20 +681,49 @@ def do_forward(self, sequence, training=True):
684681
685682
"""
686683
model = self.model
684+
model.reset_metrics()
687685

688-
if training:
689-
forward_fn = model.train_on_batch
690-
else:
691-
forward_fn = model.test_on_batch
686+
with tf.device(self.device):
687+
for inputs, labels in sequence:
688+
loss, accuracy = model.train_on_batch(x=inputs, y=labels, reset_metrics=False)
689+
690+
return loss, accuracy
691+
692+
def test_step(self, sequence):
693+
"""
694+
Forward propagation for the input `sequence`. This method will be called
695+
in `test`. If you want to specify your customized data during traing/testing/predicting,
696+
you can implement a subclass of `graphgallery.NodeSequence`, wich is iterable
697+
and yields `inputs` and `labels` in each iteration.
698+
699+
700+
Note:
701+
----------
702+
You must compile your model before training/testing/predicting.
703+
Use `model.build()`.
692704
705+
Arguments:
706+
----------
707+
sequence: `graphgallery.NodeSequence`
708+
The input `sequence`.
709+
710+
Return:
711+
----------
712+
loss: Float scalar
713+
Output loss of forward propagation.
714+
accuracy: Float scalar
715+
Output accuracy of prediction.
716+
717+
"""
718+
model = self.model
693719
model.reset_metrics()
694720

695721
with tf.device(self.device):
696722
for inputs, labels in sequence:
697-
loss, accuracy = forward_fn(x=inputs, y=labels, reset_metrics=False)
723+
loss, accuracy = model.test_on_batch(x=inputs, y=labels, reset_metrics=False)
698724

699725
return loss, accuracy
700-
726+
701727
def predict(self, index, **kwargs):
702728
"""
703729
Predict the output probability for the `index` of nodes.

graphgallery/nn/models/semisupervised/sgc.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,19 @@ def preprocess(self, adj, x):
7070
if self.norm_x:
7171
x = normalize_x(x, norm=self.norm_x)
7272

73-
# InvalidArgumentError: Cannot use GPU when output.shape[1] * nnz(a) > 2^31 [Op:SparseTensorDenseMatMul]
74-
with tf.device(self.device):
73+
74+
# To avoid this tensorflow error in large dataset:
75+
# InvalidArgumentError: Cannot use GPU when output.shape[1] * nnz(a) > 2^31 [Op:SparseTensorDenseMatMul]
76+
if self.n_features*adj.nnz>2**31:
77+
device = "CPU"
78+
else:
79+
device = self.device
80+
81+
with tf.device(device):
7582
x, adj = astensors([x, adj])
7683
x = SGConvolution(order=self.order)([x, adj])
84+
85+
with tf.device(self.device):
7786
self.x_norm, self.adj_norm = x, adj
7887

7988
def build(self, lr=0.2, l2_norms=[5e-5], use_bias=True):

imgs/visualization_acc.png

-4.64 KB
Loading

imgs/visualization_loss.png

-4.32 KB
Loading

0 commit comments

Comments
 (0)