Skip to content

Commit 83e360e

Browse files
committed
update pick dicts into arguments style
1 parent 685682c commit 83e360e

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

eqnet/data/seismic_trace.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,8 @@ def __init__(
398398
min_snr=3.0,
399399
stack_event=False,
400400
stack_strategy='same_sensor',
401+
picks_dict=None,
402+
events_dict=None,
401403
stack_noise=False,
402404
flip_polarity=False,
403405
drop_channel=False,
@@ -503,8 +505,10 @@ def __init__(
503505
elif self.data_path is not None:
504506
self.base_dir = data_path
505507

506-
self.picks_dict = pd.read_csv(os.path.join(self.base_dir, "picks_train.csv"), usecols=['event_id', 'station_id', 'snr', 'phase_status', 'instrument'])
507-
self.events_dict = pd.read_csv(os.path.join(self.base_dir, "events_train.csv"), usecols=['event_id', 'magnitude', 'event_time', 'depth_km'])
508+
picks_dict = os.path.join(self.base_dir, "picks_train.csv") if (picks_dict is None) else picks_dict
509+
events_dict = os.path.join(self.base_dir, "events_train.csv") if (events_dict is None) else events_dict
510+
self.picks_dict = pd.read_csv(picks_dict, usecols=['event_id', 'station_id', 'snr', 'phase_status', 'instrument'])
511+
self.events_dict = pd.read_csv(events_dict, usecols=['event_id', 'magnitude', 'event_time', 'depth_km'])
508512
self.events_dict['year'] = self.events_dict['event_time'].apply(lambda x: int(x[:4]))
509513
self.picks_dict['snr'] = self.picks_dict['snr'].apply(lambda x: np.array([float(number) for number in x.split()[1:-1]]))
510514
temp = self.picks_dict.groupby('event_id')['snr'].apply(lambda x: np.concatenate(x.values)).reset_index()

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,8 @@ def main(args):
379379
format=args.format, # ["h5", "hf"]
380380
training=True,
381381
stack_event=args.stack_event,
382+
picks_dict=args.picks_dict,
383+
events_dict=args.events_dict,
382384
stack_noise=args.stack_noise,
383385
flip_polarity=args.flip_polarity,
384386
drop_channel=args.drop_channel,
@@ -716,6 +718,8 @@ def get_args_parser(add_help=True):
716718
parser.add_argument("--test-label-list", default="+", type=None, help="test label path")
717719
parser.add_argument("--test-noise-list", default="+", type=None, help="test noise list")
718720
parser.add_argument("--test-hdf5-file", default=None, type=str, help="hdf5 file for testing")
721+
parser.add_argument("--picks-dict", default=None, type=str, help="picks dictionary for training augmentation")
722+
parser.add_argument("--events-dict", default=None, type=str, help="events dictionary for training augmentation")
719723
parser.add_argument("--dataset", default="", type=str, help="dataset name")
720724
parser.add_argument("--model", default="phasenet_das", type=str, help="model name")
721725
parser.add_argument("--backbone", default="unet", type=str, help="model backbone")

0 commit comments

Comments
 (0)