Skip to content

Commit 7292277

Browse files
Add grid interpolation support to Function class with from_grid() method
Co-authored-by: Gui-FernandesBR <[email protected]>
1 parent 170e89c commit 7292277

File tree

2 files changed

+407
-2
lines changed

2 files changed

+407
-2
lines changed

rocketpy/mathutils/function.py

Lines changed: 256 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
LinearNDInterpolator,
2323
NearestNDInterpolator,
2424
RBFInterpolator,
25+
RegularGridInterpolator,
2526
)
2627

2728
from rocketpy.plots.plot_helpers import show_or_save_plot
@@ -43,6 +44,7 @@
4344
"spline": 3,
4445
"shepard": 4,
4546
"rbf": 5,
47+
"linear_grid": 6,
4648
}
4749
EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2}
4850

@@ -449,6 +451,37 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab
449451

450452
self._interpolation_func = rbf_interpolation
451453

454+
elif interpolation == 6: # linear_grid (RegularGridInterpolator)
455+
# For grid interpolation, the actual interpolator is stored separately
456+
# This function is a placeholder that should not be called directly
457+
# since __get_value_opt_grid is used instead
458+
if hasattr(self, '_grid_interpolator'):
459+
def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument
460+
return self._grid_interpolator(x)
461+
self._interpolation_func = grid_interpolation
462+
else:
463+
# Fallback to shepard if grid interpolator not available
464+
warnings.warn(
465+
"Grid interpolator not found, falling back to shepard interpolation"
466+
)
467+
def shepard_fallback(x, x_min, x_max, x_data, y_data, _):
468+
# pylint: disable=unused-argument
469+
arg_qty, arg_dim = x.shape
470+
result = np.empty(arg_qty)
471+
x = x.reshape((arg_qty, 1, arg_dim))
472+
sub_matrix = x_data - x
473+
distances_squared = np.sum(sub_matrix**2, axis=2)
474+
zero_distances = np.where(distances_squared == 0)
475+
valid_indexes = np.ones(arg_qty, dtype=bool)
476+
valid_indexes[zero_distances[0]] = False
477+
weights = distances_squared[valid_indexes] ** (-1.5)
478+
numerator_sum = np.sum(y_data * weights, axis=1)
479+
denominator_sum = np.sum(weights, axis=1)
480+
result[valid_indexes] = numerator_sum / denominator_sum
481+
result[~valid_indexes] = y_data[zero_distances[1]]
482+
return result
483+
self._interpolation_func = shepard_fallback
484+
452485
else:
453486
raise ValueError(f"Interpolation {interpolation} method not recognized.")
454487

@@ -635,6 +668,64 @@ def __get_value_opt_nd(self, *args):
635668

636669
return result
637670

