Skip to content

Commit a94bce0

Browse files
authored
Update graph class into LOL format and adjust requirements (#35)
* Update graph class into LOL format * Update requirements.txt
1 parent 55b36f1 commit a94bce0

File tree

2 files changed

+215
-95
lines changed

2 files changed

+215
-95
lines changed

grim/imputation/networkx_graph.py

Lines changed: 213 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,175 +1,294 @@
1-
import networkx as nx
2-
import csv
3-
1+
"""
2+
Yuli Tshuva
3+
Create a data structure of directed Graph suitable for our needs with numpy arrays and dictionaries.
4+
"""
45

5-
def missing(labelA, labelB):
6-
a = list(labelA)
7-
b = list(labelB)
8-
return [x for x in b if x not in a]
6+
import csv
7+
import numpy as np
8+
import pandas as pd
9+
import gc
10+
from tqdm.auto import tqdm
11+
import time
912

1013

1114
class Graph(object):
12-
__slots__ = (
13-
"graph",
14-
"labelDict",
15-
"whole_graph",
16-
"full_loci",
17-
"nodes_plan_a",
18-
"nodes_plan_b",
19-
)
20-
2115
def __init__(self, config):
22-
self.graph = nx.DiGraph()
16+
"""
17+
Clean initiation.
18+
"""
19+
self.Edges = []
20+
self.Vertices = []
21+
self.Vertices_attributes = {}
22+
self.Neighbors_start = []
23+
24+
self.Whole_Edges = []
25+
self.Whole_Vertices = []
26+
self.Whole_Vertices_attributes = {}
27+
self.Whole_Neighbors_start = []
28+
2329
self.labelDict = {}
24-
self.whole_graph = nx.DiGraph()
30+
2531
self.full_loci = config["full_loci"]
2632
self.nodes_plan_a, self.nodes_plan_b = [], []
2733
if config["nodes_for_plan_A"]:
2834
path = "/".join(config["node_file"].split("/")[:-1])
29-
30-
# bug: dies if file doesn't exist
31-
# bug: list_f doesn't exist
3235
with open(path + "/nodes_for_plan_a.txt") as list_f:
3336
for item in list_f:
3437
self.nodes_plan_a.append(item.strip())
35-
# bug: dies if file doesn't exist
3638
with open(path + "/nodes_for_plan_b.txt") as list_f:
3739
for item in list_f:
3840
self.nodes_plan_b.append(item.strip())
39-
# self.nodes_plan_a = pickle.load(open( path + '/nodes_for_plan_a.pkl', "rb"))
40-
# self.nodes_plan_b = pickle.load(open( path + '/nodes_for_plan_b.pkl', "rb"))
4141

42-
# build graph from files of nodes and edges between nodes with top relation
4342
def build_graph(self, nodesFile, edgesFile, allEdgesFile):
43+
"""Build the graph by scanning the three files, line by line and filling up the class variables."""
4444
nodesDict = dict()
45-
# add nodes from file
4645
with open(nodesFile) as nodesfile:
4746
readNodes = csv.reader(nodesfile, delimiter=",")
4847
next(readNodes)
49-
for row in readNodes:
48+
for row in tqdm(readNodes, desc="Vertices Creation:"):
5049
if len(row) > 0:
5150
if not self.nodes_plan_a or row[2] in self.nodes_plan_a:
52-
self.graph.add_node(
53-
row[1],
54-
label=row[2],
55-
freq=list(map(float, row[3].split(";"))),
56-
)
51+
self.Vertices.append(row[1])
52+
vertex_id = len(self.Vertices) - 1
53+
self.Vertices_attributes[row[1]] = (row[2], list(map(float, row[3].split(";"))), vertex_id)
54+
5755
if not self.nodes_plan_b or row[2] in self.nodes_plan_b:
58-
self.whole_graph.add_node(
59-
row[1],
60-
label=row[2],
61-
freq=list(map(float, row[3].split(";"))),
62-
)
63-
nodesDict[row[0]] = row[1]
56+
self.Whole_Vertices.append(row[1])
57+
vertex_id = len(self.Whole_Vertices) - 1
58+
self.Whole_Vertices_attributes[row[1]] = (
59+
row[2], list(map(float, row[3].split(";"))), vertex_id)
6460

65-
nodesfile.close()
61+
nodesDict[row[0]] = row[1]
6662

67-
# add edges from file
63+
# Add edges from file
6864
with open(edgesFile) as edgesfile:
6965
readEdges = csv.reader(edgesfile, delimiter=",")
7066
next(readEdges)
71-
for row in readEdges:
67+
for row in tqdm(readEdges, desc="Edges Creation:"):
7268
if len(row) > 0:
73-
node1 = nodesDict[row[0]]
74-
node2 = nodesDict[row[1]]
75-
if node1 in self.graph and node2 in self.graph:
76-
if self.graph.nodes[node1]["label"] == self.full_loci:
77-
self.graph.add_edge(node2, node1)
69+
node1_id = row[0]
70+
node2_id = row[1]
71+
node1 = nodesDict[node1_id]
72+
node2 = nodesDict[node2_id]
73+
if node1 in self.Vertices_attributes and node2 in self.Vertices_attributes:
74+
node1_label = self.Vertices_attributes[node1][0]
75+
if node1_label == self.full_loci:
76+
self.Edges.append([node2_id, node1_id])
7877
else:
79-
self.graph.add_edge(node1, node2)
80-
81-
edgesfile.close()
78+
self.Edges.append([node1_id, node2_id])
8279

8380
# add edges from file
8481
with open(allEdgesFile) as allEdgesfile:
8582
readEdges = csv.reader(allEdgesfile, delimiter=",")
8683
next(readEdges)
87-
for row in readEdges:
84+
for row in tqdm(readEdges, "All Edges Creation:"):
8885
if len(row) > 0:
89-
node1 = nodesDict[row[0]]
90-
node2 = nodesDict[row[1]]
91-
if len(self.whole_graph.nodes[node1]["label"]) < len(
92-
self.whole_graph.nodes[node2]["label"]
93-
):
94-
connector = self.whole_graph.nodes[node2]["label"] + node1
95-
self.whole_graph.add_edge(node1, connector)
96-
self.whole_graph.add_edge(connector, node2)
86+
node1_id = row[0]
87+
node2_id = row[1]
88+
node1 = nodesDict[node1_id]
89+
node2 = nodesDict[node2_id]
90+
node1_label = self.Whole_Vertices_attributes[node1][0]
91+
node2_label = self.Whole_Vertices_attributes[node2][0]
92+
93+
if len(node1_label) < len(node2_label):
94+
# Create a connector
95+
connector = node2_label + node1
96+
97+
if connector not in self.Whole_Vertices_attributes:
98+
self.Whole_Vertices.append(connector)
99+
connector_id = len(self.Whole_Vertices) - 1
100+
self.Whole_Vertices_attributes[connector] = connector_id
101+
102+
self.Whole_Edges.append([node1_id, connector_id])
103+
else:
104+
connector_id = self.Whole_Vertices_attributes[connector]
105+
106+
# Append the connector to the whole edges array
107+
self.Whole_Edges.append([connector_id, node2_id])
108+
97109
else:
98-
connector = self.whole_graph.nodes[node1]["label"] + node2
99-
self.whole_graph.add_edge(node2, connector)
100-
self.whole_graph.add_edge(connector, node1)
110+
# Create a connector
111+
connector = node1_label + node2
112+
113+
if connector not in self.Whole_Vertices_attributes:
114+
self.Whole_Vertices.append(connector)
115+
connector_id = len(self.Whole_Vertices) - 1
116+
self.Whole_Vertices_attributes[connector] = connector_id
117+
118+
# Append the connector to the whole edges array
119+
self.Whole_Edges.append([node2_id, connector_id])
120+
self.Whole_Edges.append([connector_id, node1_id])
101121

102-
allEdgesfile.close()
103122
nodesDict.clear()
123+
del nodesDict
124+
125+
# Concat all the lists of the edges lists to a numpy array
126+
self.Edges = np.vstack(self.Edges).astype(np.uint32)
127+
self.Whole_Edges = np.vstack(self.Whole_Edges).astype(np.uint32)
128+
self.Vertices = np.array(self.Vertices, dtype=np.object_)
129+
self.Whole_Vertices = np.array(self.Whole_Vertices, dtype=np.object_)
130+
131+
# Drop duplications in edges
132+
df_e = pd.DataFrame(self.Whole_Edges)
133+
df_e.drop_duplicates(inplace=True)
134+
del self.Whole_Edges
135+
self.Whole_Edges = df_e.to_numpy()
136+
del df_e
137+
138+
# Sort the Edges arrays - Takes numpy to sort an array of size (10**8, 2) about 45 secs on Google Colab.
139+
sorted_indices = np.lexsort((self.Edges[:, 1], self.Edges[:, 0]))
140+
self.Edges = self.Edges[sorted_indices]
141+
sorted_indices = np.lexsort((self.Whole_Edges[:, 1], self.Whole_Edges[:, 0]))
142+
self.Whole_Edges = self.Whole_Edges[sorted_indices]
143+
144+
# Save memory
145+
del sorted_indices
146+
147+
# Create a list of the first appearance of a number in the 0 column in the matrix
148+
unique_values, first_occurrences_indices = np.unique(self.Edges[:, 0], return_index=True)
149+
150+
j = 0
151+
for i in range(0, self.Vertices.shape[0]):
152+
if int(unique_values[j]) == i:
153+
self.Neighbors_start.append(int(first_occurrences_indices[j]))
154+
j += 1
155+
else:
156+
try:
157+
self.Neighbors_start.append(self.Neighbors_start[-1])
158+
except: # In case of the start of the list - empty list
159+
self.Neighbors_start.append(0)
160+
161+
# Free some memory
162+
del unique_values, first_occurrences_indices
163+
164+
# Create a list of the first appearance of a number in the 0 column in the matrix
165+
unique_values, first_occurrences_indices = np.unique(self.Whole_Edges[:, 0], return_index=True)
166+
167+
j = 0
168+
for i in range(0, self.Whole_Vertices.shape[0]):
169+
if int(unique_values[j]) == i:
170+
self.Whole_Neighbors_start.append(int(first_occurrences_indices[j]))
171+
j += 1
172+
else:
173+
try:
174+
self.Whole_Neighbors_start.append(self.Whole_Neighbors_start[-1])
175+
except: # In case of the start of the list - empty list
176+
self.Whole_Neighbors_start.append(0)
177+
178+
# Free some memory
179+
del unique_values, first_occurrences_indices
180+
181+
self.Neighbors_start.append(int(len(self.Vertices)))
182+
self.Whole_Neighbors_start.append(int(len(self.Whole_Vertices)))
183+
184+
self.Neighbors_start = np.array(self.Neighbors_start, dtype=np.uint32)
185+
self.Whole_Neighbors_start = np.array(self.Whole_Neighbors_start, dtype=np.uint32)
186+
187+
# Take the first column out of the Edges arrays
188+
### Do the following to massive save of memory
189+
Edges = self.Edges[:, 1].copy()
190+
del self.Edges
191+
self.Edges = Edges
192+
193+
Whole_Edges = self.Whole_Edges[:, 1].copy()
194+
del self.Whole_Edges
195+
self.Whole_Edges = Whole_Edges
196+
197+
gc.collect()
104198

105-
# return all haplotype by specific label
106199
def haps_by_label(self, label):
107-
# cheak if already found
200+
"""Find haplotypes by their labels.
201+
Does not use the graphical features of the haplotypes.
202+
Returns a list of haplotypes.
203+
Approved."""
204+
# Check if already found
108205
if label in self.labelDict:
109206
return self.labelDict[label]
110-
# not found yet. serach and save in labelDict
207+
# If you get here, label hasn't been found yet. So, I should find it and save in labelDict.
111208
hapsList = []
112209
if not self.nodes_plan_a or label in self.nodes_plan_a:
113-
for key, key_data in self.graph.nodes(data=True):
114-
if key_data["label"] == label:
115-
hapsList.append(key)
210+
for haplotype, hap_label in self.Vertices_attributes.items():
211+
hap_label = hap_label[0]
212+
if hap_label == label:
213+
hapsList.append(haplotype)
116214
elif label in self.nodes_plan_b:
117-
for key, key_data in self.whole_graph.nodes(data=True):
118-
if key_data["label"] == label:
119-
hapsList.append(key)
215+
for haplotype, hap_label in self.Whole_Vertices_attributes.items():
216+
hap_label = hap_label[0]
217+
if hap_label == label:
218+
hapsList.append(haplotype)
120219
self.labelDict[label] = hapsList
121220
return hapsList
122221

123222
def haps_with_probs_by_label(self, label):
223+
"""Find the haplotypes just like the above function but with the haplotypes' probabilities.
224+
Does not use the graphical features of the haplotypes.
225+
Returns a dictionary of haplotype to frequency.
226+
Approved."""
124227
dictAlleles = {}
125-
126228
listLabel = self.haps_by_label(label)
127229
if not self.nodes_plan_a or label in self.nodes_plan_a:
128230
for allele in listLabel:
129-
dictAlleles[allele] = self.graph.nodes[allele]["freq"]
231+
dictAlleles[allele] = self.Vertices_attributes[allele][1]
130232
elif label in self.nodes_plan_b:
131233
for allele in listLabel:
132-
dictAlleles[allele] = self.whole_graph.nodes[allele]["freq"]
133-
234+
dictAlleles[allele] = self.Whole_Vertices_attributes[allele][1]
134235
return dictAlleles
135236

136-
# find all adj of alleleList from label 'ABCQR'
137237
def adjs_query(self, alleleList):
238+
"""A filtering query on the alleles in the graph.
239+
Does use the graph class.
240+
Returns a dictionary of haplotypes (can be partial) to frequencies.
241+
Approved."""
138242
adjDict = dict()
139243
for allele in alleleList:
140-
if allele in self.graph:
141-
allele_node = self.graph.nodes[allele]
142-
if allele_node["label"] == self.full_loci: # 'ABCQR':
143-
adjDict[allele] = allele_node["freq"]
244+
if allele in self.Vertices_attributes:
245+
allele_attributes = self.Vertices_attributes[allele][0:2]
246+
if allele_attributes[0] == self.full_loci:
247+
adjDict[allele] = allele_attributes[1]
144248
else:
145-
adjs = self.graph.adj[allele]
146-
for adj in adjs:
147-
adjDict[adj] = self.graph.nodes[adj]["freq"]
249+
allele_id = self.Vertices_attributes[allele][2]
250+
# Find the neighbors of the allele
251+
allele_neighbors = self.Vertices[
252+
self.Edges[range(self.Neighbors_start[allele_id], self.Neighbors_start[allele_id + 1])]]
253+
# The frequencies of the neighbors to the dictionary
254+
for adj in allele_neighbors:
255+
adjDict[adj] = self.Vertices_attributes[adj][1]
148256
return adjDict
149257

150-
# find all adj of alleleList by label
151258
def adjs_query_by_color(self, alleleList, labelA, labelB):
152-
# copyLabelA = labelA
259+
"""A filtering query on the alleles in the graph.
260+
Does use the graph class.
261+
Returns a dictionary of haplotypes (can be partial) to frequencies.
262+
Approved."""
153263
adjDict = dict()
154264
if labelA == labelB:
155265
return self.node_probs(alleleList, labelA)
156266

157267
for allele in alleleList:
158-
if allele in self.whole_graph:
159-
alleles = self.whole_graph.adj.get(labelB + allele, [])
268+
if allele in self.Whole_Vertices_attributes:
269+
alleles = []
270+
connector = labelB + allele
271+
272+
if connector in self.Whole_Vertices_attributes:
273+
connector_id = self.Whole_Vertices_attributes[connector]
274+
alleles = self.Whole_Vertices[self.Whole_Edges[range(self.Whole_Neighbors_start[connector_id],
275+
self.Whole_Neighbors_start[
276+
connector_id + 1])]]
277+
160278
for adj in alleles:
161-
adjDict[adj] = self.whole_graph.nodes[adj]["freq"]
279+
adjDict[adj] = self.Whole_Vertices_attributes[adj][1]
162280
return adjDict
163281

164-
# return dict of nodes and there proper freq
165282
def node_probs(self, nodes, label):
166-
nodesDict = dict()
283+
"""Get a list of haplotypes and a label,
284+
Return a dictionary of nodes and their proper frequency."""
285+
nodesDict = {}
167286
if not self.nodes_plan_b or label in self.nodes_plan_b:
168287
for node in nodes:
169-
if node in self.whole_graph:
170-
nodesDict[node] = self.whole_graph.nodes[node]["freq"]
288+
if node in self.Whole_Vertices_attributes:
289+
nodesDict[node] = self.Whole_Vertices_attributes[node][1]
171290
elif label in self.nodes_plan_a:
172291
for node in nodes:
173-
if node in self.whole_graph:
174-
nodesDict[node] = self.graph.nodes[node]["freq"]
292+
if node in self.Whole_Vertices_attributes:
293+
nodesDict[node] = self.Vertices_attributes[node][1]
175294
return nodesDict

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
cython==0.29.32
22
numpy>=1.20.2
3-
networkx==2.5.1
3+
pandas
4+
tqdm

0 commit comments

Comments
 (0)