Skip to content

Commit 6620f90

Browse files
skdeepmindTorax team
authored andcommitted
Plot with datatree w/o creating .nc file
PiperOrigin-RevId: 884592463
1 parent e5df787 commit 6620f90

File tree

4 files changed

+120
-68
lines changed

4 files changed

+120
-68
lines changed

docs/plotting.rst

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,16 @@ plot layout and content. Dataclass fields and defaults are as follows:
175175
- ``cols`` (int): Number of columns in the figure.
176176
- ``axes`` (tuple of ``PlotProperties``): Configuration for each subplot.
177177
See below.
178-
- ``figure_size_factor`` (float=5.0): Scaling factor for the figure size.
179-
- ``tick_fontsize`` (int=10): Font size for axis ticks.
180-
- ``axes_fontsize`` (int=10): Font size for axis labels.
181-
- ``title_fontsize`` (int=16): Font size for the figure title.
182-
- ``default_legend_fontsize`` (int=10): Default font size for legends.
183-
- ``colors`` (tuple[str, ...] = ('r', 'b', 'g', 'm', 'y', 'c')): Colors to use
184-
for plot lines. Cycles through the tuple for multiple lines.
178+
- ``font_family`` (str = 'Arial, sans-serif'): Font family for all text in the figure.
179+
- ``title_size`` (int = 16): Font size for the main figure title.
180+
- ``subplot_title_size`` (int = 12): Font size for subplot titles.
181+
- ``tick_size`` (int = 8): Font size for axis ticks.
182+
- ``nticks_time`` (int = 6): Number of x-axis ticks for the time series plots.
183+
- ``tickvals_rho`` (tuple[float, ...] = (0, 0.2, 0.4, 0.6, 0.8, 1.0)): Values of x-axis ticks for the spatial plots.
184+
- ``height`` (int | None = None): Height of the figure in pixels. If None, autosize is used.
185+
- ``legend_spacing`` (int = 10): Spacing between legend entries.
186+
- ``margin`` (dict[str, int] = {'l': 40, 'r': 40, 't': 80, 'b': 40}): Margin around the figure. A dict with keys l,
187+
r, t and b for left, right, top and bottom margins respectively.
185188

186189
The ``PlotProperties`` dataclass configures individual subplots. For example,
187190
the ``PlotProperties`` object for plotting ion and electron temperatures looks

docs/running_programmatically.rst

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,30 @@ We can then run the simulation:
3333
3434
# Example below shows how to access the fusion gain at time=2 seconds.
3535
Q_fusion_t2 = data_tree.scalars.Q_fusion.sel(time=2, method='nearest')
36+
37+
Plotting from an in-memory simulation
38+
######################################
39+
40+
If you have already run a simulation and have a ``data_tree`` in memory, you can
41+
plot it directly without saving to a file first using
42+
``torax.plot_run_from_data_tree``:
43+
44+
.. code-block:: python
45+
46+
plot_config = torax.import_module('plotting/configs/default_plot_config.py')['PLOT_CONFIG']
47+
48+
# Plot directly from the in-memory data_tree returned by run_simulation.
49+
fig = torax.plot_run_from_data_tree(plot_config, data_tree)
50+
51+
To compare two in-memory runs:
52+
53+
.. code-block:: python
54+
55+
fig = torax.plot_run_from_data_tree(plot_config, data_tree, data_tree2)
56+
57+
If you have saved the output to a ``.nc`` file and want to plot from disk,
58+
use ``torax.plot_run`` instead:
59+
60+
.. code-block:: python
61+
62+
fig = torax.plot_run(plot_config, PATH_TO_LOCAL_NC_FILE)

torax/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torax._src.output_tools.post_processing import PostProcessedOutputs
3333
from torax._src.pedestal_model.pedestal_model_output import PedestalModelOutput
3434
from torax._src.plotting.plotruns_lib import plot_run
35+
from torax._src.plotting.plotruns_lib import plot_run_from_data_tree
3536
from torax._src.sources.source_profiles import SourceProfiles
3637
from torax._src.state import CoreProfiles
3738
from torax._src.state import CoreTransport
@@ -58,6 +59,7 @@
5859
'build_torax_config_from_file',
5960
'import_module',
6061
'plot_run',
62+
'plot_run_from_data_tree',
6163
'run_simulation',
6264
'CoreProfiles',
6365
'CoreTransport',

torax/_src/plotting/plotruns_lib.py

