|
22 | 22 | LinearNDInterpolator, |
23 | 23 | NearestNDInterpolator, |
24 | 24 | RBFInterpolator, |
| 25 | + RegularGridInterpolator, |
25 | 26 | ) |
26 | 27 |
|
27 | 28 | from rocketpy.plots.plot_helpers import show_or_save_plot |
|
43 | 44 | "spline": 3, |
44 | 45 | "shepard": 4, |
45 | 46 | "rbf": 5, |
| 47 | + "linear_grid": 6, |
46 | 48 | } |
47 | 49 | EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2} |
48 | 50 |
|
@@ -449,6 +451,37 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab |
449 | 451 |
|
450 | 452 | self._interpolation_func = rbf_interpolation |
451 | 453 |
|
| 454 | + elif interpolation == 6: # linear_grid (RegularGridInterpolator) |
| 455 | + # For grid interpolation, the actual interpolator is stored separately |
| 456 | + # This function is a placeholder that should not be called directly |
| 457 | + # since __get_value_opt_grid is used instead |
| 458 | + if hasattr(self, '_grid_interpolator'): |
| 459 | + def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument |
| 460 | + return self._grid_interpolator(x) |
| 461 | + self._interpolation_func = grid_interpolation |
| 462 | + else: |
| 463 | + # Fallback to shepard if grid interpolator not available |
| 464 | + warnings.warn( |
| 465 | + "Grid interpolator not found, falling back to shepard interpolation" |
| 466 | + ) |
| 467 | + def shepard_fallback(x, x_min, x_max, x_data, y_data, _): |
| 468 | + # pylint: disable=unused-argument |
| 469 | + arg_qty, arg_dim = x.shape |
| 470 | + result = np.empty(arg_qty) |
| 471 | + x = x.reshape((arg_qty, 1, arg_dim)) |
| 472 | + sub_matrix = x_data - x |
| 473 | + distances_squared = np.sum(sub_matrix**2, axis=2) |
| 474 | + zero_distances = np.where(distances_squared == 0) |
| 475 | + valid_indexes = np.ones(arg_qty, dtype=bool) |
| 476 | + valid_indexes[zero_distances[0]] = False |
| 477 | + weights = distances_squared[valid_indexes] ** (-1.5) |
| 478 | + numerator_sum = np.sum(y_data * weights, axis=1) |
| 479 | + denominator_sum = np.sum(weights, axis=1) |
| 480 | + result[valid_indexes] = numerator_sum / denominator_sum |
| 481 | + result[~valid_indexes] = y_data[zero_distances[1]] |
| 482 | + return result |
| 483 | + self._interpolation_func = shepard_fallback |
| 484 | + |
452 | 485 | else: |
453 | 486 | raise ValueError(f"Interpolation {interpolation} method not recognized.") |
454 | 487 |
|
@@ -635,6 +668,64 @@ def __get_value_opt_nd(self, *args): |
635 | 668 |
|
636 | 669 | return result |
637 | 670 |
|
| 671 | + def __get_value_opt_grid(self, *args): |
| 672 | + """Evaluate the Function using RegularGridInterpolator for structured grids. |
| 673 | +
|
| 674 | + Parameters |
| 675 | + ---------- |
| 676 | + args : tuple |
| 677 | + Values where the Function is to be evaluated. Must match the number |
| 678 | + of dimensions of the grid. |
| 679 | +
|
| 680 | + Returns |
| 681 | + ------- |
| 682 | + result : scalar or ndarray |
| 683 | + Value of the Function at the specified points. |
| 684 | + """ |
| 685 | + # Check if we have the grid interpolator |
| 686 | + if not hasattr(self, '_grid_interpolator'): |
| 687 | + raise RuntimeError( |
| 688 | + "Grid interpolator not initialized. Use from_grid() to create " |
| 689 | + "a Function with grid interpolation." |
| 690 | + ) |
| 691 | + |
| 692 | + # Convert args to appropriate format for RegularGridInterpolator |
| 693 | + # RegularGridInterpolator expects points as (N, ndim) array |
| 694 | + if len(args) != self.__dom_dim__: |
| 695 | + raise ValueError( |
| 696 | + f"Expected {self.__dom_dim__} arguments but got {len(args)}" |
| 697 | + ) |
| 698 | + |
| 699 | + # Handle single point evaluation |
| 700 | + point = np.array(args).reshape(1, -1) |
| 701 | + |
| 702 | + # Handle extrapolation based on the extrapolation setting |
| 703 | + if self.__extrapolation__ == "constant": |
| 704 | + # Clamp point to grid boundaries for constant extrapolation |
| 705 | + for i, axis in enumerate(self._grid_axes): |
| 706 | + point[0, i] = np.clip(point[0, i], axis[0], axis[-1]) |
| 707 | + result = self._grid_interpolator(point) |
| 708 | + elif self.__extrapolation__ == "zero": |
| 709 | + # Check if point is outside bounds |
| 710 | + outside_bounds = False |
| 711 | + for i, axis in enumerate(self._grid_axes): |
| 712 | + if point[0, i] < axis[0] or point[0, i] > axis[-1]: |
| 713 | + outside_bounds = True |
| 714 | + break |
| 715 | + if outside_bounds: |
| 716 | + result = np.array([0.0]) |
| 717 | + else: |
| 718 | + result = self._grid_interpolator(point) |
| 719 | + else: |
| 720 | + # Natural or other extrapolation - use interpolator directly |
| 721 | + result = self._grid_interpolator(point) |
| 722 | + |
| 723 | + # Return scalar for single evaluation |
| 724 | + if result.size == 1: |
| 725 | + return float(result[0]) |
| 726 | + |
| 727 | + return result |
| 728 | + |
638 | 729 | def __determine_1d_domain_bounds(self, lower, upper): |
639 | 730 | """Determine domain bounds for 1-D function discretization. |
640 | 731 |
|
@@ -3891,11 +3982,11 @@ def __validate_interpolation(self, interpolation): |
3891 | 3982 | elif self.__dom_dim__ > 1: |
3892 | 3983 | if interpolation is None: |
3893 | 3984 | interpolation = "shepard" |
3894 | | - if interpolation.lower() not in ["shepard", "linear", "rbf"]: |
| 3985 | + if interpolation.lower() not in ["shepard", "linear", "rbf", "linear_grid"]: |
3895 | 3986 | warnings.warn( |
3896 | 3987 | ( |
3897 | 3988 | "Interpolation method set to 'shepard'. The methods " |
3898 | | - "'linear', 'shepard' and 'rbf' are supported for " |
| 3989 | + "'linear', 'shepard', 'rbf' and 'linear_grid' are supported for " |
3899 | 3990 | "multiple dimensions." |
3900 | 3991 | ), |
3901 | 3992 | ) |
@@ -3950,6 +4041,169 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument |
3950 | 4041 | "extrapolation": self.__extrapolation__, |
3951 | 4042 | } |
3952 | 4043 |
|
| 4044 | + @classmethod |
| 4045 | + def from_grid(cls, grid_data, axes, inputs=None, outputs=None, |
| 4046 | + interpolation="linear_grid", extrapolation="constant", **kwargs): |
| 4047 | + """Creates a Function from N-dimensional grid data. |
| 4048 | + |
| 4049 | + This method is designed for structured grid data, such as CFD simulation |
| 4050 | + results where values are computed on a regular grid. It uses |
| 4051 | + scipy.interpolate.RegularGridInterpolator for efficient interpolation. |
| 4052 | +
|
| 4053 | + Parameters |
| 4054 | + ---------- |
| 4055 | + grid_data : ndarray |
| 4056 | + N-dimensional array containing the function values on the grid. |
| 4057 | + For example, for a 3D function Cd(M, Re, α), this would be a 3D array |
| 4058 | + where grid_data[i, j, k] = Cd(M[i], Re[j], α[k]). |
| 4059 | + axes : list of ndarray |
| 4060 | + List of 1D arrays defining the grid points along each axis. |
| 4061 | + Each array should be sorted in ascending order. |
| 4062 | + For example: [M_axis, Re_axis, alpha_axis]. |
| 4063 | + inputs : list of str, optional |
| 4064 | + Names of the input variables. If None, generic names will be used. |
| 4065 | + For example: ['Mach', 'Reynolds', 'Alpha']. |
| 4066 | + outputs : str, optional |
| 4067 | + Name of the output variable. For example: 'Cd'. |
| 4068 | + interpolation : str, optional |
| 4069 | + Interpolation method. Default is 'linear_grid'. |
| 4070 | + Currently only 'linear_grid' is supported for grid data. |
| 4071 | + extrapolation : str, optional |
| 4072 | + Extrapolation behavior. Default is 'constant', which clamps to edge values. |
| 4073 | + 'constant': Use nearest edge value for out-of-bounds points. |
| 4074 | + 'zero': Return zero for out-of-bounds points. |
| 4075 | + **kwargs : dict, optional |
| 4076 | + Additional arguments passed to the Function constructor. |
| 4077 | +
|
| 4078 | + Returns |
| 4079 | + ------- |
| 4080 | + Function |
| 4081 | + A Function object using RegularGridInterpolator for evaluation. |
| 4082 | +
|
| 4083 | + Examples |
| 4084 | + -------- |
| 4085 | + >>> import numpy as np |
| 4086 | + >>> # Create 3D drag coefficient data |
| 4087 | + >>> mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0]) |
| 4088 | + >>> reynolds = np.array([1e5, 5e5, 1e6]) |
| 4089 | + >>> alpha = np.array([0.0, 2.0, 4.0, 6.0]) |
| 4090 | + >>> # Create a simple drag coefficient function |
| 4091 | + >>> M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing='ij') |
| 4092 | + >>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A |
| 4093 | + >>> # Create Function object |
| 4094 | + >>> cd_func = Function.from_grid( |
| 4095 | + ... cd_data, |
| 4096 | + ... [mach, reynolds, alpha], |
| 4097 | + ... inputs=['Mach', 'Reynolds', 'Alpha'], |
| 4098 | + ... outputs='Cd' |
| 4099 | + ... ) |
| 4100 | + >>> # Evaluate at a point |
| 4101 | + >>> cd_func(1.2, 3e5, 3.0) |
| 4102 | + |
| 4103 | + Notes |
| 4104 | + ----- |
| 4105 | + - Grid data must be on a regular (structured) grid. |
| 4106 | + - For unstructured data, use the regular Function constructor with |
| 4107 | + scattered points. |
| 4108 | + - Extrapolation with 'constant' mode uses the nearest edge values, |
| 4109 | + which is appropriate for aerodynamic coefficients where extrapolation |
| 4110 | + beyond the data range should be avoided. |
| 4111 | + """ |
| 4112 | + # Validate inputs |
| 4113 | + if not isinstance(grid_data, np.ndarray): |
| 4114 | + grid_data = np.array(grid_data) |
| 4115 | + |
| 4116 | + if not isinstance(axes, (list, tuple)): |
| 4117 | + raise ValueError("axes must be a list or tuple of 1D arrays") |
| 4118 | + |
| 4119 | + # Ensure all axes are numpy arrays |
| 4120 | + axes = [np.array(axis) if not isinstance(axis, np.ndarray) else axis |
| 4121 | + for axis in axes] |
| 4122 | + |
| 4123 | + # Check dimensions match |
| 4124 | + if len(axes) != grid_data.ndim: |
| 4125 | + raise ValueError( |
| 4126 | + f"Number of axes ({len(axes)}) must match grid_data dimensions " |
| 4127 | + f"({grid_data.ndim})" |
| 4128 | + ) |
| 4129 | + |
| 4130 | + # Check each axis matches corresponding grid dimension |
| 4131 | + for i, axis in enumerate(axes): |
| 4132 | + if len(axis) != grid_data.shape[i]: |
| 4133 | + raise ValueError( |
| 4134 | + f"Axis {i} has {len(axis)} points but grid dimension {i} " |
| 4135 | + f"has {grid_data.shape[i]} points" |
| 4136 | + ) |
| 4137 | + |
| 4138 | + # Set default inputs if not provided |
| 4139 | + if inputs is None: |
| 4140 | + inputs = [f"x{i}" for i in range(len(axes))] |
| 4141 | + elif len(inputs) != len(axes): |
| 4142 | + raise ValueError( |
| 4143 | + f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})" |
| 4144 | + ) |
| 4145 | + |
| 4146 | + # Create a new Function instance |
| 4147 | + func = cls.__new__(cls) |
| 4148 | + |
| 4149 | + # Initialize basic attributes |
| 4150 | + func.source = None # Will be set to indicate grid source |
| 4151 | + func.__inputs__ = inputs |
| 4152 | + func.__outputs__ = outputs if outputs is not None else "f" |
| 4153 | + func.__interpolation__ = interpolation |
| 4154 | + func.__extrapolation__ = extrapolation |
| 4155 | + func.title = kwargs.get('title', None) |
| 4156 | + func.__img_dim__ = 1 |
| 4157 | + func.__cropped_domain__ = (None, None) |
| 4158 | + func._source_type = SourceType.ARRAY |
| 4159 | + func.__dom_dim__ = len(axes) |
| 4160 | + |
| 4161 | + # Store grid-specific data |
| 4162 | + func._grid_axes = axes |
| 4163 | + func._grid_data = grid_data |
| 4164 | + |
| 4165 | + # Create RegularGridInterpolator |
| 4166 | + # We handle extrapolation manually in __get_value_opt_grid, |
| 4167 | + # so we set bounds_error=False and let it extrapolate linearly |
| 4168 | + # (which we'll override when needed) |
| 4169 | + func._grid_interpolator = RegularGridInterpolator( |
| 4170 | + axes, |
| 4171 | + grid_data, |
| 4172 | + method='linear', |
| 4173 | + bounds_error=False, |
| 4174 | + fill_value=None # Linear extrapolation (will be overridden by manual handling) |
| 4175 | + ) |
| 4176 | + |
| 4177 | + # Create placeholder domain and image for compatibility |
| 4178 | + # This flattens the grid for any code expecting these attributes |
| 4179 | + mesh = np.meshgrid(*axes, indexing='ij') |
| 4180 | + domain_points = np.column_stack([m.ravel() for m in mesh]) |
| 4181 | + func._domain = domain_points |
| 4182 | + func._image = grid_data.ravel() |
| 4183 | + |
| 4184 | + # Set basic array attributes for compatibility |
| 4185 | + func.x_array = axes[0] |
| 4186 | + func.x_initial, func.x_final = axes[0][0], axes[0][-1] |
| 4187 | + func.y_array = func._image[:len(axes[0])] # Placeholder |
| 4188 | + func.y_initial, func.y_final = func._image[0], func._image[-1] |
| 4189 | + if len(axes) > 2: |
| 4190 | + func.z_array = axes[2] |
| 4191 | + func.z_initial, func.z_final = axes[2][0], axes[2][-1] |
| 4192 | + |
| 4193 | + # Set get_value_opt to use grid interpolation |
| 4194 | + func.get_value_opt = func.__get_value_opt_grid |
| 4195 | + |
| 4196 | + # Set interpolation and extrapolation functions |
| 4197 | + func.__set_interpolation_func() |
| 4198 | + func.__set_extrapolation_func() |
| 4199 | + |
| 4200 | + # Set inputs and outputs properly |
| 4201 | + func.set_inputs(inputs) |
| 4202 | + func.set_outputs(outputs) |
| 4203 | + func.set_title(func.title) |
| 4204 | + |
| 4205 | + return func |
| 4206 | + |
3953 | 4207 | @classmethod |
3954 | 4208 | def from_dict(cls, func_dict): |
3955 | 4209 | """Creates a Function instance from a dictionary. |
|
0 commit comments