@@ -3210,52 +3210,56 @@ def test_mean_default_dtype(self):
32103210 # TODO FIXME: This is a bad test
32113211 f (data )
32123212
3213- @pytest .mark .slow
3214- def test_mean_custom_dtype (self ):
3213+ @pytest .mark .parametrize (
3214+ "input_dtype" ,
3215+ (
3216+ "bool" ,
3217+ "uint16" ,
3218+ "int8" ,
3219+ "int64" ,
3220+ "float16" ,
3221+ "float32" ,
3222+ "float64" ,
3223+ "complex64" ,
3224+ "complex128" ,
3225+ ),
3226+ )
3227+ @pytest .mark .parametrize (
3228+ "sum_dtype" ,
3229+ (
3230+ "bool" ,
3231+ "uint16" ,
3232+ "int8" ,
3233+ "int64" ,
3234+ "float16" ,
3235+ "float32" ,
3236+ "float64" ,
3237+ "complex64" ,
3238+ "complex128" ,
3239+ ),
3240+ )
3241+ @pytest .mark .parametrize ("axis" , [None , ()])
3242+ def test_mean_custom_dtype (self , input_dtype , sum_dtype , axis ):
32153243 # Test the ability to provide your own output dtype for a mean.
32163244
3217- # We try multiple axis combinations even though axis should not matter.
3218- axes = [None , 0 , 1 , [], [0 ], [1 ], [0 , 1 ]]
3219- idx = 0
3220- for input_dtype in map (str , ps .all_types ):
3221- x = matrix (dtype = input_dtype )
3222- for sum_dtype in map (str , ps .all_types ):
3223- axis = axes [idx % len (axes )]
3224- # If the inner sum cannot be created, it will raise a
3225- # TypeError.
3226- try :
3227- mean_var = x .mean (dtype = sum_dtype , axis = axis )
3228- except TypeError :
3229- pass
3230- else :
3231- # Executed if no TypeError was raised
3232- if sum_dtype in discrete_dtypes :
3233- assert mean_var .dtype == "float64" , (mean_var .dtype , sum_dtype )
3234- else :
3235- assert mean_var .dtype == sum_dtype , (mean_var .dtype , sum_dtype )
3236- if (
3237- "complex" in input_dtype or "complex" in sum_dtype
3238- ) and input_dtype != sum_dtype :
3239- continue
3240- f = function ([x ], mean_var )
3241- data = np .random .random ((3 , 4 )) * 10
3242- data = data .astype (input_dtype )
3243- # TODO FIXME: This is a bad test
3244- f (data )
3245- # Check that we can take the gradient, when implemented
3246- if "complex" in mean_var .dtype :
3247- continue
3248- try :
3249- grad (mean_var .sum (), x , disconnected_inputs = "ignore" )
3250- except NotImplementedError :
3251- # TrueDiv does not seem to have a gradient when
3252- # the numerator is complex.
3253- if mean_var .dtype in complex_dtypes :
3254- pass
3255- else :
3256- raise
3245+ x = matrix (dtype = input_dtype )
3246+ # If the inner sum cannot be created, it will raise a TypeError.
3247+ mean_var = x .mean (dtype = sum_dtype , axis = axis )
3248+ if sum_dtype in discrete_dtypes :
3249+ assert mean_var .dtype == "float64" , (mean_var .dtype , sum_dtype )
3250+ else :
3251+ assert mean_var .dtype == sum_dtype , (mean_var .dtype , sum_dtype )
32573252
3258- idx += 1
3253+ f = function ([x ], mean_var , mode = "FAST_COMPILE" )
3254+ data = np .ones ((2 , 1 )).astype (input_dtype )
3255+ if axis != ():
3256+ expected_res = np .array (2 ).astype (sum_dtype ) / 2
3257+ else :
3258+ expected_res = data
3259+ np .testing .assert_allclose (f (data ), expected_res )
3260+
3261+ if "complex" not in mean_var .dtype :
3262+ grad (mean_var .sum (), x , disconnected_inputs = "ignore" )
32593263
32603264 def test_mean_precision (self ):
32613265 # Check that the default accumulator precision is sufficient
0 commit comments