Skip to content

Commit ed9d06e

Browse files
changed source requirements to the rdd package, and fixed some additional edge cases (#7)
Co-authored-by: amca2892 <amca2892@gmail.com>
1 parent db8cf37 commit ed9d06e

13 files changed

+133
-2645
lines changed

pages/01_Create_RDD_Count_Table.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if SRC not in sys.path:
1010
sys.path.insert(0, SRC)
1111

12-
from src.RDDcounts import RDDCounts # noqa: E402
12+
from rdd import RDDCounts # noqa: E402
1313
from src.state_helpers import set_group # noqa: E402
1414

1515

@@ -194,7 +194,7 @@ def load_demo_file(filename):
194194
# GNPS1 requires group selection from the network data
195195
with st.spinner("📊 Fetching GNPS1 data to display available groups..."):
196196
try:
197-
from src.utils import get_gnps_task_data
197+
from rdd.utils import get_gnps_task_data
198198

199199
temp_gnps_df = get_gnps_task_data(gnps_task_id, gnps2=False)
200200
if "DefaultGroups" in temp_gnps_df.columns:

pages/02_Visualizations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if SRC not in sys.path:
77
sys.path.insert(0, SRC)
88

9-
from src.visualization import Visualizer, PlotlyBackend, MatplotlibBackend # noqa: E402
9+
from rdd.visualization import Visualizer, PlotlyBackend, MatplotlibBackend # noqa: E402
1010

1111
if "rdd" not in st.session_state:
1212
st.warning("First create an RDDCounts object.")

pages/03_PCA_Analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
if SRC not in sys.path:
77
sys.path.insert(0, SRC)
88

9-
from src.analysis import perform_pca_RDD_counts # noqa: E402
10-
from src.visualization import Visualizer, PlotlyBackend, MatplotlibBackend # noqa: E402
9+
from rdd.analysis import perform_pca_RDD_counts # noqa: E402
10+
from rdd.visualization import Visualizer, PlotlyBackend, MatplotlibBackend # noqa: E402
1111

1212
if "rdd" not in st.session_state:
1313
st.warning("First create an RDDCounts object.")

pages/04_Sankey_Diagram.py

Lines changed: 121 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
if SRC not in sys.path:
77
sys.path.insert(0, SRC)
88

9-
from src.visualization import Visualizer, PlotlyBackend # noqa: E402
9+
from rdd.visualization import Visualizer, PlotlyBackend # noqa: E402
1010

1111
st.header("Sankey Diagram")
1212

@@ -18,55 +18,158 @@
1818
rdd = st.session_state["rdd"]
1919
viz = Visualizer(PlotlyBackend()) # Sankey supported only in Plotly
2020

21+
# ── guard: need at least 2 ontology levels for Sankey ─────────────────
22+
if rdd.levels < 2:
23+
st.warning(
24+
"⚠️ Sankey diagrams require at least 2 ontology levels to visualize hierarchical flows."
25+
)
26+
st.info(
27+
f"Current RDD object has {rdd.levels} level(s). Please create an RDD count table with levels ≥ 2."
28+
)
29+
st.stop()
30+
2131
# ── user controls ──────────────────────────────────────────────────────
2232
sample_choice = st.selectbox(
2333
"Filter by sample filename (optional)",
2434
["<all samples>"] + sorted(rdd.counts["filename"].unique()),
2535
)
2636

27-
max_level = st.number_input(
28-
"Maximum hierarchy level", 1, rdd.levels, rdd.levels, step=1
29-
)
37+
max_level = st.number_input("Maximum hierarchy level", 1, rdd.levels, rdd.levels, step=1)
38+
3039

3140
def load_demo_file(filename):
3241
import io
42+
3343
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
3444
path = os.path.join(ROOT, "data", filename)
3545
with open(path, "rb") as f:
3646
file_obj = io.BytesIO(f.read())
3747
file_obj.name = filename
3848
return file_obj
3949

50+
4051
color_map_up = None
4152
if st.session_state.get("use_demo"):
42-
# Automatically use demo color map if in demo mode
53+
# Automatically use demo color map if in demo mode - clean it to 2 columns
4354
try:
44-
color_map_up = load_demo_file("sample_type_hierarchy.csv")
55+
import pandas as pd
56+
import io
57+
58+
demo_file = load_demo_file("sample_type_hierarchy.csv")
59+
color_df = pd.read_csv(demo_file, sep=";")
60+
61+
# Keep only descriptor and color_code columns
62+
if "descriptor" in color_df.columns and "color_code" in color_df.columns:
63+
color_df = color_df[["descriptor", "color_code"]]
64+
65+
# Convert back to BytesIO
66+
csv_str = color_df.to_csv(index=False, sep=";")
67+
color_map_up = io.BytesIO(csv_str.encode())
68+
color_map_up.name = "sample_type_hierarchy.csv"
4569
st.info("Demo color map loaded automatically.")
46-
except Exception:
47-
st.warning("Demo color map not found.")
70+
except Exception as e:
71+
st.warning(f"Demo color map not found: {e}")
4872
else:
49-
color_map_up = st.file_uploader(
50-
"Colour-mapping CSV (`descriptor;color_code`, optional)", type=("csv", "tsv", "txt")
73+
color_option = st.radio(
74+
"Color mapping option",
75+
["Use foodomics color mapping", "Upload custom file", "Use grayscale"],
76+
horizontal=True,
5177
)
5278

79+
if color_option == "Use foodomics color mapping":
80+
# Use the default foodomics color mapping - clean to 2 columns
81+
try:
82+
import pandas as pd
83+
import io
84+
85+
demo_file = load_demo_file("sample_type_hierarchy.csv")
86+
color_df = pd.read_csv(demo_file, sep=";")
87+
88+
# Keep only descriptor and color_code columns
89+
if "descriptor" in color_df.columns and "color_code" in color_df.columns:
90+
color_df = color_df[["descriptor", "color_code"]]
91+
92+
# Convert back to BytesIO
93+
csv_str = color_df.to_csv(index=False, sep=";")
94+
color_map_up = io.BytesIO(csv_str.encode())
95+
color_map_up.name = "sample_type_hierarchy.csv"
96+
st.info("✓ Using foodomics reference color mapping")
97+
except Exception as e:
98+
st.error(f"Could not load foodomics color mapping: {e}")
99+
elif color_option == "Upload custom file":
100+
color_map_up = st.file_uploader(
101+
"Colour-mapping file (CSV/TSV with 2 columns: descriptor and color_code)",
102+
type=("csv", "tsv", "txt"),
103+
)
104+
else:
105+
# Generate grayscale mapping automatically
106+
import io
107+
import pandas as pd
108+
109+
# Get all unique reference types from level 0
110+
unique_types = rdd.counts[rdd.counts["level"] == 0]["reference_type"].unique()
111+
n_types = len(unique_types)
112+
113+
# Generate grayscale colors
114+
grayscale_colors = [
115+
f"#{int(255 - (i * 200 / max(n_types-1, 1))):02x}" * 3 for i in range(n_types)
116+
]
117+
118+
# Create mapping dataframe
119+
color_df = pd.DataFrame({"descriptor": unique_types, "color_code": grayscale_colors})
120+
121+
# Convert to BytesIO object (with header to match expected format)
122+
csv_str = color_df.to_csv(index=False, sep=";", header=True)
123+
color_map_up = io.BytesIO(csv_str.encode())
124+
color_map_up.name = "grayscale_mapping.csv"
125+
st.info(f"✓ Generated grayscale mapping for {n_types} reference types")
126+
53127
dark_mode = st.checkbox("Dark mode")
54128

55129
# ── draw button ────────────────────────────────────────────────────────
56130
if st.button("Draw Sankey"):
57-
# persist colour map only if provided
58-
59-
colour_path = None
60-
if color_map_up:
61-
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp:
62-
tmp.write(color_map_up.getbuffer())
131+
if not color_map_up:
132+
st.error("⚠️ Please select a color mapping option.")
133+
st.stop()
134+
135+
# Read uploaded file and convert to semicolon-separated format for backend
136+
import pandas as pd
137+
import io
138+
139+
try:
140+
# Try to read with common separators
141+
color_map_up.seek(0)
142+
content = color_map_up.read().decode("utf-8")
143+
144+
# Detect separator and read
145+
if ";" in content.split("\n")[0]:
146+
sep = ";"
147+
elif "\t" in content.split("\n")[0]:
148+
sep = "\t"
149+
else:
150+
sep = ","
151+
152+
color_map_up.seek(0)
153+
color_df = pd.read_csv(color_map_up, sep=sep)
154+
155+
# Ensure we have the expected columns
156+
if len(color_df.columns) >= 2:
157+
color_df = color_df.iloc[:, :2] # Take first two columns
158+
color_df.columns = ["descriptor", "color_code"]
159+
160+
# Save as semicolon-separated for backend
161+
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
162+
color_df.to_csv(tmp, sep=";", index=False)
63163
colour_path = tmp.name
164+
except Exception as e:
165+
st.error(f"Error reading color mapping file: {e}")
166+
st.stop()
64167

65168
fig = viz.plot_sankey(
66169
rdd,
67-
color_mapping_file=colour_path, # may be None
170+
color_mapping_file=colour_path,
68171
max_hierarchy_level=max_level or None,
69172
filename_filter=None if sample_choice == "<all samples>" else sample_choice,
70173
dark_mode=dark_mode,
71174
)
72-
st.plotly_chart(fig, use_container_width=True)
175+
st.plotly_chart(fig, use_container_width=True)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ scikit-bio>=0.7.1 # CLR transformation
1212

1313
# ───────── GNPS data access ─────────
1414
gnpsdata @ git+https://github.com/Wang-Bioinformatics-Lab/GNPSDataPackage.git@f4ca8d9b7fab87823179b122b7e4a0a0b62e9d65
15+
gnps-rdd @ git+https://github.com/bittremieuxlab/gnps-rdd.git@review/manuscript-revisions
1516

1617
# ───────── optional / indirect ─────────
1718
# pyarrow is auto-installed by pandas ≥2 for fast parquet/feather,

0 commit comments

Comments
 (0)