11# -*- coding: utf-8 -*-
22
3- import numpy as np
43import datetime
4+
5+ import numpy as np
56import pytest
6- import pysteps
7- from pysteps import cascade , blending
87
8+ import pysteps
9+ from pysteps import blending , cascade
910
1011steps_arg_values = [
1112 (1 , 3 , 4 , 8 , None , None , False , "spn" , True , 4 , False , False , 0 , False ),
1415 (1 , 3 , 4 , 8 , None , "mean" , False , "spn" , True , 4 , False , False , 0 , False ),
1516 (1 , 3 , 4 , 8 , None , "mean" , False , "spn" , True , 4 , False , False , 0 , True ),
1617 (1 , 3 , 4 , 8 , None , "cdf" , False , "spn" , True , 4 , False , False , 0 , False ),
18+ (1 , [1 , 2 , 3 ], 4 , 8 , None , "cdf" , False , "spn" , True , 4 , False , False , 0 , False ),
1719 (1 , 3 , 4 , 8 , "incremental" , "cdf" , False , "spn" , True , 4 , False , False , 0 , False ),
1820 (1 , 3 , 4 , 6 , "incremental" , "cdf" , False , "bps" , True , 4 , False , False , 0 , False ),
1921 (1 , 3 , 4 , 6 , "incremental" , "cdf" , False , "bps" , False , 4 , False , False , 0 , False ),
4244 (1 , 3 , 6 , 8 , None , None , False , "spn" , True , 6 , False , True , 80 , False ),
4345 (5 , 3 , 5 , 6 , "incremental" , "cdf" , False , "spn" , False , 5 , True , False , 80 , True ),
4446 (5 , 3 , 5 , 6 , "obs" , "mean" , False , "spn" , False , 5 , True , True , 80 , False ),
47+ (5 , [1 , 2 , 3 ], 5 , 6 , "obs" , "mean" , False , "spn" , False , 5 , True , True , 80 , False ),
48+ (5 , [1 , 3 ], 5 , 6 , "obs" , "mean" , False , "spn" , False , 5 , True , True , 80 , False ),
4549]
4650
4751steps_arg_names = (
4852 "n_models" ,
49- "n_timesteps " ,
53+ "timesteps " ,
5054 "n_ens_members" ,
5155 "n_cascade_levels" ,
5256 "mask_method" ,
6569@pytest .mark .parametrize (steps_arg_names , steps_arg_values )
6670def test_steps_blending (
6771 n_models ,
68- n_timesteps ,
72+ timesteps ,
6973 n_ens_members ,
7074 n_cascade_levels ,
7175 mask_method ,
@@ -85,7 +89,14 @@ def test_steps_blending(
8589 # The input data
8690 ###
8791 # Initialise dummy NWP data
88- nwp_precip = np .zeros ((n_models , n_timesteps + 1 , 200 , 200 ))
92+ if not isinstance (timesteps , int ):
93+ n_timesteps = len (timesteps )
94+ last_timestep = timesteps [- 1 ]
95+ else :
96+ n_timesteps = timesteps
97+ last_timestep = timesteps
98+
99+ nwp_precip = np .zeros ((n_models , last_timestep + 1 , 200 , 200 ))
89100
90101 if not zero_nwp :
91102 for n_model in range (n_models ):
@@ -250,7 +261,7 @@ def test_steps_blending(
250261 precip_models = nwp_precip_decomp ,
251262 velocity = radar_velocity ,
252263 velocity_models = nwp_velocity ,
253- timesteps = n_timesteps ,
264+ timesteps = timesteps ,
254265 timestep = 5.0 ,
255266 issuetime = datetime .datetime .strptime ("202112012355" , "%Y%m%d%H%M" ),
256267 n_ens_members = n_ens_members ,
0 commit comments