diff --git a/.gitignore b/.gitignore index bfc34a2..8cb055b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ *.egg-info doc/_build/ .cache +*.pyc +.coverage +.pytest_cache diff --git a/pyasdf/asdf_data_set.py b/pyasdf/asdf_data_set.py index 9f9e716..394f823 100644 --- a/pyasdf/asdf_data_set.py +++ b/pyasdf/asdf_data_set.py @@ -37,7 +37,6 @@ import prov import prov.model - # Minimum compatibility wrapper between Python 2 and 3. try: filter = itertools.ifilter @@ -87,6 +86,7 @@ class ASDFDataSet(object): def __init__(self, filename, compression="gzip-3", shuffle=True, debug=False, mpi=None, mode="a", single_item_read_limit_in_mb=1024.0, + chunk_size=None, format_version=None): """ :type filename: str @@ -134,6 +134,8 @@ def __init__(self, filename, compression="gzip-3", shuffle=True, the interactive command line when just exploring an ASDF data set. There are other ways to still access data and even this setting can be overwritten. + :param chunk_size: Dataset chunk size in seconds. + :type chunk_size: tuple, None :type format_version: str :type format_version: The version of ASDF to use. If not given, it will use the most recent version (currently 1.0.1) if the @@ -144,6 +146,10 @@ def __init__(self, filename, compression="gzip-3", shuffle=True, "ASDF version '%s' is not supported. Supported versions: %s" % (format_version, ", ".join(SUPPORTED_FORMAT_VERSIONS))) + # Dataset chunk size in seconds. Set to True for auto-chunking; + # None for no chunking. + self.chunk_size = chunk_size + self.__force_mpi = mpi self.debug = debug @@ -246,6 +252,8 @@ def __init__(self, filename, compression="gzip-3", shuffle=True, self.__file.create_group("Provenance") if "AuxiliaryData" not in self.__file and mode != "r": self.__file.create_group("AuxiliaryData") + if "References" not in self.__file["AuxiliaryData"] and mode != "r": + self.__file.create_group("AuxiliaryData/References") # Easy access to the waveforms. self.waveforms = StationAccessor(self) @@ -362,6 +370,10 @@ def _provenance_group(self): def _auxiliary_data_group(self): return self.__file["AuxiliaryData"] + @property + def _reference_group(self): + return self.__file["AuxiliaryData/References"] + @property def asdf_format_version_in_file(self): """ @@ -806,13 +818,25 @@ def _get_waveform(self, waveform_name, starttime=None, endtime=None): self.single_item_read_limit_in_mb)) raise ASDFValueError(msg) + tr = self.__extract_waveform(waveform_name, idx_start, idx_end) + + return(tr) + + def __extract_waveform(self, waveform_name, idx_start, idx_end): + network, station, location, channel = waveform_name.split(".")[:4] channel = channel[:channel.find("__")] data = self.__file["Waveforms"]["%s.%s" % (network, station)][ waveform_name] - tr = obspy.Trace(data=data[idx_start: idx_end]) - tr.stats.starttime = data_starttime + if "mask" in data.attrs and data.attrs["mask"] != np.bool(False): + _data = np.ma.masked_values(data[idx_start: idx_end], + data.attrs["mask"]) + else: + _data = data[idx_start: idx_end] + + tr = obspy.Trace(data=_data) + tr.stats.starttime = obspy.UTCDateTime(data.attrs["starttime"]*1.E-9) tr.stats.sampling_rate = data.attrs["sampling_rate"] tr.stats.network = network tr.stats.station = station @@ -842,7 +866,7 @@ def _get_waveform(self, waveform_name, starttime=None, endtime=None): # Add the tag to the stats dictionary. details.tag = wf_name2tag(waveform_name) - return tr + return(tr) def _get_auxiliary_data(self, data_type, tag): group = self._auxiliary_data_group[data_type][tag] @@ -1054,7 +1078,8 @@ def _get_dataset_within_tolerance(station_group, trace): # If this did not work - append. self.add_waveforms(waveform=trace, tag=tag) - def add_waveforms(self, waveform, tag, event_id=None, origin_id=None, + def add_waveforms(self, waveform, tag, chunk_size=None, + event_id=None, origin_id=None, magnitude_id=None, focal_mechanism_id=None, provenance_id=None, labels=None): """ @@ -1068,6 +1093,10 @@ def add_waveforms(self, waveform, tag, event_id=None, origin_id=None, mandatory for all traces and facilitates identification of the data within one ASDF volume. The ``"raw_record"`` path is, by convention, reserved to raw, recorded, unprocessed data. + :param chunk_size: Dataset chunk size in seconds. This + overrides the default class value specified at object + instatiation. + :type chunk_size: tuple, bool :type tag: str :param event_id: The event or id which the waveform is associated with. This is useful for recorded data if a clear association is @@ -1165,13 +1194,19 @@ def add_waveforms(self, waveform, tag, event_id=None, origin_id=None, tag = self.__parse_and_validate_tag(tag) waveform = self.__parse_waveform_input_and_validate(waveform) + chunk_size = chunk_size if chunk_size is not None\ + else self.chunk_size if self.chunk_size is not None \ + else None + # Actually add the data. for trace in waveform: + if isinstance(trace.data, np.ma.masked_array): + self.__set_masked_array_fill_value(trace) # Complicated multi-step process but it enables one to use # parallel I/O with the same functions. info = self._add_trace_get_collective_information( - trace, tag, event_id=event_id, origin_id=origin_id, - magnitude_id=magnitude_id, + trace, tag, chunk_size=chunk_size, event_id=event_id, + origin_id=origin_id, magnitude_id=magnitude_id, focal_mechanism_id=focal_mechanism_id, provenance_id=provenance_id, labels=labels) if info is None: @@ -1179,6 +1214,16 @@ def add_waveforms(self, waveform, tag, event_id=None, origin_id=None, self._add_trace_write_collective_information(info) self._add_trace_write_independent_information(info, trace) + def __set_masked_array_fill_value(self, trace): + if trace.data.dtype.kind in ("i", "u"): + _info = np.iinfo + elif trace.data.dtype.kind == "f": + _info = np.finfo + else: + raise(NotImplementedError("fill value for dtype %s not defined" + % trace.data.dtype)) + trace.data.set_fill_value(_info(trace.data.dtype).min) + def __parse_and_validate_tag(self, tag): tag = tag.strip() if tag.lower() == "stationxml": @@ -1225,6 +1270,314 @@ def __parse_waveform_input_and_validate(self, waveform): raise NotImplementedError return waveform + def create_reference(self, ref, starttime, endtime, net=None, sta=None, + loc=None, chan=None, tag=None, overwrite=False): + """ + Creates a reference for fast lookup of data segments. + + :param ref: The reference label to apply. + :type ref: str + :param starttime: Start time of reference. + :type starttime: :class:`obspy.core.utcdatetime.UTCDateTime` + :param endtime: End time of reference. + :type endtime: :class:`obspy.core.utcdatetime.UTCDateTime` + :param net: Networks to create references for. + :type net: str, tuple + :param sta: Stations to create references for. + :type sta: str, tuple + :param loc: Location codes to create references for. + :type loc: str, tuple + :param chan: Channels to create references for. + :type chan: str, tuple + :param tag: Tag to create references for. + :type tag: str, tuple + :param overwrite: Overwrite existing references for this label. + :type overwrite: bool + + This methodology is useful for creating subsets of a dataset + without duplicating waveforms. + + .. rubric:: Example + + Consider an ASDFDataSet populated with continuous waveforms for + stations from two networks (AA and BB): + + - AA.XXX + - AA.YYY + - AA.ZZZ + - BB.UUU + - BB.VVV + - BB.WWW + + It may be useful to process event-segmented waveforms, where + a one-minute window of data is needed. We can create references + to these windowed data segments for fast extraction: + + .. code-block:: python + + >>> ds.create_reference("event000001", + ... obspy.UTCDateTime("2016001T01:00:00"), + ... obspy.UTCDateTime("2016001T01:01:00")) + >>> ds.get_data_for_reference("event000001") + 18 Trace(s) in Stream: + AA.XXX..HHZ | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (16 other traces) + ... + BB.WWW..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + Or perhaps we only want to include data from network AA in the + referenced data set: + + .. code-block:: python + + >>> ds.create_reference("event000001", + ... obspy.UTCDateTime("2016001T01:00:00"), + ... obspy.UTCDateTime("2016001T01:01:00"), + ... net="AA", + ... overwrite=True) + >>> ds.get_data_for_reference("event000001") + 9 Trace(s) in Stream: + AA.XXX..HHZ | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (7 other traces) + ... + AA.ZZZ..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + Or only horizontal component data: + + .. code-block:: python + + >>> ds.create_reference("event000001", + ... obspy.UTCDateTime("2016001T01:00:00"), + ... obspy.UTCDateTime("2016001T01:01:00"), + ... chan=("HHN", "HHE"), + ... overwrite=True) + >>> ds.get_data_for_reference("event000001") + 12 Trace(s) in Stream: + AA.XXX..HHN | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (10 other traces) + ... + BB.WWW..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + etc... + """ + + def _coerce(obj): + if isinstance(obj, str): + obj = (obj,) + elif isinstance(obj, tuple)\ + or isinstance(obj, list)\ + or obj is None: + pass + else: + try: + if isinstance(obj, unicode): + obj = (obj,) + else: + raise(TypeError(obj)) + except NameError: + pass + return(obj) + + net = _coerce(net) + sta = _coerce(sta) + loc = _coerce(loc) + chan = _coerce(chan) + tag = _coerce(tag) + + def _predicate_net(_key): + return(net is None or _key.split(".")[0] in net) + + def _predicate_sta(_key): + return(sta is None or _key.split(".")[1] in sta) + + def _predicate_loc(_key): + return(loc is None or _key.split(".")[2] in loc) + + def _predicate_chan(_key): + return(chan is None or _key.split(".")[-1].split("__")[0] in chan) + + def _predicate_tag(_key): + return(tag is None or _key.split(".")[-1].split("__")[-1] in tag) + + def _predicate_netsta(_key): + return(_predicate_net(_key) and _predicate_sta(_key)) + + def _predicate_locchantag(_key): + return(_predicate_loc(_key) + and _predicate_chan(_key) + and _predicate_tag(_key)) + + _wf_grp = self._waveform_group + for _station_name in filter(_predicate_netsta, + self._waveform_group.keys()): + for waveform_name in filter(_predicate_locchantag, + _wf_grp[_station_name].keys()): + _net, _sta, _loc, _chan = waveform_name.split("__")[0].split(".") + + _ds = self._waveform_group["%s/%s" % (_station_name, waveform_name)] + + _ts = obspy.UTCDateTime(_ds.attrs["starttime"]*1e-9) + _samprate = _ds.attrs["sampling_rate"] + _te = _ts + len(_ds)/_samprate + if _te < starttime or _ts > endtime: + continue + + _offset = int((starttime-_ts)*_samprate) + _nsamp = int(round((endtime-starttime)*_samprate, 0)) + idx_start = _offset if _offset >= 0 else 0 + idx_end = _offset + _nsamp + + if ref not in self._reference_group: + _ref_grp = self._reference_group.create_group(ref) + else: + _ref_grp = self._reference_group[ref] + + _net = "__" if _net == "" else _net + _sta = "__" if _sta == "" else _sta + _loc = "__" if _loc == "" else _loc + _chan = "__" if _chan == "" else _chan + _handle = "/".join((_net, _sta, _loc, _chan)) + + if overwrite is True and _handle in _ref_grp: + del(_ref_grp[_handle]) + + if _handle not in _ref_grp: + _ref = _ref_grp.create_dataset(_handle, + (2,), + dtype=np.int64) + _ref.attrs["waveform_name"] = waveform_name + _ref.attrs["sampling_rate"] = _ds.attrs["sampling_rate"] + _ts = _ds.attrs["starttime"] + int(_offset/_samprate*1.e9) + _ref.attrs["starttime"] = _ts + _ref[:] = [idx_start, idx_end] + else: + print("Will not overwrite existing reference") + continue + + def get_data_for_reference(self, ref, net=None, sta=None, loc=None, + chan=None): + """ + Retrieve referenced data. + + :param ref: Reference label. + :type ref: str + :param net: Networks to retrieve referenced data for. + :type net: str, tuple + :param sta: Stations to retrieve referenced data for. + :type sta: str, tuple + :param loc: Location codes to retrieve referenced data for. + :type loc: str, tuple + :param chan: Channels to retrieve referenced data for. + :type chan: str, tuple + :returns: Referenced data. + :rtype: :class:`~obspy.core.stream.Stream` + + .. rubric:: Example + + Consider an ASDFDataSet with references pointing to event-segmented + waveforms (see :func:`create_reference`). We can retrieve data + for a particular reference label: + + .. code-block:: python + + >>> ds.get_data_for_reference("event000001") + 18 Trace(s) in Stream: + AA.XXX..HHZ | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (16 other traces) + ... + BB.WWW..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + Or for only the BB network: + + .. code-block:: python + + >>> ds.get_data_for_reference("event000001", + ... net="BB") + 9 Trace(s) in Stream: + BB.UUU..HHZ | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (7 other traces) + ... + BB.WWW..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + Or for only horizontal components: + + .. code-block:: python + + >>> ds.get_data_for_reference("event000001", + ... chan=("HHN","HHE")) + 12 Trace(s) in Stream: + AA.XXX..HHN | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + ... + (10 other traces) + ... + BB.WWW..HHE | 2016-01-01T01:00:00.00Z ... | 100.0 Hz, 6001 samples + + etc... + """ + if ref not in self._reference_group: + raise(IOError("reference does not exist: %s" % ref)) + + def _coerce(obj): + if isinstance(obj, str): + obj = (obj,) + elif isinstance(obj, tuple)\ + or isinstance(obj, list)\ + or obj is None: + pass + else: + try: + if isinstance(obj, unicode): + obj = (obj,) + else: + raise(TypeError(obj)) + except NameError: + pass + return(obj) + + net = _coerce(net) + sta = _coerce(sta) + loc = _coerce(loc) + chan = _coerce(chan) + + def _predicate_net(_key): + return(net is None or _key in net) + + def _predicate_sta(_key): + return(sta is None or _key in sta) + + def _predicate_loc(_key): + return(loc is None or _key in loc) + + def _predicate_chan(_key): + return(chan is None or _key in chan) + + _st = obspy.Stream() + _ref_grp = self._reference_group[ref] + for _net in filter(_predicate_net, _ref_grp.keys()): + _net_grp = _ref_grp[_net] + + for _sta in filter(_predicate_sta, _net_grp.keys()): + _sta_grp = _net_grp[_sta] + + for _loc in filter(_predicate_loc, _sta_grp.keys()): + _loc_grp = _sta_grp[_loc] + + for _chan in filter(_predicate_chan, _loc_grp.keys()): + _ref = _loc_grp[_chan] + + waveform_name = _ref.attrs["waveform_name"] + idx_start, idx_end = _ref[:] + _tr = self.__extract_waveform(waveform_name, + idx_start, + idx_end) + _st.append(_tr) + return(_st) + def get_provenance_document(self, document_name): """ Retrieve a provenance document with a certain name. @@ -1295,7 +1648,7 @@ def _add_trace_write_independent_information(self, info, trace): :param trace: :return: """ - self._waveform_group[info["data_name"]][:] = trace.data + self._waveform_group[info["data_name"]][:] = np.ma.filled(trace.data) def _add_trace_write_collective_information(self, info): """ @@ -1330,7 +1683,7 @@ def __get_waveform_ds_name(self, net, sta, loc, cha, start, end, tag): tag=tag) def _add_trace_get_collective_information( - self, trace, tag, event_id=None, origin_id=None, + self, trace, tag, chunk_size=None, event_id=None, origin_id=None, magnitude_id=None, focal_mechanism_id=None, provenance_id=None, labels=None): """ @@ -1354,6 +1707,11 @@ def _add_trace_get_collective_information( loc=trace.stats.location, cha=trace.stats.channel, start=trace.stats.starttime, end=trace.stats.endtime, tag=tag) + if chunk_size is None or chunk_size is True: + chunks = chunk_size + else: + chunks = (int(round(chunk_size * trace.stats.sampling_rate, 0)),) + group_name = "%s/%s" % (station_name, data_name) if group_name in self._waveform_group: msg = "Data '%s' already exists in file. Will not be added!" % \ @@ -1367,6 +1725,12 @@ def _add_trace_get_collective_information( else: fletcher32 = True + # Determine appropriate mask value. + if not isinstance(trace.data, np.ma.masked_array): + _mask = np.bool(False) + else: + _mask = trace.data.fill_value + info = { "station_name": station_name, "data_name": group_name, @@ -1378,13 +1742,15 @@ def _add_trace_get_collective_information( "compression_opts": self.__compression[1], "shuffle": self.__shuffle, "fletcher32": fletcher32, - "maxshape": (None,) + "maxshape": (None,), + "chunks": chunks }, "dataset_attrs": { # Starttime is the epoch time in nanoseconds. "starttime": int(round(trace.stats.starttime.timestamp * 1.0E9)), - "sampling_rate": trace.stats.sampling_rate + "sampling_rate": trace.stats.sampling_rate, + "mask": _mask } } diff --git a/pyasdf/tests/test_asdf_data_set.py b/pyasdf/tests/test_asdf_data_set.py index a9a59cb..6831a89 100644 --- a/pyasdf/tests/test_asdf_data_set.py +++ b/pyasdf/tests/test_asdf_data_set.py @@ -87,6 +87,56 @@ def test_waveform_tags_attribute(tmpdir): assert data_set.waveform_tags == expected +def test_reference_creation(tmpdir): + asdf_filename = os.path.join(tmpdir.strpath, "test.h5") + data_path = os.path.join(data_dir, "small_sample_data_set") + + data_set = ASDFDataSet(asdf_filename) + + for filename in glob.glob(os.path.join(data_path, "*.mseed")): + data_set.add_waveforms(filename, tag="raw") + + data_set.create_reference("ref1", + obspy.UTCDateTime("2013-05-24T05:50:00"), + obspy.UTCDateTime("2013-05-24T05:55:00"), + net="AE") + st = data_set.get_data_for_reference("ref1") + assert len(st) == 3 + for tr in st: + assert tr.stats.network == "AE" + + data_set.create_reference("ref2", + obspy.UTCDateTime("2013-05-24T05:50:00"), + obspy.UTCDateTime("2013-05-24T05:55:00"), + chan="BHZ") + st = data_set.get_data_for_reference("ref2") + assert len(st) == 2 + for tr in st: + assert tr.stats.channel == "BHZ" + + data_set.create_reference("ref3", + obspy.UTCDateTime("2013-05-24T05:50:00"), + obspy.UTCDateTime("2013-05-24T05:55:00"), + chan=("BHN", "BHE")) + st = data_set.get_data_for_reference("ref3") + assert len(st) == 4 + for tr in st: + assert tr.stats.channel in ("BHN", "BHE") + + data_set.create_reference("ref4", + obspy.UTCDateTime("2013-05-24T05:50:00"), + obspy.UTCDateTime("2013-05-24T05:55:00"), + net="TA", + sta=("POKR",), + chan="BHZ") + st = data_set.get_data_for_reference("ref4") + assert len(st) == 1 + for tr in st: + assert tr.stats.channel == "BHZ"\ + and tr.stats.station == "POKR"\ + and tr.stats.network == "TA" + + def test_data_set_creation(tmpdir): """ Test data set creation with a small test dataset. @@ -144,6 +194,60 @@ def test_data_set_creation(tmpdir): assert cat_file == cat_asdf +def test_masked_data_creation(tmpdir): + asdf_filename = os.path.join(tmpdir.strpath, "test.h5") + data_path = os.path.join(data_dir, "small_sample_data_set") + + data_set = ASDFDataSet(asdf_filename) + + filename = os.path.join(data_path, "AE.113A..BHZ.mseed") + + ts1 = obspy.UTCDateTime("2013-05-24T05:40:00") + te1 = obspy.UTCDateTime("2013-05-24T06:00:00") + ts2 = obspy.UTCDateTime("2013-05-24T06:10:00") + te2 = obspy.UTCDateTime("2013-05-24T06:50:00") + + st_file_raw = obspy.read(filename) + + st_file_masked = st_file_raw.copy().trim(starttime=ts1, endtime=te1)\ + + st_file_raw.copy().trim(starttime=ts2, endtime=te2) + st_file_masked.merge() + + # This will cast dtype from int to float + st_file_masked_filtered = st_file_masked.copy() + st_file_masked_filtered = st_file_masked_filtered.split() + st_file_masked_filtered.filter("bandpass", freqmin=0.1, freqmax=10) + st_file_masked_filtered.merge() + + data_set.add_waveforms(st_file_masked, tag="masked") + data_set.add_waveforms(st_file_masked_filtered, tag="masked_filtered") + + st_asdf_masked = data_set.waveforms["AE.113A"]["masked"] + st_asdf_masked_filtered = data_set.waveforms["AE.113A"]["masked_filtered"] + + trfm = st_file_masked[0] + trfmf = st_file_masked_filtered[0] + tram = st_asdf_masked[0] + tramf = st_asdf_masked_filtered[0] + + for tr in (trfm, trfmf): + del(tr.stats.mseed) + del(tr.stats._format) + del(tr.stats.processing) + + for tr in (tram, tramf): + del(tr.stats.asdf) + del(tr.stats._format) + + assert trfm.stats == tram.stats + assert all(trfm.data.mask == tram.data.mask) + assert all(trfm.data[~trfm.data.mask] == tram.data[~tram.data.mask]) + + assert trfmf.stats == tramf.stats + assert all(trfmf.data.mask == tramf.data.mask) + assert all(trfmf.data[~trfmf.data.mask] == tramf.data[~tramf.data.mask]) + + def test_equality_checks(example_data_set): """ Tests the equality operations. @@ -3020,6 +3124,7 @@ def test_get_waveform_attributes(example_data_set): 'event_ids': [ 'smi:service.iris.edu/fdsnws/event/1/query?' 'eventid=4218658'], + 'mask': np.bool(False), 'sampling_rate': 40.0, 'starttime': 1369374000000000000}, 'AE.113A..BHN__2013-05-24T05:40:00__' @@ -3027,6 +3132,7 @@ def test_get_waveform_attributes(example_data_set): 'event_ids': [ 'smi:service.iris.edu/fdsnws/event/1/query?' 'eventid=4218658'], + 'mask': np.bool(False), 'sampling_rate': 40.0, 'starttime': 1369374000000000000}, 'AE.113A..BHZ__2013-05-24T05:40:00__' @@ -3034,6 +3140,7 @@ def test_get_waveform_attributes(example_data_set): 'event_ids': [ 'smi:service.iris.edu/fdsnws/event/1/query?' 'eventid=4218658'], + 'mask': np.bool(False), 'sampling_rate': 40.0, 'starttime': 1369374000000000000} }