forked from HekpoMaH/Neural-Bipartite-Matching
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
249 lines (214 loc) · 10.2 KB
/
utils.py
File metadata and controls
249 lines (214 loc) · 10.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
import time
import signal
import torch
import torch_geometric
from torch_geometric.data import DataLoader
import models
import flow_datasets
from hyperparameters import get_hyperparameters
from torch_cluster import random_walk
def get_mask_to_process(continue_p, batch_ids, edge_ids, debug=False):
"""
Used for graphs with different number of steps needed to be performed
Returns:
mask (1d tensor): The mask for which nodes still need to be processed
"""
if debug:
print("Getting mask processing")
print("Continue p:", continue_p)
mask = continue_p[batch_ids] > 0.5
edge_mask = mask[edge_ids] > 0.5
if debug:
print("Mask:", mask)
return mask, edge_mask
def get_adj_matrix(edge_index):
return torch_geometric.utils.to_dense_adj(edge_index).squeeze().bool()
def get_adj_flow_matrix(size, edge_index, capacities):
return (torch_geometric.utils.to_dense_adj(edge_index).squeeze().bool(),
torch_geometric.utils.to_dense_adj(edge_index, edge_attr=capacities).squeeze())
def flip_edge_index(edge_index):
return torch.stack((edge_index[1], edge_index[0]), dim=0)
def get_true_termination(batch, x_curr, y_curr):
""" Gets termination values per each graph"""
true_termination = torch.stack(
[(~(x_curr[batch.batch == btch] == y_curr[batch.batch == btch]).all()).float()
for btch in range(batch.num_graphs)],
)
return true_termination
def split_per_graph(batch_ids, to_split, num_graphs=None):
""" Splits a value into subvalues per each graph """
if num_graphs is None:
num_graphs = batch_ids.max()+1
splitted = torch.stack([to_split[batch_ids == btch] for btch in range(num_graphs)])
return splitted
def get_graph_embedding(batch_ids, latent_nodes, GRAPH_SIZES, reduction='mean'):
""" Gets the embedding of each graph in batch """
graph_embs = split_per_graph(batch_ids, latent_nodes, num_graphs=len(GRAPH_SIZES)).sum(1)
graph_embs /= GRAPH_SIZES.unsqueeze(1)
return graph_embs
def interrupted(_interrupted=[False], _default=[None]):
if _default[0] is None or signal.getsignal(signal.SIGINT) == _default[0]:
_interrupted[0] = False
def handle(signal, frame):
if _interrupted[0] and _default[0] is not None:
_default[0](signal, frame)
print('Interrupt!')
_interrupted[0] = True
_default[0] = signal.signal(signal.SIGINT, handle)
return _interrupted[0]
def add_self_loops(batch):
# edge_attr = batch.edge_attr[:, 0]
# edge_cap = batch.edge_attr[:, 1]
edge_attr = batch.edge_attr
new_edge_index, edge_attr = torch_geometric.utils.add_self_loops(batch.edge_index, edge_attr, fill_value=0)
# _, edge_cap = torch_geometric.utils.add_self_loops(batch.edge_index, edge_cap, fill_value=0)
# batch.edge_attr = torch.cat((edge_attr.view(-1, 1), edge_cap.view(-1, 1)), dim=1)
batch.edge_attr = edge_attr
batch.edge_index = new_edge_index
return batch
def get_sizes_and_source(batch):
DEVICE = get_hyperparameters()["device"]
GRAPH_SIZES = torch.unique(batch.batch, return_counts=True)[1].to(DEVICE)
SOURCE_NODES = (GRAPH_SIZES.cumsum(0)-GRAPH_SIZES).clone().detach()
return GRAPH_SIZES, SOURCE_NODES
def get_sizes_and_source_sink(batch):
DEVICE = get_hyperparameters()["device"]
GRAPH_SIZES = torch.unique(batch.batch, return_counts=True)[1].to(DEVICE)
SOURCE_NODES = (GRAPH_SIZES.cumsum(0)-GRAPH_SIZES).clone().detach()
SINK_NODES = (GRAPH_SIZES.cumsum(0)-1).clone().detach()
return GRAPH_SIZES, SOURCE_NODES, SINK_NODES
def finish(x, y, batch_ids, steps, STEPS_SIZE, GRAPH_SIZES):
"""
Returns whether it's a final iteration or not in real task
Returns true/false value per graph (as a mask)
N.B. Not what the network thinks
"""
DEVICE = get_hyperparameters()["device"]
if steps == 0:
return torch.ones(len(GRAPH_SIZES), device=DEVICE)
if not steps < STEPS_SIZE-1:
return torch.zeros(len(GRAPH_SIZES), device=DEVICE)
x_curr = torch.index_select(x, 1, torch.tensor([steps], dtype=torch.long, device=DEVICE)).squeeze(1).to(DEVICE)
y_curr = torch.index_select(y, 1, torch.tensor([steps], dtype=torch.long, device=DEVICE)).squeeze(1).to(DEVICE)
noteq = (~(x_curr == y_curr))
hyperparameters = get_hyperparameters()
batches_inside = batch_ids.max()+1
noteq_batched = noteq.view(batches_inside, -1, hyperparameters["dim_target"])
true_termination = noteq_batched.any(dim=1).any(dim=-1).float()
return true_termination
def get_input(batch, EPSILON, train, x_curr, last_output): # Always returns last output
inp = last_output
assert not x_curr.requires_grad
assert not x_curr[:, 0].requires_grad
return inp
def get_print_info(augmenting_path_network):
total_loss_dist, total_loss_pred, total_loss_term, findmin = augmenting_path_network.get_validation_losses()
mean_step, final_step, tnr, subtract_acc = augmenting_path_network.get_validation_accuracies()
total_loss = total_loss_dist + total_loss_pred + total_loss_term
broken_invariants, broken_reachabilities, broken_flows, broken_all = augmenting_path_network.get_broken_invariants()
len_broken = len(broken_invariants)
return total_loss_dist, total_loss_pred, total_loss_term, findmin, total_loss, mean_step, final_step, tnr, subtract_acc, broken_invariants, broken_reachabilities, broken_flows, broken_all, len_broken
def iterate_over(processor, optimizer=None, test=False):
hyperparameters = get_hyperparameters()
DEVICE = hyperparameters["device"]
BATCH_SIZE = hyperparameters["batch_size"]
for algorithm in processor.algorithms.values():
if processor.training:
algorithm.iterator = iter(DataLoader(algorithm.train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, num_workers=8))
else:
algorithm.iterator = iter(DataLoader(algorithm.test_dataset if test
else algorithm.val_dataset, batch_size=BATCH_SIZE,
shuffle=False, drop_last=False, num_workers=8))
if not processor.training:
for algorithm in processor.algorithms.values():
algorithm.zero_validation_stats()
try:
while True:
for algorithm in processor.algorithms.values():
batch = next(algorithm.iterator)
batch.to(DEVICE)
EPS_I = 0
start = time.time()
with torch.set_grad_enabled(processor.training):
output = algorithm.process(batch, EPS_I)
if not processor.training:
algorithm.update_validation_stats(batch, output)
if processor.training:
processor.update_weights(optimizer)
if interrupted():
break
except StopIteration: # datasets should be the same size
pass
for algorithm in processor.algorithms.values(): # for when they are not
if not processor.training:
algorithm.zero_tracking_losses_and_statistics()
try:
while True:
batch = next(algorithm.iterator)
batch.to(DEVICE)
EPS_I = 0
start = time.time()
with torch.set_grad_enabled(processor.training):
output = algorithm.process(batch, EPS_I)
if not processor.training:
algorithm.update_validation_stats(batch, output)
if processor.training:
processor.update_weights(optimizer)
except StopIteration:
pass
def load_algorithms(algorithms, processor, use_ints):
hyperparameters = get_hyperparameters()
DEVICE = hyperparameters["device"]
DIM_LATENT = hyperparameters["dim_latent"]
DIM_NODES_BFS = hyperparameters["dim_nodes_BFS"]
DIM_NODES_AugmentingPath = hyperparameters["dim_nodes_AugmentingPath"]
DIM_EDGES = hyperparameters["dim_edges"]
DIM_EDGES_BFS = hyperparameters["dim_edges_BFS"]
DIM_BITS = hyperparameters["dim_bits"] if use_ints else None
for algorithm in algorithms:
if algorithm == "AugmentingPath":
algo_net = models.AugmentingPathNetwork(DIM_LATENT, DIM_NODES_AugmentingPath, DIM_EDGES, processor, flow_datasets.SingleIterationDataset, './all_iter', bias=hyperparameters["bias"], use_ints=use_ints, bits_size=DIM_BITS).to(DEVICE)
if algorithm == "BFS":
algo_net = models.BFSNetwork(DIM_LATENT, DIM_NODES_BFS, DIM_EDGES_BFS, processor, flow_datasets.BFSSingleIterationDataset, './bfs').to(DEVICE)
processor.add_algorithm(algo_net, algorithm)
def integer2bit(integer, num_bits=8):
# Credit: https://github.com/KarenUllrich/pytorch-binary-converter/blob/master/binary_converter.py
"""Turn integer tensor to binary representation.
Args:
integer : torch.Tensor, tensor with integers
num_bits : Number of bits to specify the precision. Default: 8.
Returns:
Tensor: Binary tensor. Adds last dimension to original tensor for
bits.
"""
dtype = integer.type()
exponent_bits = -torch.arange(-(num_bits - 1), 1).type(dtype)
exponent_bits = exponent_bits.repeat(integer.shape + (1,))
out = integer.unsqueeze(-1) / 2 ** exponent_bits
return (out - (out % 1)) % 2
_POWERS_OF_2 = torch.tensor([128, 64, 32, 16, 8, 4, 2, 1], device=get_hyperparameters()["device"])
def bit2integer(bit_logits):
bits = (bit_logits > 0).long()
bits *= _POWERS_OF_2
ints = bits.sum(dim=1).float()
return ints
def create_inv_edge_index(batch_size, size, edge_index):
iei = torch.full((batch_size*size, batch_size*size), -100, dtype=torch.long)
for i in range(len(edge_index[0])):
iei[edge_index[0][i]][edge_index[1][i]] = i
return iei
def get_print_format():
fmt = """
==========================
Mean step acc: {:.4f} Last step acc: {:.4f}
Mincap TNR: {:.4f}
Subtract accuracy: {:.4f}
loss-(dist,pred,term,findmin,total): {:.4f} {:.4f} {:.4f} {:.4f} {:.4f}
broken_invariants: {:2d}/{:3d}
broken_all: {:2d}/{:3d}
broken_reachabilities: {:2d}/{:3d}
broken_flows: {:2d}/{:3d}
patience: {}
===============
"""
return fmt