@@ -427,7 +427,6 @@ def display_gradients_simply(
427427 shot_ids : tuple [int , ...] = (0 ,),
428428 figsize : float = 5 ,
429429 fill_area : bool = True ,
430- show_signal : bool = True ,
431430 uni_signal : str | None = "gray" ,
432431 uni_gradient : str | None = None ,
433432 subfigure : plt .Figure | None = None ,
@@ -447,10 +446,6 @@ def display_gradients_simply(
447446 Fills the area under the curve for improved visibility and
448447 representation of the integral, aka trajectory.
449448 The default is `True`.
450- show_signal : bool, optional
451- Show an additional illustration of the signal as
452- the modulated distance to the center.
453- The default is `True`.
454449 uni_signal : str or None, optional
455450 Define whether the signal should be represented by a
456451 unique color given as argument or just by the default
@@ -471,13 +466,13 @@ def display_gradients_simply(
471466 Axes of the figure.
472467 """
473468 # Setup figure and labels
474- Nd = trajectory .shape [- 1 ]
469+ nb_axes = trajectory .shape [- 1 ] + 1
475470 if subfigure is None :
476- fig = plt .figure (figsize = (figsize , figsize * ( Nd + show_signal ) / Nd ))
471+ fig = plt .figure (figsize = (figsize , figsize * nb_axes / ( nb_axes - 1 ) ))
477472 else :
478473 fig = subfigure
479- axes = fig .subplots (Nd + show_signal , 1 )
480- for i , ax in enumerate (axes [:Nd ]):
474+ axes = fig .subplots (nb_axes , 1 )
475+ for i , ax in enumerate (axes [:nb_axes - 1 ]):
481476 ax .set_ylabel ("G{}" .format (["x" , "y" , "z" ][i ]), fontsize = displayConfig .fontsize )
482477 axes [- 1 ].set_xlabel ("Time" , fontsize = displayConfig .fontsize )
483478
@@ -489,50 +484,42 @@ def display_gradients_simply(
489484
490485 # Plot the curves for each axis
491486 gradients = np .diff (trajectory , axis = 1 )
492- vmax = 1.1 * np .max (np .abs (gradients [shot_ids , ...]))
493- x_axis = np .arange (gradients .shape [1 ])
487+ vmax = 1.1 * np .max (np .linalg .norm (gradients [shot_ids , ...], axis = - 1 , ord = 1 ))
488+ for ax in axes [:- 1 ]:
489+ ax .set_ylim ((- vmax , vmax ))
490+ axes [- 1 ].set_ylim (- 0.1 * vmax , vmax )
491+
492+ time_axis = np .arange (gradients .shape [1 ])
494493 colors = displayConfig .get_colorlist ()
495494 for j , s_id in enumerate (shot_ids ):
496- for i , ax in enumerate (axes [:Nd ]):
497- ax .set_ylim ((- vmax , vmax ))
498- color = (
499- uni_gradient
500- if uni_gradient is not None
501- else colors [j % displayConfig .nb_colors ]
502- )
503- ax .plot (x_axis , gradients [s_id , ..., i ], color = color )
495+ color = (
496+ uni_gradient
497+ if uni_gradient is not None
498+ else colors [j % displayConfig .nb_colors ]
499+ )
500+
501+ # Set each axis individually
502+ for i , ax in enumerate (axes [:- 1 ]):
503+ ax .plot (time_axis , gradients [s_id , ..., i ], color = color )
504504 if fill_area :
505505 ax .fill_between (
506- x_axis ,
506+ time_axis ,
507507 gradients [s_id , ..., i ],
508508 alpha = displayConfig .alpha ,
509509 color = color ,
510510 )
511511
512- # Return axes alone
513- if not show_signal :
514- return axes
515-
516- # Show signal as modulated distance to center
517- distances = np .linalg .norm (trajectory [shot_ids , 1 :- 1 ], axis = - 1 )
518- distances = np .tile (distances .reshape ((len (shot_ids ), - 1 , 1 )), (1 , 1 , 10 ))
519- signal = 1 - distances .reshape ((len (shot_ids ), - 1 )) / np .max (distances )
520- signal = (
521- signal * np .exp (2j * np .pi * figsize / 100 * np .arange (signal .shape [1 ]))
522- ).real
523- signal = signal * np .abs (signal ) ** 3
524-
525- colors = displayConfig .get_colorlist ()
526- # Show signal for each requested shot
527- axes [- 1 ].set_ylim ((- 1 , 1 ))
528- axes [- 1 ].set_ylabel ("Signal" , fontsize = displayConfig .fontsize )
529- for j in range (len (shot_ids )):
530- color = (
531- uni_signal
532- if (uni_signal is not None )
533- else colors [j % displayConfig .nb_colors ]
534- )
535- axes [- 1 ].plot (np .arange (signal .shape [1 ]), signal [j ], color = color )
512+ # Set the norm axis if requested
513+ gradient_norm = np .linalg .norm (gradients [s_id ], axis = - 1 )
514+ axes [- 1 ].set_ylabel ("|G|" , fontsize = displayConfig .fontsize )
515+ axes [- 1 ].plot (gradient_norm , color = color )
516+ if fill_area :
517+ axes [- 1 ].fill_between (
518+ time_axis ,
519+ gradient_norm ,
520+ alpha = displayConfig .alpha ,
521+ color = color ,
522+ )
536523 return axes
537524
538525
@@ -541,7 +528,7 @@ def display_gradients(
541528 shot_ids : tuple [int , ...] = (0 ,),
542529 figsize : float = 5 ,
543530 fill_area : bool = True ,
544- show_signal : bool = True ,
531+ show_norm : bool = True ,
545532 uni_signal : str | None = "gray" ,
546533 uni_gradient : str | None = None ,
547534 subfigure : plt .Figure | plt .Axes | None = None ,
@@ -567,7 +554,7 @@ def display_gradients(
567554 Fills the area under the curve for improved visibility and
568555 representation of the integral, aka trajectory.
569556 The default is `True`.
570- show_signal : bool, optional
557+ show_norm : bool, optional
571558 Show an additional illustration of the signal as
572559 the modulated distance to the center.
573560 The default is `True`.
@@ -619,7 +606,7 @@ def display_gradients(
619606 shot_ids ,
620607 figsize ,
621608 fill_area ,
622- show_signal ,
609+ show_norm ,
623610 uni_signal ,
624611 uni_gradient ,
625612 subfigure ,
@@ -633,7 +620,7 @@ def display_gradients(
633620 fontsize = displayConfig .small_fontsize ,
634621 )
635622 axes [- 1 ].set_xlabel ("Time (ms)" , fontsize = displayConfig .small_fontsize )
636- if show_signal :
623+ if show_norm :
637624 axes [- 1 ].set_ylabel ("Signal (a.u.)" , fontsize = displayConfig .small_fontsize )
638625
639626 # Update axis ticks with rescaled values
@@ -642,7 +629,7 @@ def display_gradients(
642629 if ax == axes [- 1 ]:
643630 ax .xaxis .set_tick_params (labelbottom = True )
644631 ticks = ax .get_xticks ()
645- scale = (0.1 if (show_signal and ax == axes [- 1 ]) else 1 ) * raster_time
632+ scale = (0.1 if (show_norm and ax == axes [- 1 ]) else 1 ) * raster_time
646633 locator = mticker .FixedLocator (ticks )
647634 formatter = mticker .FixedFormatter (np .around (scale * ticks , 2 ))
648635 ax .xaxis .set_major_locator (locator )
@@ -663,7 +650,7 @@ def display_gradients(
663650 scale = 1e3 * scale # Convert from T/m to mT/m
664651 locator = mticker .FixedLocator (ticks )
665652 formatter = mticker .FixedFormatter (np .around (scale * ticks , 1 ))
666- if not show_signal or ax != axes [- 1 ]:
653+ if not show_norm or ax != axes [- 1 ]:
667654 ax .yaxis .set_major_locator (locator )
668655 ax .yaxis .set_major_formatter (formatter )
669656
0 commit comments