671+
def __get_value_opt_grid(self, *args):
672+
"""Evaluate the Function using RegularGridInterpolator for structured grids.
673+
674+
Parameters
675+
----------
676+
args : tuple
677+
Values where the Function is to be evaluated. Must match the number
678+
of dimensions of the grid.
679+
680+
Returns
681+
-------
682+
result : scalar or ndarray
683+
Value of the Function at the specified points.
684+
"""
685+
# Check if we have the grid interpolator
686+
if not hasattr(self, '_grid_interpolator'):
687+
raise RuntimeError(
688+
"Grid interpolator not initialized. Use from_grid() to create "
689+
"a Function with grid interpolation."
690+
)
691+
692+
# Convert args to appropriate format for RegularGridInterpolator
693+
# RegularGridInterpolator expects points as (N, ndim) array
694+
if len(args) != self.__dom_dim__:
695+
raise ValueError(
696+
f"Expected {self.__dom_dim__} arguments but got {len(args)}"
697+
)
698+
699+
# Handle single point evaluation
700+
point = np.array(args).reshape(1, -1)
701+
702+
# Handle extrapolation based on the extrapolation setting
703+
if self.__extrapolation__ == "constant":
704+
# Clamp point to grid boundaries for constant extrapolation
705+
for i, axis in enumerate(self._grid_axes):
706+
point[0, i] = np.clip(point[0, i], axis[0], axis[-1])
707+
result = self._grid_interpolator(point)
708+
elif self.__extrapolation__ == "zero":
709+
# Check if point is outside bounds
710+
outside_bounds = False
711+
for i, axis in enumerate(self._grid_axes):
712+
if point[0, i] < axis[0] or point[0, i] > axis[-1]:
713+
outside_bounds = True
714+
break
715+
if outside_bounds:
716+
result = np.array([0.0])
717+
else:
718+
result = self._grid_interpolator(point)
719+
else:
720+
# Natural or other extrapolation - use interpolator directly
721+
result = self._grid_interpolator(point)
722+
723+
# Return scalar for single evaluation
724+
if result.size == 1:
725+
return float(result[0])
726+
727+
return result
728+
638729
def __determine_1d_domain_bounds(self, lower, upper):
639730
"""Determine domain bounds for 1-D function discretization.
640731
@@ -3891,11 +3982,11 @@ def __validate_interpolation(self, interpolation):
38913982
elif self.__dom_dim__ > 1:
38923983
if interpolation is None:
38933984
interpolation = "shepard"
3894-
if interpolation.lower() not in ["shepard", "linear", "rbf"]:
3985+
if interpolation.lower() not in ["shepard", "linear", "rbf", "linear_grid"]:
38953986
warnings.warn(
38963987
(
38973988
"Interpolation method set to 'shepard'. The methods "
3898-
"'linear', 'shepard' and 'rbf' are supported for "
3989+
"'linear', 'shepard', 'rbf' and 'linear_grid' are supported for "
38993990
"multiple dimensions."
39003991
),
39013992
)
@@ -3950,6 +4041,169 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
39504041
"extrapolation": self.__extrapolation__,
39514042
}
39524043

4044+
@classmethod
4045+
def from_grid(cls, grid_data, axes, inputs=None, outputs=None,
4046+
interpolation="linear_grid", extrapolation="constant", **kwargs):
4047+
"""Creates a Function from N-dimensional grid data.
4048+
4049+
This method is designed for structured grid data, such as CFD simulation
4050+
results where values are computed on a regular grid. It uses
4051+
scipy.interpolate.RegularGridInterpolator for efficient interpolation.
4052+
4053+
Parameters
4054+
----------
4055+
grid_data : ndarray
4056+
N-dimensional array containing the function values on the grid.
4057+
For example, for a 3D function Cd(M, Re, α), this would be a 3D array
4058+
where grid_data[i, j, k] = Cd(M[i], Re[j], α[k]).
4059+
axes : list of ndarray
4060+
List of 1D arrays defining the grid points along each axis.
4061+
Each array should be sorted in ascending order.
4062+
For example: [M_axis, Re_axis, alpha_axis].
4063+
inputs : list of str, optional
4064+
Names of the input variables. If None, generic names will be used.
4065+
For example: ['Mach', 'Reynolds', 'Alpha'].
4066+
outputs : str, optional
4067+
Name of the output variable. For example: 'Cd'.
4068+
interpolation : str, optional
4069+
Interpolation method. Default is 'linear_grid'.
4070+
Currently only 'linear_grid' is supported for grid data.
4071+
extrapolation : str, optional
4072+
Extrapolation behavior. Default is 'constant', which clamps to edge values.
4073+
'constant': Use nearest edge value for out-of-bounds points.
4074+
'zero': Return zero for out-of-bounds points.
4075+
**kwargs : dict, optional
4076+
Additional arguments passed to the Function constructor.
4077+
4078+
Returns
4079+
-------
4080+
Function
4081+
A Function object using RegularGridInterpolator for evaluation.
4082+
4083+
Examples
4084+
--------
4085+
>>> import numpy as np
4086+
>>> # Create 3D drag coefficient data
4087+
>>> mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0])
4088+
>>> reynolds = np.array([1e5, 5e5, 1e6])
4089+
>>> alpha = np.array([0.0, 2.0, 4.0, 6.0])
4090+
>>> # Create a simple drag coefficient function
4091+
>>> M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing='ij')
4092+
>>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A
4093+
>>> # Create Function object
4094+
>>> cd_func = Function.from_grid(
4095+
... cd_data,
4096+
... [mach, reynolds, alpha],
4097+
... inputs=['Mach', 'Reynolds', 'Alpha'],
4098+
... outputs='Cd'
4099+
... )
4100+
>>> # Evaluate at a point
4101+
>>> cd_func(1.2, 3e5, 3.0)
4102+
4103+
Notes
4104+
-----
4105+
- Grid data must be on a regular (structured) grid.
4106+
- For unstructured data, use the regular Function constructor with
4107+
scattered points.
4108+
- Extrapolation with 'constant' mode uses the nearest edge values,
4109+
which is appropriate for aerodynamic coefficients where extrapolation
4110+
beyond the data range should be avoided.
4111+
"""
4112+
# Validate inputs
4113+
if not isinstance(grid_data, np.ndarray):
4114+
grid_data = np.array(grid_data)
4115+
4116+
if not isinstance(axes, (list, tuple)):
4117+
raise ValueError("axes must be a list or tuple of 1D arrays")
4118+
4119+
# Ensure all axes are numpy arrays
4120+
axes = [np.array(axis) if not isinstance(axis, np.ndarray) else axis
4121+
for axis in axes]
4122+
4123+
# Check dimensions match
4124+
if len(axes) != grid_data.ndim:
4125+
raise ValueError(
4126+
f"Number of axes ({len(axes)}) must match grid_data dimensions "
4127+
f"({grid_data.ndim})"
4128+
)
4129+
4130+
# Check each axis matches corresponding grid dimension
4131+
for i, axis in enumerate(axes):
4132+
if len(axis) != grid_data.shape[i]:
4133+
raise ValueError(
4134+
f"Axis {i} has {len(axis)} points but grid dimension {i} "
4135+
f"has {grid_data.shape[i]} points"
4136+
)
4137+
4138+
# Set default inputs if not provided
4139+
if inputs is None:
4140+
inputs = [f"x{i}" for i in range(len(axes))]
4141+
elif len(inputs) != len(axes):
4142+
raise ValueError(
4143+
f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})"
4144+
)
4145+
4146+
# Create a new Function instance
4147+
func = cls.__new__(cls)
4148+
4149+
# Initialize basic attributes
4150+
func.source = None # Will be set to indicate grid source
4151+
func.__inputs__ = inputs
4152+
func.__outputs__ = outputs if outputs is not None else "f"
4153+
func.__interpolation__ = interpolation
4154+
func.__extrapolation__ = extrapolation
4155+
func.title = kwargs.get('title', None)
4156+
func.__img_dim__ = 1
4157+
func.__cropped_domain__ = (None, None)
4158+
func._source_type = SourceType.ARRAY
4159+
func.__dom_dim__ = len(axes)
4160+
4161+
# Store grid-specific data
4162+
func._grid_axes = axes
4163+
func._grid_data = grid_data
4164+
4165+
# Create RegularGridInterpolator
4166+
# We handle extrapolation manually in __get_value_opt_grid,
4167+
# so we set bounds_error=False and let it extrapolate linearly
4168+
# (which we'll override when needed)
4169+
func._grid_interpolator = RegularGridInterpolator(
4170+
axes,
4171+
grid_data,
4172+
method='linear',
4173+
bounds_error=False,
4174+
fill_value=None # Linear extrapolation (will be overridden by manual handling)
4175+
)
4176+
4177+
# Create placeholder domain and image for compatibility
4178+
# This flattens the grid for any code expecting these attributes
4179+
mesh = np.meshgrid(*axes, indexing='ij')
4180+
domain_points = np.column_stack([m.ravel() for m in mesh])
4181+
func._domain = domain_points
4182+
func._image = grid_data.ravel()
4183+
4184+
# Set basic array attributes for compatibility
4185+
func.x_array = axes[0]
4186+
func.x_initial, func.x_final = axes[0][0], axes[0][-1]
4187+
func.y_array = func._image[:len(axes[0])] # Placeholder
4188+
func.y_initial, func.y_final = func._image[0], func._image[-1]
4189+
if len(axes) > 2:
4190+
func.z_array = axes[2]
4191+
func.z_initial, func.z_final = axes[2][0], axes[2][-1]
4192+
4193+
# Set get_value_opt to use grid interpolation
4194+
func.get_value_opt = func.__get_value_opt_grid
4195+
4196+
# Set interpolation and extrapolation functions
4197+
func.__set_interpolation_func()
4198+
func.__set_extrapolation_func()
4199+
4200+
# Set inputs and outputs properly
4201+
func.set_inputs(inputs)
4202+
func.set_outputs(outputs)
4203+
func.set_title(func.title)
4204+
4205+
return func
4206+
39534207
@classmethod
39544208
def from_dict(cls, func_dict):
39554209
"""Creates a Function instance from a dictionary.

0 commit comments

Comments
 (0)