1313from  matplotlib  import  cm , colors 
1414
1515from  flamingo_tools .s3_utils  import  BUCKET_NAME , create_s3_target 
16- from  util  import  sliding_runlength_sum , frequency_mapping , prism_style ,  prism_cleanup_axes ,  SYNAPSE_DIR_ROOT 
16+ from  util  import  sliding_runlength_sum , frequency_mapping , SYNAPSE_DIR_ROOT 
1717
18- # INPUT_ROOT = "/home/pape/Work/my_projects/flamingo-tools/scripts/M_LR_000227_R/scale3" 
19- INPUT_ROOT  =  "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/frequency_mapping/M_LR_000227_R/scale3" 
20- FILE_EXTENSION  =  "png" 
18+ INPUT_ROOT  =  "/home/martin/Documents/lightsheet-cochlea/M_LR_000227_R" 
2119
2220TYPE_TO_CHANNEL  =  {
2321    "Type-Ia" : "CR" ,
@@ -93,7 +91,7 @@ def get_tonotopic_data():
9391        return  pickle .load (f )
9492
9593
96- def  _plot_colormap (vol , title , plot , save_path ):
94+ def  _plot_colormap (vol , title , plot , save_path ,  cmap = "viridis" ):
9795    # before creating the figure: 
9896    matplotlib .rcParams .update ({
9997        "font.size" : 14 ,          # base font size 
@@ -110,10 +108,16 @@ def _plot_colormap(vol, title, plot, save_path):
110108
111109    freq_min  =  np .min (np .nonzero (vol ))
112110    freq_max  =  vol .max ()
113-     norm  =  colors .Normalize (vmin = freq_min , vmax = freq_max , clip = True )
114-     cmap  =  plt .get_cmap ("viridis" )
111+     # norm = colors.Normalize(vmin=freq_min, vmax=freq_max, clip=True) 
112+     norm  =  colors .LogNorm (vmin = freq_min , vmax = freq_max , clip = True )
113+     tick_values  =  np .array ([10 , 20 , 40 , 80 ])
114+ 
115+     cmap  =  plt .get_cmap (cmap )
115116
116-     cb  =  plt .colorbar (cm .ScalarMappable (norm = norm , cmap = cmap ), cax = ax , orientation = "horizontal" )
117+     cb  =  plt .colorbar (cm .ScalarMappable (norm = norm , cmap = cmap ), cax = ax , orientation = "horizontal" ,
118+                       ticks = tick_values )
119+     cb .ax .xaxis .set_major_formatter (matplotlib .ticker .ScalarFormatter ())
120+     cb .ax .xaxis .set_minor_locator (matplotlib .ticker .NullLocator ())
117121    cb .set_label ("Frequency [kHz]" )
118122    plt .title (title )
119123    plt .tight_layout ()
@@ -127,19 +131,29 @@ def _plot_colormap(vol, title, plot, save_path):
127131    plt .close ()
128132
129133
130- def  fig_03a (save_path , plot , plot_napari ):
134+ def  fig_03a (save_path , plot , plot_napari ,  cmap = "viridis" ):
131135    path_ihc  =  os .path .join (INPUT_ROOT , "frequencies_IHC_v4c.tif" )
132136    path_sgn  =  os .path .join (INPUT_ROOT , "frequencies_SGN_v2.tif" )
133137    sgn  =  imageio .imread (path_sgn )
134138    ihc  =  imageio .imread (path_ihc )
135-     _plot_colormap (sgn , title = "Tonotopic Mapping" , plot = plot , save_path = save_path )
139+     _plot_colormap (sgn , title = "Tonotopic Mapping" , plot = plot , save_path = save_path ,  cmap = cmap )
136140
137141    # Show the image in napari for rendering. 
138142    if  plot_napari :
139143        import  napari 
144+         from  napari .utils  import  Colormap 
145+         # cmap = plt.get_cmap(cmap) 
146+         mpl_cmap  =  plt .get_cmap (cmap )
147+ 
148+         # Sample it into an array of RGBA values 
149+         colors  =  mpl_cmap (np .linspace (0 , 1 , 256 ))
150+ 
151+         # Wrap into napari Colormap 
152+         napari_cmap  =  Colormap (colors , name = f"{ cmap }  _custom" )
153+ 
140154        v  =  napari .Viewer ()
141-         v .add_image (ihc , colormap = "viridis" )
142-         v .add_image (sgn , colormap = "viridis" )
155+         v .add_image (ihc , colormap = napari_cmap )
156+         v .add_image (sgn , colormap = napari_cmap )
143157        napari .run ()
144158
145159
@@ -180,8 +194,7 @@ def fig_03c_rl(save_path, plot=False):
180194        plt .close ()
181195
182196
183- def  fig_03c_octave (tonotopic_data , save_path , plot = False , use_alias = True , trendlines = False ):
184-     prism_style ()
197+ def  fig_03c_octave (tonotopic_data , save_path , plot = False , use_alias = True ):
185198    ihc_version  =  "ihc_counts_v4c" 
186199    tables  =  glob (os .path .join (SYNAPSE_DIR_ROOT , ihc_version , "ihc_count_M_LR*.tsv" ))
187200    assert  len (tables ) ==  4 , len (tables )
@@ -207,55 +220,16 @@ def fig_03c_octave(tonotopic_data, save_path, plot=False, use_alias=True, trendl
207220    result ["x_pos" ] =  result ["octave_band" ].map (band_to_x )
208221
209222    fig , ax  =  plt .subplots (figsize = (8 , 4 ))
210-     trend_dict  =  {}
211223    for  name , grp  in  result .groupby ("cochlea" ):
212224        ax .scatter (grp ["x_pos" ], grp ["value" ], label = name , s = 60 , alpha = 0.8 )
213225
214-         if  trendlines :
215-             x_positions  =  grp ["x_pos" ]
216-             sorted_idx  =  np .argsort (x_positions )
217-             x_sorted  =  np .array (x_positions )[sorted_idx ]
218-             y_sorted  =  np .array (grp ["value" ])[sorted_idx ]
219-             trend_dict [name ] =  {"x_sorted" : x_sorted ,
220-                                 "y_sorted" : y_sorted ,
221-                                 }
222- 
223-     if  trendlines :
224-         def  get_trendline_values (trend_dict ):
225-             x_sorted  =  [trend_dict [k ]["x_sorted" ] for  k  in  trend_dict .keys ()][0 ]
226-             y_sorted_all  =  [trend_dict [k ]["y_sorted" ] for  k  in  trend_dict .keys ()]
227-             y_sorted  =  []
228-             for  num  in  range (len (x_sorted )):
229-                 y_sorted .append (np .mean ([y [num ] for  y  in  y_sorted_all ]))
230-             return  x_sorted , y_sorted 
231- 
232-         # Trendline left 
233-         x_sorted , y_sorted  =  get_trendline_values (trend_dict )
234- 
235-         trend , =  ax .plot (
236-             x_sorted ,
237-             y_sorted ,
238-             linestyle = "dotted" ,
239-             color = "grey" ,
240-             alpha = 0.7 
241-         )
242- 
243-         # trendline_legend = ax.legend(handles=[trend], loc='lower center') 
244-         # trendline_legend = ax.legend( 
245-         #     handles=[trend], 
246-         #     labels=["Trendline"], 
247-         #     loc="upper left" 
248-         # ) 
249-         # # Add the legend manually to the Axes. 
250-         # ax.add_artist(trendline_legend) 
251- 
252226    ax .set_xticks (range (len (bin_labels )))
253227    ax .set_xticklabels (bin_labels )
254228    ax .set_xlabel ("Octave band (kHz)" )
255229
256-     ax .set_ylabel ("Average Ribbon Synapse Count per IHC" , fontsize = 10 )
230+     ax .set_ylabel ("Average Ribbon Synapse Count per IHC" )
231+     ax .set_title ("Ribbon synapse count per octave band" )
257232    plt .legend (title = "Cochlea" )
258-     prism_cleanup_axes (ax )
259233
260234    if  ".png"  in  save_path :
261235        plt .savefig (save_path , bbox_inches = "tight" , pad_inches = 0.1 , dpi = png_dpi )
@@ -345,15 +319,16 @@ def main():
345319    tonotopic_data  =  get_tonotopic_data ()
346320
347321    # Panel A: Tonotopic mapping of SGNs and IHCs (rendering in napari + heatmap) 
348-     # fig_03a(save_path=os.path.join(args.figure_dir, f"fig_03a_cmap.{FILE_EXTENSION}"), 
349-     # plot=args.plot, plot_napari=True) 
322+     cmap  =  "plasma" 
323+     fig_03a (save_path = os .path .join (args .figure_dir , f"fig_03a_cmap_{ cmap }  .{ FILE_EXTENSION }  " ),
324+             plot = args .plot , plot_napari = True , cmap = cmap )
350325
351326    # Panel C: Spatial distribution of synapses across the cochlea. 
352327    # We have two options: running sum over the runlength or per octave band 
353328    # fig_03c_rl(save_path=os.path.join(args.figure_dir, f"fig_03c_runlength.{FILE_EXTENSION}"), plot=args.plot) 
354329    fig_03c_octave (tonotopic_data = tonotopic_data ,
355-                    save_path = os .path .join (args .figure_dir , f"fig_03c_octave.{ FILE_EXTENSION }  " ),
356-                    plot = args .plot ,  trendlines = True )
330+                      save_path = os .path .join (args .figure_dir , f"fig_03c_octave.{ FILE_EXTENSION }  " ),
331+                      plot = args .plot )
357332
358333    # Panel D: Spatial distribution of SGN sub-types. 
359334    # fig_03d_fraction(save_path=os.path.join(args.figure_dir, f"fig_03d_fraction.{FILE_EXTENSION}"), plot=args.plot) 
0 commit comments