1111
1212from unyt import unyt_quantity , unyt_array , matplotlib_support
1313from unyt .exceptions import UnitConversionError
14- from numpy import log10 , linspace , logspace , array , logical_and
14+ from numpy import log10 , linspace , logspace , array , logical_and , ones
1515from matplotlib .pyplot import Axes , Figure , close
1616from yaml import safe_load
1717from typing import Union , List , Dict , Tuple
@@ -95,6 +95,8 @@ class VelociraptorPlot(object):
9595 observational_data_filenames : List [str ]
9696 observational_data_bracket_width : float
9797 observational_data_directory : str
98+ # global mask
99+ global_mask : Union [None , array ]
98100
99101 def __init__ (
100102 self ,
@@ -743,7 +745,7 @@ def _add_lines_to_axes(self, ax: Axes, x: unyt_array, y: unyt_array) -> None:
743745 return
744746
745747 def get_quantity_from_catalogue_with_mask (
746- self , quantity : str , catalogue : VelociraptorCatalogue
748+ self , quantity : str , catalogue : VelociraptorCatalogue ,
747749 ) -> unyt_array :
748750 """
749751 Get a quantity from the catalogue using the mask.
@@ -753,62 +755,48 @@ def get_quantity_from_catalogue_with_mask(
753755 # We give each dataset a custom name, that gets ruined when masking
754756 # in versions of unyt less than 2.6.0
755757 name = x .name
756-
758+
757759 if self .structure_mask is not None :
758- x = x [self .structure_mask ]
760+ # if structure_mask already set, mask and return
761+ x_mask = logical_and (self .global_mask , self .structure_mask )
762+ x = x [x_mask ]
759763 x .name = name
760- elif self .selection_mask is not None :
764+ return x
765+
766+ # allow all entries by default
767+ self .structure_mask = ones (x .shape ).astype (bool )
768+
769+ if self .selection_mask is not None :
761770 # Create mask
762771 self .structure_mask = reduce (
763772 getattr , self .selection_mask .split ("." ), catalogue
764773 ).astype (bool )
765-
766- if self .select_structure_type is not None :
767- if self .select_structure_type == self .exclude_structure_type :
768- raise AutoPlotterError (
769- f"Cannot simultaneously select and exclude structure"
770- " type {self.select_structure_type}"
771- )
772- self .structure_mask = logical_and (
773- self .structure_mask ,
774- catalogue .structure_type .structuretype
775- == self .select_structure_type ,
776- )
777-
778- elif self .exclude_structure_type is not None :
779- self .structure_mask = logical_and (
780- self .structure_mask ,
781- catalogue .structure_type .structuretype
782- != self .exclude_structure_type ,
783- )
784-
785- x = x [self .structure_mask ]
786- x .name = name
787- elif self .select_structure_type is not None :
774+ if self .select_structure_type is not None :
788775 if self .select_structure_type == self .exclude_structure_type :
789776 raise AutoPlotterError (
790777 f"Cannot simultaneously select and exclude structure"
791778 " type {self.select_structure_type}"
792779 )
793-
794- # Need to create mask
795- self . structure_mask = (
796- catalogue . structure_type . structuretype == self .select_structure_type
780+ self . structure_mask = logical_and (
781+ self . structure_mask ,
782+ catalogue . structure_type . structuretype
783+ == self .select_structure_type ,
797784 )
798-
799- x = x [self .structure_mask ]
800- x .name = name
801- elif self .exclude_structure_type is not None :
802- # Need to create mask
803- self .structure_mask = (
804- catalogue .structure_type .structuretype != self .exclude_structure_type
785+ if self .exclude_structure_type is not None :
786+ self .structure_mask = logical_and (
787+ self .structure_mask ,
788+ catalogue .structure_type .structuretype
789+ != self .exclude_structure_type ,
805790 )
791+
792+ # combine global and structure masks
793+ x_mask = logical_and (self .global_mask , self .structure_mask )
806794
807- x = x [ self . structure_mask ]
808- x . name = name
809-
795+ # apply to the unyt array of values
796+ x = x [ x_mask ]
797+ x . name = name
810798 return x
811-
799+
812800 def _make_plot_scatter (
813801 self , catalogue : VelociraptorCatalogue
814802 ) -> Tuple [Figure , Axes ]:
@@ -974,7 +962,7 @@ def _make_plot_cumulative_histogram(
974962 return fig , ax
975963
976964 def make_plot (
977- self , catalogue : VelociraptorCatalogue , directory : str , file_extension : str
965+ self , catalogue : VelociraptorCatalogue , directory : str , file_extension : str ,
978966 ):
979967 """
980968 Federates out data parsing to individual functions based on the
@@ -1058,7 +1046,9 @@ class AutoPlotter(object):
10581046 observational_data_directory : str
10591047 # Whether or not the plots were created successfully.
10601048 created_successfully : List [bool ]
1061-
1049+ # global mask
1050+ global_mask : Union [None , array ]
1051+
10621052 def __init__ (
10631053 self ,
10641054 filename : Union [str , List [str ]],
@@ -1123,14 +1113,18 @@ def parse_yaml(self):
11231113
11241114 return
11251115
1126- def link_catalogue (self , catalogue : VelociraptorCatalogue ):
1116+ def link_catalogue (self , catalogue : VelociraptorCatalogue , global_mask_tag : Union [ None , str ] ):
11271117 """
11281118 Links a catalogue with this object so that the plots
11291119 can actually be created.
11301120 """
11311121
11321122 self .catalogue = catalogue
11331123
1124+ if global_mask_tag is not None :
1125+ self .global_mask = reduce (getattr , global_mask_tag .split ("." ), catalogue )
1126+ else :
1127+ self .global_mask = True
11341128 return
11351129
11361130 def create_plots (
@@ -1150,6 +1144,7 @@ def create_plots(
11501144
11511145 for plot in self .plots :
11521146 try :
1147+ plot .global_mask = self .global_mask
11531148 plot .make_plot (
11541149 catalogue = self .catalogue ,
11551150 directory = directory ,
0 commit comments