@@ -520,23 +520,24 @@ def test_definition(self, xp):
520520
521521 # default dtype varies across backends
522522
523- y = 9 * fft .fftfreq (9 , xp = xp )
523+ wrapped_xp = array_namespace (x )
524+ y = 9 * fft .fftfreq (9 , xp = wrapped_xp )
524525 xp_assert_close (y , x , check_dtype = False , check_namespace = True )
525526
526- y = 9 * xp .pi * fft .fftfreq (9 , xp .pi , xp = xp )
527+ y = 9 * xp .pi * fft .fftfreq (9 , xp .pi , xp = wrapped_xp )
527528 xp_assert_close (y , x , check_dtype = False )
528529
529- y = 10 * fft .fftfreq (10 , xp = xp )
530+ y = 10 * fft .fftfreq (10 , xp = wrapped_xp )
530531 xp_assert_close (y , x2 , check_dtype = False )
531532
532- y = 10 * xp .pi * fft .fftfreq (10 , xp .pi , xp = xp )
533+ y = 10 * xp .pi * fft .fftfreq (10 , xp .pi , xp = wrapped_xp )
533534 xp_assert_close (y , x2 , check_dtype = False )
534535
535536 def test_device (self , xp ):
536537 xp_test = array_namespace (xp .empty (0 ))
537538 devices = get_xp_devices (xp )
538539 for d in devices :
539- y = fft .fftfreq (9 , xp = xp , device = d )
540+ y = fft .fftfreq (9 , xp = xp_test , device = d )
540541 x = xp_test .empty (0 , device = d )
541542 assert xp_device (y ) == xp_device (x )
542543
@@ -552,23 +553,23 @@ def test_definition(self, xp):
552553 x2 = xp .asarray ([0 , 1 , 2 , 3 , 4 , 5 ], dtype = xp .float64 )
553554
554555 # default dtype varies across backends
555-
556- y = 9 * fft .rfftfreq (9 , xp = xp )
556+ wrapped_xp = array_namespace ( x )
557+ y = 9 * fft .rfftfreq (9 , xp = wrapped_xp )
557558 xp_assert_close (y , x , check_dtype = False , check_namespace = True )
558559
559- y = 9 * xp .pi * fft .rfftfreq (9 , xp .pi , xp = xp )
560+ y = 9 * xp .pi * fft .rfftfreq (9 , xp .pi , xp = wrapped_xp )
560561 xp_assert_close (y , x , check_dtype = False )
561562
562- y = 10 * fft .rfftfreq (10 , xp = xp )
563+ y = 10 * fft .rfftfreq (10 , xp = wrapped_xp )
563564 xp_assert_close (y , x2 , check_dtype = False )
564565
565- y = 10 * xp .pi * fft .rfftfreq (10 , xp .pi , xp = xp )
566+ y = 10 * xp .pi * fft .rfftfreq (10 , xp .pi , xp = wrapped_xp )
566567 xp_assert_close (y , x2 , check_dtype = False )
567568
568569 def test_device (self , xp ):
569570 xp_test = array_namespace (xp .empty (0 ))
570571 devices = get_xp_devices (xp )
571572 for d in devices :
572- y = fft .rfftfreq (9 , xp = xp , device = d )
573+ y = fft .rfftfreq (9 , xp = xp_test , device = d )
573574 x = xp_test .empty (0 , device = d )
574575 assert xp_device (y ) == xp_device (x )
0 commit comments