1515"""Utilities for plotting outputs of Torax runs.
1616
1717Public API:
18- plot_run: Main entry point. Loads data and returns a plotly Figure.
18+ plot_run: Main entry point. Loads data from file and returns a plotly Figure.
19+ plot_run_from_data_tree: Plots from an in-memory xr.DataTree.
1920 PlotData: Data container exposing all output variables as attributes.
2021 FigureProperties: Configuration for the overall figure layout.
2122 PlotProperties: Configuration for an individual subplot.
@@ -315,62 +316,67 @@ def available_variables(self) -> set[str]:
315316 return attrs
316317
317318
318- def load_data (filename : str ) -> PlotData :
319- """Loads an xr.Dataset from a file, handling coordinate name changes."""
320-
321- data_tree = output .load_state_file (filename )
319+ def _transform_data (ds : xr .Dataset ) -> xr .Dataset :
320+ """Transforms dataset variables to plotting units."""
321+ # TODO(b/414755419)
322+ ds = ds .copy ()
323+
324+ transformations = {
325+ output .J_TOROIDAL_TOTAL : 1e6 , # A/m^2 to MA/m^2
326+ output .J_TOROIDAL_OHMIC : 1e6 , # A/m^2 to MA/m^2
327+ output .J_TOROIDAL_BOOTSTRAP : 1e6 , # A/m^2 to MA/m^2
328+ output .J_TOROIDAL_EXTERNAL : 1e6 , # A/m^2 to MA/m^2
329+ 'j_generic_current' : 1e6 , # A/m^2 to MA/m^2
330+ output .I_BOOTSTRAP : 1e6 , # A to MA
331+ output .IP_PROFILE : 1e6 , # A to MA
332+ output .IP : 1e6 , # A to MA
333+ 'j_ecrh' : 1e6 , # A/m^2 to MA/m^2
334+ 'p_icrh_i' : 1e6 , # W/m^3 to MW/m^3
335+ 'p_icrh_e' : 1e6 , # W/m^3 to MW/m^3
336+ 'p_generic_heat_i' : 1e6 , # W/m^3 to MW/m^3
337+ 'p_generic_heat_e' : 1e6 , # W/m^3 to MW/m^3
338+ 'p_ecrh_e' : 1e6 , # W/m^3 to MW/m^3
339+ 'p_alpha_i' : 1e6 , # W/m^3 to MW/m^3
340+ 'p_alpha_e' : 1e6 , # W/m^3 to MW/m^3
341+ 'p_ohmic_e' : 1e6 , # W/m^3 to MW/m^3
342+ 'p_bremsstrahlung_e' : 1e6 , # W/m^3 to MW/m^3
343+ 'p_cyclotron_radiation_e' : 1e6 , # W/m^3 to MW/m^3
344+ 'p_impurity_radiation_e' : 1e6 , # W/m^3 to MW/m^3
345+ 'ei_exchange' : 1e6 , # W/m^3 to MW/m^3
346+ 'P_ohmic_e' : 1e6 , # W to MW
347+ 'P_aux_total' : 1e6 , # W to MW
348+ 'P_alpha_total' : 1e6 , # W to MW
349+ 'P_bremsstrahlung_e' : 1e6 , # W to MW
350+ 'P_cyclotron_e' : 1e6 , # W to MW
351+ 'P_ecrh' : 1e6 , # W to MW
352+ 'P_radiation_e' : 1e6 , # W to MW
353+ 'I_ecrh' : 1e6 , # A to MA
354+ 'I_aux_generic' : 1e6 , # A to MA
355+ 'W_thermal_total' : 1e6 , # J to MJ
356+ output .N_E : 1e20 , # m^-3 to 10^{20} m^-3
357+ output .N_I : 1e20 , # m^-3 to 10^{20} m^-3
358+ output .N_IMPURITY : 1e20 , # m^-3 to 10^{20} m^-3
359+ }
322360
323- def _transform_data (ds : xr .Dataset ):
324- """Transforms data in-place to the desired units."""
325- # TODO(b/414755419)
326- ds = ds .copy ()
327-
328- transformations = {
329- output .J_TOROIDAL_TOTAL : 1e6 , # A/m^2 to MA/m^2
330- output .J_TOROIDAL_OHMIC : 1e6 , # A/m^2 to MA/m^2
331- output .J_TOROIDAL_BOOTSTRAP : 1e6 , # A/m^2 to MA/m^2
332- output .J_TOROIDAL_EXTERNAL : 1e6 , # A/m^2 to MA/m^2
333- 'j_generic_current' : 1e6 , # A/m^2 to MA/m^2
334- output .I_BOOTSTRAP : 1e6 , # A to MA
335- output .IP_PROFILE : 1e6 , # A to MA
336- output .IP : 1e6 , # A to MA
337- 'j_ecrh' : 1e6 , # A/m^2 to MA/m^2
338- 'p_icrh_i' : 1e6 , # W/m^3 to MW/m^3
339- 'p_icrh_e' : 1e6 , # W/m^3 to MW/m^3
340- 'p_generic_heat_i' : 1e6 , # W/m^3 to MW/m^3
341- 'p_generic_heat_e' : 1e6 , # W/m^3 to MW/m^3
342- 'p_ecrh_e' : 1e6 , # W/m^3 to MW/m^3
343- 'p_alpha_i' : 1e6 , # W/m^3 to MW/m^3
344- 'p_alpha_e' : 1e6 , # W/m^3 to MW/m^3
345- 'p_ohmic_e' : 1e6 , # W/m^3 to MW/m^3
346- 'p_bremsstrahlung_e' : 1e6 , # W/m^3 to MW/m^3
347- 'p_cyclotron_radiation_e' : 1e6 , # W/m^3 to MW/m^3
348- 'p_impurity_radiation_e' : 1e6 , # W/m^3 to MW/m^3
349- 'ei_exchange' : 1e6 , # W/m^3 to MW/m^3
350- 'P_ohmic_e' : 1e6 , # W to MW
351- 'P_aux_total' : 1e6 , # W to MW
352- 'P_alpha_total' : 1e6 , # W to MW
353- 'P_bremsstrahlung_e' : 1e6 , # W to MW
354- 'P_cyclotron_e' : 1e6 , # W to MW
355- 'P_ecrh' : 1e6 , # W to MW
356- 'P_radiation_e' : 1e6 , # W to MW
357- 'I_ecrh' : 1e6 , # A to MA
358- 'I_aux_generic' : 1e6 , # A to MA
359- 'W_thermal_total' : 1e6 , # J to MJ
360- output .N_E : 1e20 , # m^-3 to 10^{20} m^-3
361- output .N_I : 1e20 , # m^-3 to 10^{20} m^-3
362- output .N_IMPURITY : 1e20 , # m^-3 to 10^{20} m^-3
363- }
361+ for var_name , scale in transformations .items ():
362+ if var_name in ds :
363+ ds [var_name ] /= scale
364364
365- for var_name , scale in transformations .items ():
366- if var_name in ds :
367- ds [var_name ] /= scale
365+ return ds
368366
369- return ds
370367
368+ def _data_tree_to_plot_data (data_tree : xr .DataTree ) -> PlotData :
369+ """Converts an xr.DataTree to a PlotData object with unit transformations."""
371370 return PlotData (xr .map_over_datasets (_transform_data , data_tree ))
372371
373372
373+ def load_data (filename : str ) -> PlotData :
374+ """Loads an xr.Dataset from a file, handling coordinate name changes."""
375+
376+ data_tree = output .load_state_file (filename )
377+ return _data_tree_to_plot_data (data_tree )
378+
379+
374380def _get_file_path (outfile : str ) -> str :
375381 """Gets the absolute path to the file."""
376382 possible_paths = [outfile ]
@@ -399,23 +405,38 @@ def plot_run(
399405 outfile2 : str | None = None ,
400406 interactive : bool = True ,
401407) -> go .Figure :
402- """Plots a single run or comparison of two runs."""
408+ """Plots a single run or comparison of two runs from output files."""
409+ fig_title = plot_config .figure_title or _get_title_from_paths (
410+ outfile , outfile2
411+ )
412+
403413 outfile = _get_file_path (outfile )
404414 outfile2 = _get_file_path (outfile2 ) if outfile2 else None
405415
406- plotdata1 = load_data (outfile )
407- plotdata2 = load_data (outfile2 ) if outfile2 else None
416+ data_tree = output .load_state_file (outfile )
417+ data_tree2 = output .load_state_file (outfile2 ) if outfile2 else None
418+ return plot_run_from_data_tree (
419+ plot_config , data_tree , data_tree2 , interactive , fig_title
420+ )
421+
408422
409- # Prepare list of datasets to check, associating them with their filenames
410- # for clearer errors
411- datasets_to_check = [(plotdata1 , outfile )]
423+ def plot_run_from_data_tree (
424+ plot_config : FigureProperties ,
425+ data_tree : xr .DataTree ,
426+ data_tree2 : xr .DataTree | None = None ,
427+ interactive : bool = True ,
428+ fig_title : str = 'Torax Simulation Results' ,
429+ ) -> go .Figure :
430+ """Plots a single run or comparison of two runs from in-memory DataTrees."""
431+ plotdata1 = _data_tree_to_plot_data (data_tree )
432+ plotdata2 = _data_tree_to_plot_data (data_tree2 ) if data_tree2 else None
433+
434+ datasets_to_check = [(plotdata1 , 'data_tree' )]
412435 if plotdata2 is not None :
413- datasets_to_check .append ((plotdata2 , outfile2 ))
436+ datasets_to_check .append ((plotdata2 , 'data_tree2' ))
414437
415438 for plotdata , filename in datasets_to_check :
416- # Get the set of valid keys for this specific dataset
417439 available_vars = plotdata .available_variables ()
418-
419440 for cfg in plot_config .axes :
420441 for attr in cfg .attrs :
421442 if attr not in available_vars :
@@ -424,8 +445,7 @@ def plot_run(
424445 f'output file: { filename } '
425446 )
426447
427- title = plot_config .figure_title or _get_title_from_paths (outfile , outfile2 )
428- fig = create_plotly_figure (plot_config , plotdata1 , plotdata2 , title )
448+ fig = create_plotly_figure (plot_config , plotdata1 , plotdata2 , fig_title )
429449 if interactive :
430450 fig .show ()
431451
0 commit comments