| 
15 | 15 |     ("mch", True),  | 
16 | 16 | ]  | 
17 | 17 | 
 
  | 
 | 18 | +arg_names_multistep = ("source", "len_timesteps")  | 
 | 19 | +arg_values_multistep = [  | 
 | 20 | +    ("mch", 6),  | 
 | 21 | +]  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +@pytest.mark.parametrize(arg_names_multistep, arg_values_multistep)  | 
 | 25 | +def test_tracking_tdating_dating_multistep(source, len_timesteps):  | 
 | 26 | +    pytest.importorskip("skimage")  | 
 | 27 | + | 
 | 28 | +    input_fields, metadata = get_precipitation_fields(  | 
 | 29 | +        0, len_timesteps, True, True, 4000, source  | 
 | 30 | +    )  | 
 | 31 | +    input_fields, __ = to_reflectivity(input_fields, metadata)  | 
 | 32 | + | 
 | 33 | +    timelist = metadata["timestamps"]  | 
 | 34 | + | 
 | 35 | +    # First half of timesteps  | 
 | 36 | +    tracks_1, cells, labels = dating(  | 
 | 37 | +        input_fields[0 : len_timesteps // 2],  | 
 | 38 | +        timelist[0 : len_timesteps // 2],  | 
 | 39 | +        mintrack=1,  | 
 | 40 | +    )  | 
 | 41 | +    # Second half of timesteps  | 
 | 42 | +    tracks_2, cells, _ = dating(  | 
 | 43 | +        input_fields[len_timesteps // 2 - 2 :],  | 
 | 44 | +        timelist[len_timesteps // 2 - 2 :],  | 
 | 45 | +        mintrack=1,  | 
 | 46 | +        start=2,  | 
 | 47 | +        cell_list=cells,  | 
 | 48 | +        label_list=labels,  | 
 | 49 | +    )  | 
 | 50 | + | 
 | 51 | +    # Since we are adding cells, number of tracks should increase  | 
 | 52 | +    assert len(tracks_1) <= len(tracks_2)  | 
 | 53 | + | 
 | 54 | +    # Tracks should be continuous in time so time difference should not exceed timestep  | 
 | 55 | +    max_track_step = max([t.time.diff().max().seconds for t in tracks_2 if len(t) > 1])  | 
 | 56 | +    timestep = np.diff(timelist).max().seconds  | 
 | 57 | +    assert max_track_step <= timestep  | 
 | 58 | + | 
 | 59 | +    # IDs of unmatched cells should increase in every timestep  | 
 | 60 | +    for prev_df, cur_df in zip(cells[:-1], cells[1:]):  | 
 | 61 | +        prev_ids = set(prev_df.ID)  | 
 | 62 | +        cur_ids = set(cur_df.ID)  | 
 | 63 | +        new_ids = list(cur_ids - prev_ids)  | 
 | 64 | +        prev_unmatched = list(prev_ids - cur_ids)  | 
 | 65 | +        if len(prev_unmatched):  | 
 | 66 | +            assert np.all(np.array(new_ids) > max(prev_unmatched))  | 
 | 67 | + | 
18 | 68 | 
 
  | 
19 | 69 | @pytest.mark.parametrize(arg_names, arg_values)  | 
20 | 70 | def test_tracking_tdating_dating(source, dry_input):  | 
 | 
0 commit comments