Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 274 additions & 2 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LinearNDInterpolator,
NearestNDInterpolator,
RBFInterpolator,
RegularGridInterpolator,
)

from rocketpy.plots.plot_helpers import show_or_save_plot
Expand All @@ -43,6 +44,7 @@
"spline": 3,
"shepard": 4,
"rbf": 5,
"linear_grid": 6,
}
EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2}

Expand Down Expand Up @@ -449,6 +451,41 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab

self._interpolation_func = rbf_interpolation

elif interpolation == 6: # linear_grid (RegularGridInterpolator)
# For grid interpolation, the actual interpolator is stored separately
# This function is a placeholder that should not be called directly
# since __get_value_opt_grid is used instead
if hasattr(self, "_grid_interpolator"):

def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument
return self._grid_interpolator(x)

self._interpolation_func = grid_interpolation
else:
# Fallback to shepard if grid interpolator not available
warnings.warn(
"Grid interpolator not found, falling back to shepard interpolation"
)

def shepard_fallback(x, x_min, x_max, x_data, y_data, _):
# pylint: disable=unused-argument
arg_qty, arg_dim = x.shape
result = np.empty(arg_qty)
x = x.reshape((arg_qty, 1, arg_dim))
sub_matrix = x_data - x
distances_squared = np.sum(sub_matrix**2, axis=2)
zero_distances = np.where(distances_squared == 0)
valid_indexes = np.ones(arg_qty, dtype=bool)
valid_indexes[zero_distances[0]] = False
weights = distances_squared[valid_indexes] ** (-1.5)
numerator_sum = np.sum(y_data * weights, axis=1)
denominator_sum = np.sum(weights, axis=1)
result[valid_indexes] = numerator_sum / denominator_sum
result[~valid_indexes] = y_data[zero_distances[1]]
return result

self._interpolation_func = shepard_fallback

else:
raise ValueError(f"Interpolation {interpolation} method not recognized.")

Expand Down Expand Up @@ -635,6 +672,66 @@ def __get_value_opt_nd(self, *args):

return result

def __get_value_opt_grid(self, *args):
"""Evaluate the Function using RegularGridInterpolator for structured grids.

This method is dynamically assigned in from_grid() class method.

Parameters
----------
args : tuple
Values where the Function is to be evaluated. Must match the number
of dimensions of the grid.

Returns
-------
result : scalar or ndarray
Value of the Function at the specified points.
"""
# Check if we have the grid interpolator
if not hasattr(self, "_grid_interpolator"):
raise RuntimeError(
"Grid interpolator not initialized. Use from_grid() to create "
"a Function with grid interpolation."
)

# Convert args to appropriate format for RegularGridInterpolator
# RegularGridInterpolator expects points as (N, ndim) array
if len(args) != self.__dom_dim__:
raise ValueError(
f"Expected {self.__dom_dim__} arguments but got {len(args)}"
)

# Handle single point evaluation
point = np.array(args).reshape(1, -1)

# Handle extrapolation based on the extrapolation setting
if self.__extrapolation__ == "constant":
# Clamp point to grid boundaries for constant extrapolation
for i, axis in enumerate(self._grid_axes):
point[0, i] = np.clip(point[0, i], axis[0], axis[-1])
result = self._grid_interpolator(point)
elif self.__extrapolation__ == "zero":
# Check if point is outside bounds
outside_bounds = False
for i, axis in enumerate(self._grid_axes):
if point[0, i] < axis[0] or point[0, i] > axis[-1]:
outside_bounds = True
break
if outside_bounds:
result = np.array([0.0])
else:
result = self._grid_interpolator(point)
else:
# Natural or other extrapolation - use interpolator directly
result = self._grid_interpolator(point)

# Return scalar for single evaluation
if result.size == 1:
return float(result[0])

return result

def __determine_1d_domain_bounds(self, lower, upper):
"""Determine domain bounds for 1-D function discretization.

Expand Down Expand Up @@ -3891,11 +3988,11 @@ def __validate_interpolation(self, interpolation):
elif self.__dom_dim__ > 1:
if interpolation is None:
interpolation = "shepard"
if interpolation.lower() not in ["shepard", "linear", "rbf"]:
if interpolation.lower() not in ["shepard", "linear", "rbf", "linear_grid"]:
warnings.warn(
(
"Interpolation method set to 'shepard'. The methods "
"'linear', 'shepard' and 'rbf' are supported for "
"'linear', 'shepard', 'rbf' and 'linear_grid' are supported for "
"multiple dimensions."
),
)
Expand Down Expand Up @@ -3950,6 +4047,181 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
"extrapolation": self.__extrapolation__,
}

@classmethod
def from_grid(
cls,
grid_data,
axes,
inputs=None,
outputs=None,
interpolation="linear_grid",
extrapolation="constant",
**kwargs,
):
"""Creates a Function from N-dimensional grid data.

This method is designed for structured grid data, such as CFD simulation
results where values are computed on a regular grid. It uses
scipy.interpolate.RegularGridInterpolator for efficient interpolation.

Parameters
----------
grid_data : ndarray
N-dimensional array containing the function values on the grid.
For example, for a 3D function Cd(M, Re, α), this would be a 3D array
where grid_data[i, j, k] = Cd(M[i], Re[j], α[k]).
axes : list of ndarray
List of 1D arrays defining the grid points along each axis.
Each array should be sorted in ascending order.
For example: [M_axis, Re_axis, alpha_axis].
inputs : list of str, optional
Names of the input variables. If None, generic names will be used.
For example: ['Mach', 'Reynolds', 'Alpha'].
outputs : str, optional
Name of the output variable. For example: 'Cd'.
interpolation : str, optional
Interpolation method. Default is 'linear_grid'.
Currently only 'linear_grid' is supported for grid data.
extrapolation : str, optional
Extrapolation behavior. Default is 'constant', which clamps to edge values.
'constant': Use nearest edge value for out-of-bounds points.
'zero': Return zero for out-of-bounds points.
**kwargs : dict, optional
Additional arguments passed to the Function constructor.

