66if SRC not in sys .path :
77 sys .path .insert (0 , SRC )
88
9- from src .visualization import Visualizer , PlotlyBackend # noqa: E402
9+ from rdd .visualization import Visualizer , PlotlyBackend # noqa: E402
1010
1111st .header ("Sankey Diagram" )
1212
1818rdd = st .session_state ["rdd" ]
1919viz = Visualizer (PlotlyBackend ()) # Sankey supported only in Plotly
2020
21+ # ── guard: need at least 2 ontology levels for Sankey ─────────────────
22+ if rdd .levels < 2 :
23+ st .warning (
24+ "⚠️ Sankey diagrams require at least 2 ontology levels to visualize hierarchical flows."
25+ )
26+ st .info (
27+ f"Current RDD object has { rdd .levels } level(s). Please create an RDD count table with levels ≥ 2."
28+ )
29+ st .stop ()
30+
2131# ── user controls ──────────────────────────────────────────────────────
2232sample_choice = st .selectbox (
2333 "Filter by sample filename (optional)" ,
2434 ["<all samples>" ] + sorted (rdd .counts ["filename" ].unique ()),
2535)
2636
27- max_level = st .number_input (
28- "Maximum hierarchy level" , 1 , rdd .levels , rdd .levels , step = 1
29- )
37+ max_level = st .number_input ("Maximum hierarchy level" , 1 , rdd .levels , rdd .levels , step = 1 )
38+
3039
3140def load_demo_file (filename ):
3241 import io
42+
3343 ROOT = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." ))
3444 path = os .path .join (ROOT , "data" , filename )
3545 with open (path , "rb" ) as f :
3646 file_obj = io .BytesIO (f .read ())
3747 file_obj .name = filename
3848 return file_obj
3949
50+
4051color_map_up = None
4152if st .session_state .get ("use_demo" ):
42- # Automatically use demo color map if in demo mode
53+ # Automatically use demo color map if in demo mode - clean it to 2 columns
4354 try :
44- color_map_up = load_demo_file ("sample_type_hierarchy.csv" )
55+ import pandas as pd
56+ import io
57+
58+ demo_file = load_demo_file ("sample_type_hierarchy.csv" )
59+ color_df = pd .read_csv (demo_file , sep = ";" )
60+
61+ # Keep only descriptor and color_code columns
62+ if "descriptor" in color_df .columns and "color_code" in color_df .columns :
63+ color_df = color_df [["descriptor" , "color_code" ]]
64+
65+ # Convert back to BytesIO
66+ csv_str = color_df .to_csv (index = False , sep = ";" )
67+ color_map_up = io .BytesIO (csv_str .encode ())
68+ color_map_up .name = "sample_type_hierarchy.csv"
4569 st .info ("Demo color map loaded automatically." )
46- except Exception :
47- st .warning ("Demo color map not found. " )
70+ except Exception as e :
71+ st .warning (f "Demo color map not found: { e } " )
4872else :
49- color_map_up = st .file_uploader (
50- "Colour-mapping CSV (`descriptor;color_code`, optional)" , type = ("csv" , "tsv" , "txt" )
73+ color_option = st .radio (
74+ "Color mapping option" ,
75+ ["Use foodomics color mapping" , "Upload custom file" , "Use grayscale" ],
76+ horizontal = True ,
5177 )
5278
79+ if color_option == "Use foodomics color mapping" :
80+ # Use the default foodomics color mapping - clean to 2 columns
81+ try :
82+ import pandas as pd
83+ import io
84+
85+ demo_file = load_demo_file ("sample_type_hierarchy.csv" )
86+ color_df = pd .read_csv (demo_file , sep = ";" )
87+
88+ # Keep only descriptor and color_code columns
89+ if "descriptor" in color_df .columns and "color_code" in color_df .columns :
90+ color_df = color_df [["descriptor" , "color_code" ]]
91+
92+ # Convert back to BytesIO
93+ csv_str = color_df .to_csv (index = False , sep = ";" )
94+ color_map_up = io .BytesIO (csv_str .encode ())
95+ color_map_up .name = "sample_type_hierarchy.csv"
96+ st .info ("✓ Using foodomics reference color mapping" )
97+ except Exception as e :
98+ st .error (f"Could not load foodomics color mapping: { e } " )
99+ elif color_option == "Upload custom file" :
100+ color_map_up = st .file_uploader (
101+ "Colour-mapping file (CSV/TSV with 2 columns: descriptor and color_code)" ,
102+ type = ("csv" , "tsv" , "txt" ),
103+ )
104+ else :
105+ # Generate grayscale mapping automatically
106+ import io
107+ import pandas as pd
108+
109+ # Get all unique reference types from level 0
110+ unique_types = rdd .counts [rdd .counts ["level" ] == 0 ]["reference_type" ].unique ()
111+ n_types = len (unique_types )
112+
113+ # Generate grayscale colors
114+ grayscale_colors = [
115+ f"#{ int (255 - (i * 200 / max (n_types - 1 , 1 ))):02x} " * 3 for i in range (n_types )
116+ ]
117+
118+ # Create mapping dataframe
119+ color_df = pd .DataFrame ({"descriptor" : unique_types , "color_code" : grayscale_colors })
120+
121+ # Convert to BytesIO object (with header to match expected format)
122+ csv_str = color_df .to_csv (index = False , sep = ";" , header = True )
123+ color_map_up = io .BytesIO (csv_str .encode ())
124+ color_map_up .name = "grayscale_mapping.csv"
125+ st .info (f"✓ Generated grayscale mapping for { n_types } reference types" )
126+
53127dark_mode = st .checkbox ("Dark mode" )
54128
55129# ── draw button ────────────────────────────────────────────────────────
56130if st .button ("Draw Sankey" ):
57- # persist colour map only if provided
58-
59- colour_path = None
60- if color_map_up :
61- with tempfile .NamedTemporaryFile (delete = False , suffix = ".csv" ) as tmp :
62- tmp .write (color_map_up .getbuffer ())
131+ if not color_map_up :
132+ st .error ("⚠️ Please select a color mapping option." )
133+ st .stop ()
134+
135+ # Read uploaded file and convert to semicolon-separated format for backend
136+ import pandas as pd
137+ import io
138+
139+ try :
140+ # Try to read with common separators
141+ color_map_up .seek (0 )
142+ content = color_map_up .read ().decode ("utf-8" )
143+
144+ # Detect separator and read
145+ if ";" in content .split ("\n " )[0 ]:
146+ sep = ";"
147+ elif "\t " in content .split ("\n " )[0 ]:
148+ sep = "\t "
149+ else :
150+ sep = ","
151+
152+ color_map_up .seek (0 )
153+ color_df = pd .read_csv (color_map_up , sep = sep )
154+
155+ # Ensure we have the expected columns
156+ if len (color_df .columns ) >= 2 :
157+ color_df = color_df .iloc [:, :2 ] # Take first two columns
158+ color_df .columns = ["descriptor" , "color_code" ]
159+
160+ # Save as semicolon-separated for backend
161+ with tempfile .NamedTemporaryFile (delete = False , suffix = ".csv" , mode = "w" ) as tmp :
162+ color_df .to_csv (tmp , sep = ";" , index = False )
63163 colour_path = tmp .name
164+ except Exception as e :
165+ st .error (f"Error reading color mapping file: { e } " )
166+ st .stop ()
64167
65168 fig = viz .plot_sankey (
66169 rdd ,
67- color_mapping_file = colour_path , # may be None
170+ color_mapping_file = colour_path ,
68171 max_hierarchy_level = max_level or None ,
69172 filename_filter = None if sample_choice == "<all samples>" else sample_choice ,
70173 dark_mode = dark_mode ,
71174 )
72- st .plotly_chart (fig , use_container_width = True )
175+ st .plotly_chart (fig , use_container_width = True )
0 commit comments