@@ -372,3 +372,31 @@ def test_backward_compat_mixed(self):
372372 assert '7.0*f(x + 3*h_x)' in str (eqe .rhs )
373373 assert '0.5*g(x + h_x)' in str (eqe .rhs )
374374 assert 'g(x + 2*h_x)' not in str (eqe .rhs )
375+
376+ def test_backward_compat_array_of_func (self ):
377+ grid = Grid (shape = (11 , 11 , 11 ))
378+ x , _ , _ = grid .dimensions
379+ hx = x .spacing
380+
381+ f = Function (name = 'f' , grid = grid , space_order = 16 , coefficients = 'symbolic' )
382+
383+ # Define stencil coefficients.
384+ weights = Function (name = "w" , space_order = 0 , shape = (9 ,), dimensions = (x ,))
385+ wdx = [weights [0 ]]
386+ for iq in range (1 , weights .shape [0 ]):
387+ wdx .append (weights [iq ])
388+ wdx .insert (0 , weights [iq ])
389+
390+ # Plain numbers for comparison
391+ wdxn = np .random .rand (17 )
392+
393+ # Number with spacing
394+ wdxns = wdxn / hx
395+
396+ dexev = f .dx (weights = wdx ).evaluate
397+ dexevn = f .dx (weights = wdxn ).evaluate
398+ dexevns = f .dx (weights = wdxns ).evaluate
399+
400+ assert all (a .as_coefficient (1 / hx ) for a in dexevn .args )
401+ assert all (a .as_coefficient (1 / hx ) for a in dexevns .args )
402+ assert all (not a .as_coefficient (1 / hx ) for a in dexev .args )
0 commit comments