Skip to content

Commit ef2e109

Browse files
fix: fix fetching img_path in vqa_generator
1 parent aa87906 commit ef2e109

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

graphgen/models/generator/vqa_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ async def generate(
8080
for node in nodes:
8181
node_data = node[1]
8282
if "images" in node_data and node_data["images"]:
83-
img_path = node_data["images"]
83+
img_path = node_data["images"]["img_path"]
8484
for qa in qa_pairs.values():
8585
qa["img_path"] = img_path
8686
result.update(qa_pairs)

graphgen/models/partitioner/anchor_bfs_partitioner.py

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import random
22
from collections import deque
3-
from typing import Any, List, Literal, Set
3+
from typing import Any, List, Literal, Set, Tuple
44

55
from graphgen.bases import BaseGraphStorage
66
from graphgen.bases.datatypes import Community
@@ -55,36 +55,9 @@ async def partition(
5555
for seed_node in seeds:
5656
if seed_node in used_n:
5757
continue
58-
59-
comm_n: List[str] = []
60-
comm_e: List[tuple[str, str]] = []
61-
queue: deque[tuple[str, Any]] = deque([(NODE_UNIT, seed_node)])
62-
cnt = 0
63-
64-
while queue and cnt < max_units_per_community:
65-
k, it = queue.popleft()
66-
67-
if k == NODE_UNIT:
68-
if it in used_n:
69-
continue
70-
used_n.add(it)
71-
comm_n.append(it)
72-
cnt += 1
73-
for nei in adj[it]:
74-
e_key = frozenset((it, nei))
75-
if e_key not in used_e:
76-
queue.append((EDGE_UNIT, e_key))
77-
else: # EDGE_UNIT
78-
if it in used_e:
79-
continue
80-
used_e.add(it)
81-
u, v = it
82-
comm_e.append((u, v))
83-
cnt += 1
84-
for n in it:
85-
if n not in used_n:
86-
queue.append((NODE_UNIT, n))
87-
58+
comm_n, comm_e = await self._grow_community(
59+
seed_node, adj, max_units_per_community, used_n, used_e
60+
)
8861
if comm_n or comm_e:
8962
communities.append(
9063
Community(id=len(communities), nodes=comm_n, edges=comm_e)
@@ -105,3 +78,51 @@ async def _pick_anchor_ids(
10578
if self.anchor_type.lower() in node_type:
10679
anchor_ids.add(node_id)
10780
return anchor_ids
81+
82+
@staticmethod
83+
async def _grow_community(
84+
seed: str,
85+
adj: dict[str, List[str]],
86+
max_units: int,
87+
used_n: set[str],
88+
used_e: set[frozenset[str]],
89+
) -> Tuple[List[str], List[Tuple[str, str]]]:
90+
"""
91+
Grow a community from the seed node using BFS.
92+
:param seed: seed node id
93+
:param adj: adjacency list
94+
:param max_units: maximum number of units (nodes + edges) in the community
95+
:param used_n: set of used node ids
96+
:param used_e: set of used edge keys
97+
:return: (list of node ids, list of edge tuples)
98+
"""
99+
comm_n: List[str] = []
100+
comm_e: List[Tuple[str, str]] = []
101+
queue: deque[tuple[str, Any]] = deque([(NODE_UNIT, seed)])
102+
cnt = 0
103+
104+
while queue and cnt < max_units:
105+
k, it = queue.popleft()
106+
107+
if k == NODE_UNIT:
108+
if it in used_n:
109+
continue
110+
used_n.add(it)
111+
comm_n.append(it)
112+
cnt += 1
113+
for nei in adj[it]:
114+
e_key = frozenset((it, nei))
115+
if e_key not in used_e:
116+
queue.append((EDGE_UNIT, e_key))
117+
else: # EDGE_UNIT
118+
if it in used_e:
119+
continue
120+
used_e.add(it)
121+
u, v = it
122+
comm_e.append((u, v))
123+
cnt += 1
124+
for n in it:
125+
if n not in used_n:
126+
queue.append((NODE_UNIT, n))
127+
128+
return comm_n, comm_e

0 commit comments

Comments
 (0)