Returns
-------
Function
A Function object using RegularGridInterpolator for evaluation.

Examples
--------
>>> import numpy as np
>>> # Create 3D drag coefficient data
>>> mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0])
>>> reynolds = np.array([1e5, 5e5, 1e6])
>>> alpha = np.array([0.0, 2.0, 4.0, 6.0])
>>> # Create a simple drag coefficient function
>>> M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing='ij')
>>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A
>>> # Create Function object
>>> cd_func = Function.from_grid(
... cd_data,
... [mach, reynolds, alpha],
... inputs=['Mach', 'Reynolds', 'Alpha'],
... outputs='Cd'
... )
>>> # Evaluate at a point
>>> cd_func(1.2, 3e5, 3.0)

Notes
-----
- Grid data must be on a regular (structured) grid.
- For unstructured data, use the regular Function constructor with
scattered points.
- Extrapolation with 'constant' mode uses the nearest edge values,
which is appropriate for aerodynamic coefficients where extrapolation
beyond the data range should be avoided.
"""
# Validate inputs
if not isinstance(grid_data, np.ndarray):
grid_data = np.array(grid_data)

if not isinstance(axes, (list, tuple)):
raise ValueError("axes must be a list or tuple of 1D arrays")

# Ensure all axes are numpy arrays
axes = [
np.array(axis) if not isinstance(axis, np.ndarray) else axis
for axis in axes
]

# Check dimensions match
if len(axes) != grid_data.ndim:
raise ValueError(
f"Number of axes ({len(axes)}) must match grid_data dimensions "
f"({grid_data.ndim})"
)

# Check each axis matches corresponding grid dimension
for i, axis in enumerate(axes):
if len(axis) != grid_data.shape[i]:
raise ValueError(
f"Axis {i} has {len(axis)} points but grid dimension {i} "
f"has {grid_data.shape[i]} points"
)

# Set default inputs if not provided
if inputs is None:
inputs = [f"x{i}" for i in range(len(axes))]
elif len(inputs) != len(axes):
raise ValueError(
f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})"
)

# Create a new Function instance
func = cls.__new__(cls)

# Store grid-specific data first
func._grid_axes = axes
func._grid_data = grid_data

# Create RegularGridInterpolator
# We handle extrapolation manually in __get_value_opt_grid,
# so we set bounds_error=False and let it extrapolate linearly
# (which we'll override when needed)
func._grid_interpolator = RegularGridInterpolator(
axes,
grid_data,
method="linear",
bounds_error=False,
fill_value=None, # Linear extrapolation (will be overridden by manual handling)
)

# Create placeholder domain and image for compatibility
# This flattens the grid for any code expecting these attributes
mesh = np.meshgrid(*axes, indexing="ij")
domain_points = np.column_stack([m.ravel() for m in mesh])
func._domain = domain_points
func._image = grid_data.ravel()

# Set source as flattened data array (for compatibility with serialization, etc.)
func.source = np.column_stack([domain_points, func._image])

# Initialize basic attributes
func.__inputs__ = inputs
func.__outputs__ = outputs if outputs is not None else "f"
func.__interpolation__ = interpolation
func.__extrapolation__ = extrapolation
func.title = kwargs.get("title", None)
func.__img_dim__ = 1
func.__cropped_domain__ = (None, None)
func._source_type = SourceType.ARRAY
func.__dom_dim__ = len(axes)

# Set basic array attributes for compatibility
func.x_array = axes[0]
func.x_initial, func.x_final = axes[0][0], axes[0][-1]
func.y_array = func._image[: len(axes[0])] # Placeholder
func.y_initial, func.y_final = func._image[0], func._image[-1]
if len(axes) > 2:
func.z_array = axes[2]
func.z_initial, func.z_final = axes[2][0], axes[2][-1]

# Set get_value_opt to use grid interpolation
func.get_value_opt = func.__get_value_opt_grid

# Set interpolation and extrapolation functions
func.__set_interpolation_func()
func.__set_extrapolation_func()

# Set inputs and outputs properly
func.set_inputs(inputs)
func.set_outputs(outputs)
func.set_title(func.title)

return func

@classmethod
def from_dict(cls, func_dict):
"""Creates a Function instance from a dictionary.
Expand Down
36 changes: 22 additions & 14 deletions rocketpy/rocket/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,28 @@ def __init__( # pylint: disable=too-many-statements
)

# Define aerodynamic drag coefficients
self.power_off_drag = Function(
power_off_drag,
"Mach Number",
"Drag Coefficient with Power Off",
"linear",
"constant",
)
self.power_on_drag = Function(
power_on_drag,
"Mach Number",
"Drag Coefficient with Power On",
"linear",
"constant",
)
# If already a Function, use it directly (preserves multi-dimensional drag)
if isinstance(power_off_drag, Function):
self.power_off_drag = power_off_drag
else:
self.power_off_drag = Function(
power_off_drag,
"Mach Number",
"Drag Coefficient with Power Off",
"linear",
"constant",
)

if isinstance(power_on_drag, Function):
self.power_on_drag = power_on_drag
else:
self.power_on_drag = Function(
power_on_drag,
"Mach Number",
"Drag Coefficient with Power On",
"linear",
"constant",
)

# Create a, possibly, temporary empty motor
# self.motors = Components() # currently unused, only 1 motor is supported
Expand Down
Loading
Loading