Lines changed: 81 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
"""Utilities for plotting outputs of Torax runs.
1616
1717
Public API:
18-
plot_run: Main entry point. Loads data and returns a plotly Figure.
18+
plot_run: Main entry point. Loads data from file and returns a plotly Figure.
19+
plot_run_from_data_tree: Plots from an in-memory xr.DataTree.
1920
PlotData: Data container exposing all output variables as attributes.
2021
FigureProperties: Configuration for the overall figure layout.
2122
PlotProperties: Configuration for an individual subplot.
@@ -315,62 +316,67 @@ def available_variables(self) -> set[str]:
315316
return attrs
316317

317318

318-
def load_data(filename: str) -> PlotData:
319-
"""Loads an xr.Dataset from a file, handling coordinate name changes."""
320-
321-
data_tree = output.load_state_file(filename)
319+
def _transform_data(ds: xr.Dataset) -> xr.Dataset:
320+
"""Transforms dataset variables to plotting units."""
321+
# TODO(b/414755419)
322+
ds = ds.copy()
323+
324+
transformations = {
325+
output.J_TOROIDAL_TOTAL: 1e6, # A/m^2 to MA/m^2
326+
output.J_TOROIDAL_OHMIC: 1e6, # A/m^2 to MA/m^2
327+
output.J_TOROIDAL_BOOTSTRAP: 1e6, # A/m^2 to MA/m^2
328+
output.J_TOROIDAL_EXTERNAL: 1e6, # A/m^2 to MA/m^2
329+
'j_generic_current': 1e6, # A/m^2 to MA/m^2
330+
output.I_BOOTSTRAP: 1e6, # A to MA
331+
output.IP_PROFILE: 1e6, # A to MA
332+
output.IP: 1e6, # A to MA
333+
'j_ecrh': 1e6, # A/m^2 to MA/m^2
334+
'p_icrh_i': 1e6, # W/m^3 to MW/m^3
335+
'p_icrh_e': 1e6, # W/m^3 to MW/m^3
336+
'p_generic_heat_i': 1e6, # W/m^3 to MW/m^3
337+
'p_generic_heat_e': 1e6, # W/m^3 to MW/m^3
338+
'p_ecrh_e': 1e6, # W/m^3 to MW/m^3
339+
'p_alpha_i': 1e6, # W/m^3 to MW/m^3
340+
'p_alpha_e': 1e6, # W/m^3 to MW/m^3
341+
'p_ohmic_e': 1e6, # W/m^3 to MW/m^3
342+
'p_bremsstrahlung_e': 1e6, # W/m^3 to MW/m^3
343+
'p_cyclotron_radiation_e': 1e6, # W/m^3 to MW/m^3
344+
'p_impurity_radiation_e': 1e6, # W/m^3 to MW/m^3
345+
'ei_exchange': 1e6, # W/m^3 to MW/m^3
346+
'P_ohmic_e': 1e6, # W to MW
347+
'P_aux_total': 1e6, # W to MW
348+
'P_alpha_total': 1e6, # W to MW
349+
'P_bremsstrahlung_e': 1e6, # W to MW
350+
'P_cyclotron_e': 1e6, # W to MW
351+
'P_ecrh': 1e6, # W to MW
352+
'P_radiation_e': 1e6, # W to MW
353+
'I_ecrh': 1e6, # A to MA
354+
'I_aux_generic': 1e6, # A to MA
355+
'W_thermal_total': 1e6, # J to MJ
356+
output.N_E: 1e20, # m^-3 to 10^{20} m^-3
357+
output.N_I: 1e20, # m^-3 to 10^{20} m^-3
358+
output.N_IMPURITY: 1e20, # m^-3 to 10^{20} m^-3
359+
}
322360

323-
def _transform_data(ds: xr.Dataset):
324-
"""Transforms data in-place to the desired units."""
325-
# TODO(b/414755419)
326-
ds = ds.copy()
327-
328-
transformations = {
329-
output.J_TOROIDAL_TOTAL: 1e6, # A/m^2 to MA/m^2
330-
output.J_TOROIDAL_OHMIC: 1e6, # A/m^2 to MA/m^2
331-
output.J_TOROIDAL_BOOTSTRAP: 1e6, # A/m^2 to MA/m^2
332-
output.J_TOROIDAL_EXTERNAL: 1e6, # A/m^2 to MA/m^2
333-
'j_generic_current': 1e6, # A/m^2 to MA/m^2
334-
output.I_BOOTSTRAP: 1e6, # A to MA
335-
output.IP_PROFILE: 1e6, # A to MA
336-
output.IP: 1e6, # A to MA
337-
'j_ecrh': 1e6, # A/m^2 to MA/m^2
338-
'p_icrh_i': 1e6, # W/m^3 to MW/m^3
339-
'p_icrh_e': 1e6, # W/m^3 to MW/m^3
340-
'p_generic_heat_i': 1e6, # W/m^3 to MW/m^3
341-
'p_generic_heat_e': 1e6, # W/m^3 to MW/m^3
342-
'p_ecrh_e': 1e6, # W/m^3 to MW/m^3
343-
'p_alpha_i': 1e6, # W/m^3 to MW/m^3
344-
'p_alpha_e': 1e6, # W/m^3 to MW/m^3
345-
'p_ohmic_e': 1e6, # W/m^3 to MW/m^3
346-
'p_bremsstrahlung_e': 1e6, # W/m^3 to MW/m^3
347-
'p_cyclotron_radiation_e': 1e6, # W/m^3 to MW/m^3
348-
'p_impurity_radiation_e': 1e6, # W/m^3 to MW/m^3
349-
'ei_exchange': 1e6, # W/m^3 to MW/m^3
350-
'P_ohmic_e': 1e6, # W to MW
351-
'P_aux_total': 1e6, # W to MW
352-
'P_alpha_total': 1e6, # W to MW
353-
'P_bremsstrahlung_e': 1e6, # W to MW
354-
'P_cyclotron_e': 1e6, # W to MW
355-
'P_ecrh': 1e6, # W to MW
356-
'P_radiation_e': 1e6, # W to MW
357-
'I_ecrh': 1e6, # A to MA
358-
'I_aux_generic': 1e6, # A to MA
359-
'W_thermal_total': 1e6, # J to MJ
360-
output.N_E: 1e20, # m^-3 to 10^{20} m^-3
361-
output.N_I: 1e20, # m^-3 to 10^{20} m^-3
362-
output.N_IMPURITY: 1e20, # m^-3 to 10^{20} m^-3
363-
}
361+
for var_name, scale in transformations.items():
362+
if var_name in ds:
363+
ds[var_name] /= scale
364364

365-
for var_name, scale in transformations.items():
366-
if var_name in ds:
367-
ds[var_name] /= scale
365+
return ds
368366

369-
return ds
370367

368+
def _data_tree_to_plot_data(data_tree: xr.DataTree) -> PlotData:
369+
"""Converts an xr.DataTree to a PlotData object with unit transformations."""
371370
return PlotData(xr.map_over_datasets(_transform_data, data_tree))
372371

373372

373+
def load_data(filename: str) -> PlotData:
374+
"""Loads an xr.Dataset from a file, handling coordinate name changes."""
375+
376+
data_tree = output.load_state_file(filename)
377+
return _data_tree_to_plot_data(data_tree)
378+
379+
374380
def _get_file_path(outfile: str) -> str:
375381
"""Gets the absolute path to the file."""
376382
possible_paths = [outfile]
@@ -399,23 +405,38 @@ def plot_run(
399405
outfile2: str | None = None,
400406
interactive: bool = True,
401407
) -> go.Figure:
402-
"""Plots a single run or comparison of two runs."""
408+
"""Plots a single run or comparison of two runs from output files."""
409+
fig_title = plot_config.figure_title or _get_title_from_paths(
410+
outfile, outfile2
411+
)
412+
403413
outfile = _get_file_path(outfile)
404414
outfile2 = _get_file_path(outfile2) if outfile2 else None
405415

406-
plotdata1 = load_data(outfile)
407-
plotdata2 = load_data(outfile2) if outfile2 else None
416+
data_tree = output.load_state_file(outfile)
417+
data_tree2 = output.load_state_file(outfile2) if outfile2 else None
418+
return plot_run_from_data_tree(
419+
plot_config, data_tree, data_tree2, interactive, fig_title
420+
)
421+
408422

409-
# Prepare list of datasets to check, associating them with their filenames
410-
# for clearer errors
411-
datasets_to_check = [(plotdata1, outfile)]
423+
def plot_run_from_data_tree(
424+
plot_config: FigureProperties,
425+
data_tree: xr.DataTree,
426+
data_tree2: xr.DataTree | None = None,
427+
interactive: bool = True,
428+
fig_title: str = 'Torax Simulation Results',
429+
) -> go.Figure:
430+
"""Plots a single run or comparison of two runs from in-memory DataTrees."""
431+
plotdata1 = _data_tree_to_plot_data(data_tree)
432+
plotdata2 = _data_tree_to_plot_data(data_tree2) if data_tree2 else None
433+
434+
datasets_to_check = [(plotdata1, 'data_tree')]
412435
if plotdata2 is not None:
413-
datasets_to_check.append((plotdata2, outfile2))
436+
datasets_to_check.append((plotdata2, 'data_tree2'))
414437

415438
for plotdata, filename in datasets_to_check:
416-
# Get the set of valid keys for this specific dataset
417439
available_vars = plotdata.available_variables()
418-
419440
for cfg in plot_config.axes:
420441
for attr in cfg.attrs:
421442
if attr not in available_vars:
@@ -424,8 +445,7 @@ def plot_run(
424445
f'output file: {filename}'
425446
)
426447

427-
title = plot_config.figure_title or _get_title_from_paths(outfile, outfile2)
428-
fig = create_plotly_figure(plot_config, plotdata1, plotdata2, title)
448+
fig = create_plotly_figure(plot_config, plotdata1, plotdata2, fig_title)
429449
if interactive:
430450
fig.show()
431451

0 commit comments

Comments
 (0)