Skip to content

Commit 949453e

Browse files
biubiu0906tyy2064131211xy-Ji
authored
[Model] EGT (#238)
* EGT模型 * add egt * add egt --------- Co-authored-by: Tan Yingying <[email protected]> Co-authored-by: Xingyuan Ji <[email protected]>
1 parent 0871e6c commit 949453e

24 files changed

+3443
-7
lines changed

examples/egt/__init__.py

Whitespace-only changes.

examples/egt/config.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
scheme: pcqm4mv2
2+
model_name: egt_90m
3+
distributed: false # 设置为true启用多GPU分布式训练
4+
batch_size: 512 # 单GPU时的批次大小;多GPU时需调整为总批次大小除以GPU数量
5+
model_height: 24
6+
node_width: 768
7+
edge_width: 64
8+
num_heads: 32
9+
num_epochs: 1
10+
max_lr: 0.0001
11+
attn_dropout: 0.3
12+
lr_warmup_steps: 200000
13+
lr_total_steps: 1000000
14+
node_ffn_multiplier: 1.0
15+
edge_ffn_multiplier: 1.0
16+
upto_hop: 16
17+
dataloader_workers: 1 # 多线程数
18+
scale_degree: true
19+
num_virtual_nodes: 4
20+
svd_random_neg: true
21+
mixed_precision: true # 启用混合精度
22+
use_adaptive_sparse: true # 启用自适应稀疏
23+
sparse_alpha: 0.5 # 稀疏化强度系数

examples/egt/data/dataset_base.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
import tensorlayerx as tlx
3+
from tqdm import tqdm
4+
from pathlib import Path
5+
6+
7+
class DatasetBase:
8+
def __init__(self,
9+
dataset_name,
10+
split,
11+
cache_dir=None,
12+
load_cache_if_exists=True,
13+
**kwargs):
14+
super().__init__(**kwargs)
15+
self.dataset_name = dataset_name
16+
self.split = split
17+
self.cache_dir = cache_dir
18+
19+
self.is_cached = False
20+
if load_cache_if_exists:
21+
self.cache(verbose=0, must_exist=True)
22+
23+
@property
24+
def record_tokens(self):
25+
raise NotImplementedError
26+
27+
def read_record(self, token):
28+
raise NotImplementedError
29+
30+
def __len__(self):
31+
return len(self.record_tokens)
32+
33+
def __getitem__(self, index):
34+
35+
token = self.record_tokens[index]
36+
try:
37+
return self._records[token]
38+
except AttributeError:
39+
record = self.read_record(token)
40+
self._records = {token: record}
41+
return record
42+
except KeyError:
43+
record = self.read_record(token)
44+
self._records[token] = record
45+
return record
46+
47+
def read_all_records(self, verbose=1):
48+
self._records = {}
49+
if verbose:
50+
print(f'Reading all {self.split} records...', flush=True)
51+
for token in tqdm(self.record_tokens):
52+
self._records[token] = self.read_record(token)
53+
else:
54+
for token in self.record_tokens:
55+
self._records[token] = self.read_record(token)
56+
57+
def get_cache_path(self, path=None):
58+
if path is None:
59+
path = self.cache_dir
60+
base_path = (Path(path)/self.dataset_name)/self.split
61+
base_path.mkdir(parents=True, exist_ok=True)
62+
return base_path
63+
64+
def cache_load_and_save(self, base_path, op, verbose):
65+
tokens_path = base_path/'tokens.npy'
66+
records_path = base_path/'records.npy'
67+
68+
if op == 'load':
69+
self._record_tokens = tlx.files.load_npy_to_any(
70+
name=str(tokens_path))
71+
self._records = tlx.files.load_npy_to_any(name=str(records_path))
72+
elif op == 'save':
73+
if tokens_path.exists() and records_path.exists() \
74+
and hasattr(self, '_record_tokens') and hasattr(self, '_records'):
75+
return
76+
self.read_all_records(verbose=verbose)
77+
tlx.files.save_any_to_npy(
78+
save_dict=self.record_tokens, name=str(tokens_path))
79+
tlx.files.save_any_to_npy(
80+
save_dict=self._records, name=str(records_path))
81+
else:
82+
raise ValueError(f'Unknown operation: {op}')
83+
84+
def cache(self, path=None, verbose=1, must_exist=False):
85+
if self.is_cached:
86+
return
87+
88+
base_path = self.get_cache_path(path)
89+
try:
90+
if verbose:
91+
print(
92+
f'Trying to load {self.split} cache from disk...', flush=True)
93+
self.cache_load_and_save(base_path, 'load', verbose)
94+
if verbose:
95+
print(f'Loaded {self.split} cache from disk.', flush=True)
96+
except FileNotFoundError:
97+
if must_exist:
98+
return
99+
100+
if verbose:
101+
print(f'{self.split} cache does not exist! Cacheing...', flush=True)
102+
self.cache_load_and_save(base_path, 'save', verbose)
103+
if verbose:
104+
print(f'Saved {self.split} cache to disk.', flush=True)
105+
106+
self.is_cached = True

examples/egt/data/graph_dataset.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import tensorlayerx as tlx
2+
import numpy as np
3+
4+
from .dataset_base import DatasetBase
5+
6+
from .stack_with_pad import stack_with_pad
7+
from collections import defaultdict
8+
from numba.typed import List
9+
10+
11+
class GraphDataset(DatasetBase):
12+
def __init__(self,
13+
num_nodes_key='num_nodes',
14+
edges_key='edges',
15+
node_features_key='node_features',
16+
edge_features_key='edge_features',
17+
node_mask_key='node_mask',
18+
targets_key='target',
19+
include_node_mask=True,
20+
**kwargs):
21+
super().__init__(**kwargs)
22+
self.num_nodes_key = num_nodes_key
23+
self.edges_key = edges_key
24+
self.node_features_key = node_features_key
25+
self.edge_features_key = edge_features_key
26+
self.node_mask_key = node_mask_key
27+
self.targets_key = targets_key
28+
self.include_node_mask = include_node_mask
29+
30+
31+
32+
def __getitem__(self, index):
33+
item = super().__getitem__(index)
34+
if self.include_node_mask:
35+
item = item.copy()
36+
item[self.node_mask_key] = np.ones((item[self.num_nodes_key],), dtype=np.uint8)
37+
return item
38+
39+
def _calculate_max_nodes(self):
40+
self._max_nodes = self[0][self.num_nodes_key]
41+
self._max_nodes_index = 0
42+
for i in range(1, super().__len__()):
43+
graph = super().__getitem__(i)
44+
cur_nodes = graph[self.num_nodes_key]
45+
if cur_nodes > self._max_nodes:
46+
self._max_nodes = cur_nodes
47+
self._max_nodes_index = i
48+
49+
@property
50+
def max_nodes(self):
51+
try:
52+
return self._max_nodes
53+
except AttributeError:
54+
self._calculate_max_nodes()
55+
return self._max_nodes
56+
57+
@property
58+
def max_nodes_index(self):
59+
try:
60+
return self._max_nodes_index
61+
except AttributeError:
62+
self._calculate_max_nodes()
63+
return self._max_nodes_index
64+
65+
def cache_load_and_save(self, base_path, op, verbose):
66+
super().cache_load_and_save(base_path, op, verbose)
67+
max_nodes_path = base_path / 'max_nodes_data.npy'
68+
69+
if op == 'load':
70+
# 替换 torch.load 为 np.load
71+
max_nodes_data = tlx.files.load_npy_to_any(name=str(max_nodes_path))
72+
self._max_nodes = max_nodes_data['max_nodes']
73+
self._max_nodes_index = max_nodes_data['max_nodes_index']
74+
elif op == 'save':
75+
if verbose:
76+
print(f'Calculating {self.split} max nodes...', flush=True)
77+
max_nodes_data = {'max_nodes': self.max_nodes,
78+
'max_nodes_index': self.max_nodes_index}
79+
# 替换 torch.save 为 np.savez
80+
tlx.files.save_any_to_npy(save_dict=max_nodes_data, name=str(max_nodes_path))
81+
else:
82+
raise ValueError(f'Unknown operation: {op}')
83+
84+
def max_batch(self, batch_size, collate_fn):
85+
return collate_fn([self.__getitem__(self.max_nodes_index)] * batch_size)
86+
87+
88+
def graphdata_collate(batch):
89+
batch_data = defaultdict(List)
90+
for elem in batch:
91+
for k, v in elem.items():
92+
batch_data[k].append(v)
93+
94+
# 替换 torch.from_numpy 为 tlx.convert_to_tensor
95+
out = {k: tlx.convert_to_tensor(stack_with_pad(dat))
96+
for k, dat in batch_data.items()}
97+
return out
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
import numba as nb
3+
4+
5+
@nb.njit
6+
def stack_with_pad_4d(inputs):
7+
num_elem = len(inputs)
8+
ms_0, ms_1, ms_2, ms_3 = inputs[0].shape
9+
10+
for i in range(1,num_elem):
11+
is_0, is_1, is_2, is_3 = inputs[i].shape
12+
ms_0 = max(is_0, ms_0)
13+
ms_1 = max(is_1, ms_1)
14+
ms_2 = max(is_2, ms_2)
15+
ms_3 = max(is_3, ms_3)
16+
17+
stacked_shape = (num_elem,ms_0,ms_1,ms_2,ms_3)
18+
stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype)
19+
20+
for i, elem in enumerate(inputs):
21+
stacked[i][:elem.shape[0],:elem.shape[1],:elem.shape[2],:elem.shape[3]] = elem
22+
return stacked
23+
24+
@nb.njit
25+
def stack_with_pad_3d(inputs):
26+
num_elem = len(inputs)
27+
ms_0, ms_1, ms_2 = inputs[0].shape
28+
29+
for i in range(1,num_elem):
30+
is_0, is_1, is_2 = inputs[i].shape
31+
ms_0 = max(is_0, ms_0)
32+
ms_1 = max(is_1, ms_1)
33+
ms_2 = max(is_2, ms_2)
34+
35+
stacked_shape = (num_elem,ms_0,ms_1,ms_2)
36+
stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype)
37+
38+
for i, elem in enumerate(inputs):
39+
stacked[i][:elem.shape[0],:elem.shape[1],:elem.shape[2]] = elem
40+
return stacked
41+
42+
@nb.njit
43+
def stack_with_pad_2d(inputs):
44+
num_elem = len(inputs)
45+
ms_0, ms_1 = inputs[0].shape
46+
47+
for i in range(1,num_elem):
48+
is_0, is_1 = inputs[i].shape
49+
ms_0 = max(is_0, ms_0)
50+
ms_1 = max(is_1, ms_1)
51+
52+
stacked_shape = (num_elem,ms_0,ms_1)
53+
stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype)
54+
55+
for i, elem in enumerate(inputs):
56+
stacked[i][:elem.shape[0],:elem.shape[1]] = elem
57+
return stacked
58+
59+
@nb.njit
60+
def stack_with_pad_1d(inputs):
61+
num_elem = len(inputs)
62+
ms_0 = inputs[0].shape[0]
63+
64+
for i in range(1,num_elem):
65+
is_0 = inputs[i].shape[0]
66+
ms_0 = max(is_0, ms_0)
67+
68+
stacked_shape = (num_elem,ms_0)
69+
stacked = np.zeros(stacked_shape, dtype=inputs[0].dtype)
70+
71+
for i, elem in enumerate(inputs):
72+
stacked[i][:elem.shape[0]] = elem
73+
return stacked
74+
75+
76+
def stack_with_pad(inputs):
77+
shape_rank = np.ndim(inputs[0])
78+
if shape_rank == 0:
79+
return np.stack(inputs)
80+
if shape_rank == 1:
81+
return stack_with_pad_1d(inputs)
82+
elif shape_rank == 2:
83+
return stack_with_pad_2d(inputs)
84+
elif shape_rank == 3:
85+
return stack_with_pad_3d(inputs)
86+
elif shape_rank == 4:
87+
return stack_with_pad_4d(inputs)
88+
else:
89+
raise ValueError('Only support up to 4D tensor')
90+
91+
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import numpy as np
2+
import numba as nb
3+
4+
from .graph_dataset import GraphDataset
5+
6+
NODE_FEATURES_OFFSET = 128
7+
EDGE_FEATURES_OFFSET = 8
8+
9+
@nb.njit
10+
def floyd_warshall(A):
11+
n = A.shape[0]
12+
D = np.zeros((n,n), dtype=np.int16)
13+
14+
for i in range(n):
15+
for j in range(n):
16+
if i == j:
17+
pass
18+
elif A[i,j] == 0:
19+
D[i,j] = 510
20+
else:
21+
D[i,j] = 1
22+
23+
for k in range(n):
24+
for i in range(n):
25+
for j in range(n):
26+
old_dist = D[i,j]
27+
new_dist = D[i,k] + D[k,j]
28+
if new_dist < old_dist:
29+
D[i,j] = new_dist
30+
return D
31+
32+
@nb.njit
33+
def preprocess_data(num_nodes, edges, node_feats, edge_feats):
34+
node_feats = node_feats + np.arange(1,node_feats.shape[-1]*NODE_FEATURES_OFFSET+1,
35+
NODE_FEATURES_OFFSET,dtype=np.int16)
36+
edge_feats = edge_feats + np.arange(1,edge_feats.shape[-1]*EDGE_FEATURES_OFFSET+1,
37+
EDGE_FEATURES_OFFSET,dtype=np.int16)
38+
39+
A = np.zeros((num_nodes,num_nodes),dtype=np.int16)
40+
E = np.zeros((num_nodes,num_nodes,edge_feats.shape[-1]),dtype=np.int16)
41+
for k in range(edges.shape[0]):
42+
i,j = edges[k,0], edges[k,1]
43+
A[i,j] = 1
44+
E[i,j] = edge_feats[k]
45+
46+
D = floyd_warshall(A)
47+
return node_feats, D, E
48+
49+
50+
class StructuralDataset(GraphDataset):
51+
def __init__(self,
52+
distance_matrix_key = 'distance_matrix',
53+
feature_matrix_key = 'feature_matrix',
54+
**kwargs):
55+
super().__init__(**kwargs)
56+
self.distance_matrix_key = distance_matrix_key
57+
self.feature_matrix_key = feature_matrix_key
58+
59+
def __getitem__(self, index):
60+
item = super().__getitem__(index)
61+
62+
num_nodes = int(item[self.num_nodes_key])
63+
edges = item.pop(self.edges_key)
64+
node_feats = item.pop(self.node_features_key)
65+
edge_feats = item.pop(self.edge_features_key)
66+
67+
node_feats, dist_mat, edge_feats_mat = preprocess_data(num_nodes, edges, node_feats, edge_feats)
68+
item[self.node_features_key] = node_feats
69+
item[self.distance_matrix_key] = dist_mat
70+
item[self.feature_matrix_key] = edge_feats_mat
71+
72+
return item
73+

0 commit comments

Comments
 (0)