|
4 | 4 | """Graph extraction using NLP.""" |
5 | 5 |
|
6 | 6 | from itertools import combinations |
| 7 | +from typing import Any |
7 | 8 |
|
8 | 9 | import numpy as np |
9 | 10 | import pandas as pd |
@@ -31,7 +32,6 @@ async def build_noun_graph( |
31 | 32 | text_units, text_analyzer, num_threads=num_threads, cache=cache |
32 | 33 | ) |
33 | 34 | edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights) |
34 | | - |
35 | 35 | return (nodes_df, edges_df) |
36 | 36 |
|
37 | 37 |
|
@@ -95,35 +95,28 @@ def _extract_edges( |
95 | 95 | """ |
96 | 96 | text_units_df = nodes_df.explode("text_unit_ids") |
97 | 97 | text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"}) |
| 98 | + |
98 | 99 | text_units_df = ( |
99 | | - text_units_df.groupby("text_unit_id").agg({"title": list}).reset_index() |
| 100 | + text_units_df.groupby("text_unit_id") |
| 101 | + .agg({"title": lambda x: list(x) if len(x) > 1 else np.nan}) |
| 102 | + .reset_index() |
100 | 103 | ) |
101 | | - |
102 | | - text_units_df["edges"] = text_units_df["title"].apply( |
103 | | - lambda x: list(combinations(x, 2)) |
| 104 | + text_units_df = text_units_df.dropna() |
| 105 | + titles = text_units_df["title"].tolist() |
| 106 | + all_edges: Any = [list(combinations(t, 2)) for t in titles] |
| 107 | + |
| 108 | + text_units_df = text_units_df.assign(edges=all_edges) |
| 109 | + edge_df = text_units_df.explode("edges")[["edges", "text_unit_id"]] |
| 110 | + |
| 111 | + edge_df[["source", "target"]] = edge_df["edges"].to_list() |
| 112 | + edge_df["min_source"] = edge_df[["source", "target"]].min(axis=1) |
| 113 | + edge_df["max_target"] = edge_df[["source", "target"]].max(axis=1) |
| 114 | + edge_df = edge_df.drop(columns=["source", "target"]).rename( |
| 115 | + columns={"min_source": "source", "max_target": "target"} |
104 | 116 | ) |
105 | 117 |
|
106 | | - edge_df = text_units_df.explode("edges").loc[:, ["edges", "text_unit_id"]] |
107 | | - |
108 | | - edge_df["source"] = edge_df["edges"].apply( |
109 | | - lambda x: x[0] if isinstance(x, tuple) else None |
110 | | - ) |
111 | | - edge_df["target"] = edge_df["edges"].apply( |
112 | | - lambda x: x[1] if isinstance(x, tuple) else None |
113 | | - ) |
114 | 118 | edge_df = edge_df[(edge_df.source.notna()) & (edge_df.target.notna())] |
115 | 119 | edge_df = edge_df.drop(columns=["edges"]) |
116 | | - # make sure source is always smaller than target |
117 | | - edge_df["source"], edge_df["target"] = zip( |
118 | | - *edge_df.apply( |
119 | | - lambda x: (x["source"], x["target"]) |
120 | | - if x["source"] < x["target"] |
121 | | - else (x["target"], x["source"]), |
122 | | - axis=1, |
123 | | - ), |
124 | | - strict=False, |
125 | | - ) |
126 | | - |
127 | 120 | # group by source and target, count the number of text units |
128 | 121 | grouped_edge_df = ( |
129 | 122 | edge_df.groupby(["source", "target"]).agg({"text_unit_id": list}).reset_index() |
|
0 commit comments