Skip to content

Commit 356e14a

Browse files
Run ruff format on modified files
Co-authored-by: Gui-FernandesBR <[email protected]>
1 parent 5abed8b commit 356e14a

File tree

5 files changed

+146
-130
lines changed

5 files changed

+146
-130
lines changed

rocketpy/mathutils/function.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -455,15 +455,18 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab
455455
# For grid interpolation, the actual interpolator is stored separately
456456
# This function is a placeholder that should not be called directly
457457
# since __get_value_opt_grid is used instead
458-
if hasattr(self, '_grid_interpolator'):
458+
if hasattr(self, "_grid_interpolator"):
459+
459460
def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument
460461
return self._grid_interpolator(x)
462+
461463
self._interpolation_func = grid_interpolation
462464
else:
463465
# Fallback to shepard if grid interpolator not available
464466
warnings.warn(
465467
"Grid interpolator not found, falling back to shepard interpolation"
466468
)
469+
467470
def shepard_fallback(x, x_min, x_max, x_data, y_data, _):
468471
# pylint: disable=unused-argument
469472
arg_qty, arg_dim = x.shape
@@ -480,6 +483,7 @@ def shepard_fallback(x, x_min, x_max, x_data, y_data, _):
480483
result[valid_indexes] = numerator_sum / denominator_sum
481484
result[~valid_indexes] = y_data[zero_distances[1]]
482485
return result
486+
483487
self._interpolation_func = shepard_fallback
484488

485489
else:
@@ -683,22 +687,22 @@ def __get_value_opt_grid(self, *args):
683687
Value of the Function at the specified points.
684688
"""
685689
# Check if we have the grid interpolator
686-
if not hasattr(self, '_grid_interpolator'):
690+
if not hasattr(self, "_grid_interpolator"):
687691
raise RuntimeError(
688692
"Grid interpolator not initialized. Use from_grid() to create "
689693
"a Function with grid interpolation."
690694
)
691-
695+
692696
# Convert args to appropriate format for RegularGridInterpolator
693697
# RegularGridInterpolator expects points as (N, ndim) array
694698
if len(args) != self.__dom_dim__:
695699
raise ValueError(
696700
f"Expected {self.__dom_dim__} arguments but got {len(args)}"
697701
)
698-
702+
699703
# Handle single point evaluation
700704
point = np.array(args).reshape(1, -1)
701-
705+
702706
# Handle extrapolation based on the extrapolation setting
703707
if self.__extrapolation__ == "constant":
704708
# Clamp point to grid boundaries for constant extrapolation
@@ -719,11 +723,11 @@ def __get_value_opt_grid(self, *args):
719723
else:
720724
# Natural or other extrapolation - use interpolator directly
721725
result = self._grid_interpolator(point)
722-
726+
723727
# Return scalar for single evaluation
724728
if result.size == 1:
725729
return float(result[0])
726-
730+
727731
return result
728732

729733
def __determine_1d_domain_bounds(self, lower, upper):
@@ -4042,10 +4046,18 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
40424046
}
40434047

40444048
@classmethod
4045-
def from_grid(cls, grid_data, axes, inputs=None, outputs=None,
4046-
interpolation="linear_grid", extrapolation="constant", **kwargs):
4049+
def from_grid(
4050+
cls,
4051+
grid_data,
4052+
axes,
4053+
inputs=None,
4054+
outputs=None,
4055+
interpolation="linear_grid",
4056+
extrapolation="constant",
4057+
**kwargs,
4058+
):
40474059
"""Creates a Function from N-dimensional grid data.
4048-
4060+
40494061
This method is designed for structured grid data, such as CFD simulation
40504062
results where values are computed on a regular grid. It uses
40514063
scipy.interpolate.RegularGridInterpolator for efficient interpolation.
@@ -4092,14 +4104,14 @@ def from_grid(cls, grid_data, axes, inputs=None, outputs=None,
40924104
>>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A
40934105
>>> # Create Function object
40944106
>>> cd_func = Function.from_grid(
4095-
... cd_data,
4107+
... cd_data,
40964108
... [mach, reynolds, alpha],
40974109
... inputs=['Mach', 'Reynolds', 'Alpha'],
40984110
... outputs='Cd'
40994111
... )
41004112
>>> # Evaluate at a point
41014113
>>> cd_func(1.2, 3e5, 3.0)
4102-
4114+
41034115
Notes
41044116
-----
41054117
- Grid data must be on a regular (structured) grid.
@@ -4112,98 +4124,100 @@ def from_grid(cls, grid_data, axes, inputs=None, outputs=None,
41124124
# Validate inputs
41134125
if not isinstance(grid_data, np.ndarray):
41144126
grid_data = np.array(grid_data)
4115-
4127+
41164128
if not isinstance(axes, (list, tuple)):
41174129
raise ValueError("axes must be a list or tuple of 1D arrays")
4118-
4130+
41194131
# Ensure all axes are numpy arrays
4120-
axes = [np.array(axis) if not isinstance(axis, np.ndarray) else axis
4121-
for axis in axes]
4122-
4132+
axes = [
4133+
np.array(axis) if not isinstance(axis, np.ndarray) else axis
4134+
for axis in axes
4135+
]
4136+
41234137
# Check dimensions match
41244138
if len(axes) != grid_data.ndim:
41254139
raise ValueError(
41264140
f"Number of axes ({len(axes)}) must match grid_data dimensions "
41274141
f"({grid_data.ndim})"
41284142
)
4129-
4143+
41304144
# Check each axis matches corresponding grid dimension
41314145
for i, axis in enumerate(axes):
41324146
if len(axis) != grid_data.shape[i]:
41334147
raise ValueError(
41344148
f"Axis {i} has {len(axis)} points but grid dimension {i} "
41354149
f"has {grid_data.shape[i]} points"
41364150
)
4137-
4151+
41384152
# Set default inputs if not provided
41394153
if inputs is None:
41404154
inputs = [f"x{i}" for i in range(len(axes))]
41414155
elif len(inputs) != len(axes):
41424156
raise ValueError(
41434157
f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})"
41444158
)
4145-
4159+
41464160
# Create a new Function instance
41474161
func = cls.__new__(cls)
4148-
4162+
41494163
# Store grid-specific data first
41504164
func._grid_axes = axes
41514165
func._grid_data = grid_data
4152-
4166+
41534167
# Create RegularGridInterpolator
41544168
# We handle extrapolation manually in __get_value_opt_grid,
41554169
# so we set bounds_error=False and let it extrapolate linearly
41564170
# (which we'll override when needed)
41574171
func._grid_interpolator = RegularGridInterpolator(
4158-
axes,
4159-
grid_data,
4160-
method='linear',
4172+
axes,
4173+
grid_data,
4174+
method="linear",
41614175
bounds_error=False,
4162-
fill_value=None # Linear extrapolation (will be overridden by manual handling)
4176+
fill_value=None, # Linear extrapolation (will be overridden by manual handling)
41634177
)
4164-
4178+
41654179
# Create placeholder domain and image for compatibility
41664180
# This flattens the grid for any code expecting these attributes
4167-
mesh = np.meshgrid(*axes, indexing='ij')
4181+
mesh = np.meshgrid(*axes, indexing="ij")
41684182
domain_points = np.column_stack([m.ravel() for m in mesh])
41694183
func._domain = domain_points
41704184
func._image = grid_data.ravel()
4171-
4185+
41724186
# Set source as flattened data array (for compatibility with serialization, etc.)
41734187
func.source = np.column_stack([domain_points, func._image])
4174-
4188+
41754189
# Initialize basic attributes
41764190
func.__inputs__ = inputs
41774191
func.__outputs__ = outputs if outputs is not None else "f"
41784192
func.__interpolation__ = interpolation
41794193
func.__extrapolation__ = extrapolation
4180-
func.title = kwargs.get('title', None)
4194+
func.title = kwargs.get("title", None)
41814195
func.__img_dim__ = 1
41824196
func.__cropped_domain__ = (None, None)
41834197
func._source_type = SourceType.ARRAY
41844198
func.__dom_dim__ = len(axes)
4185-
4199+
41864200
# Set basic array attributes for compatibility
41874201
func.x_array = axes[0]
41884202
func.x_initial, func.x_final = axes[0][0], axes[0][-1]
4189-
func.y_array = func._image[:len(axes[0])] # Placeholder
4203+
func.y_array = func._image[: len(axes[0])] # Placeholder
41904204
func.y_initial, func.y_final = func._image[0], func._image[-1]
41914205
if len(axes) > 2:
4192-
func.z_array = axes[2]
4206+
func.z_array = axes[2]
41934207
func.z_initial, func.z_final = axes[2][0], axes[2][-1]
4194-
4208+
41954209
# Set get_value_opt to use grid interpolation
41964210
func.get_value_opt = func.__get_value_opt_grid
4197-
4211+
41984212
# Set interpolation and extrapolation functions
41994213
func.__set_interpolation_func()
42004214
func.__set_extrapolation_func()
4201-
4215+
42024216
# Set inputs and outputs properly
42034217
func.set_inputs(inputs)
42044218
func.set_outputs(outputs)
42054219
func.set_title(func.title)
4206-
4220+
42074221
return func
42084222

42094223
@classmethod

rocketpy/rocket/rocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__( # pylint: disable=too-many-statements
352352
"linear",
353353
"constant",
354354
)
355-
355+
356356
if isinstance(power_on_drag, Function):
357357
self.power_on_drag = power_on_drag
358358
else:

rocketpy/simulation/flight.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,7 +1374,7 @@ def __get_drag_coefficient(
13741374
# Check if drag function is multi-dimensional by examining its inputs
13751375
if hasattr(drag_function, "__inputs__") and len(drag_function.__inputs__) > 1:
13761376
# Multi-dimensional drag function - calculate additional parameters
1377-
1377+
13781378
# Calculate Reynolds number
13791379
# Re = rho * V * L / mu
13801380
# where L is characteristic length (rocket diameter)
@@ -1383,15 +1383,15 @@ def __get_drag_coefficient(
13831383
freestream_speed = np.linalg.norm(freestream_velocity_body)
13841384
characteristic_length = 2 * self.rocket.radius # Diameter
13851385
reynolds = rho * freestream_speed * characteristic_length / mu
1386-
1386+
13871387
# Calculate angle of attack
13881388
# Angle between freestream velocity and rocket axis (z-axis in body frame)
13891389
# The z component of freestream velocity in body frame
13901390
if hasattr(freestream_velocity_body, "z"):
13911391
stream_vz_b = -freestream_velocity_body.z
13921392
else:
13931393
stream_vz_b = -freestream_velocity_body[2]
1394-
1394+
13951395
# Normalize and calculate angle
13961396
if freestream_speed > 1e-6:
13971397
cos_alpha = stream_vz_b / freestream_speed
@@ -1401,11 +1401,11 @@ def __get_drag_coefficient(
14011401
alpha_deg = np.rad2deg(alpha_rad)
14021402
else:
14031403
alpha_deg = 0.0
1404-
1404+
14051405
# Determine which parameters to pass based on input names
14061406
input_names = [name.lower() for name in drag_function.__inputs__]
14071407
args = []
1408-
1408+
14091409
for name in input_names:
14101410
if "mach" in name or name == "m":
14111411
args.append(mach)
@@ -1416,7 +1416,7 @@ def __get_drag_coefficient(
14161416
else:
14171417
# Unknown parameter, default to mach
14181418
args.append(mach)
1419-
1419+
14201420
return drag_function.get_value_opt(*args)
14211421
else:
14221422
# 1D drag function - only mach number
@@ -1458,7 +1458,7 @@ def udot_rail1(self, t, u, post_processing=False):
14581458
+ (vz) ** 2
14591459
) ** 0.5
14601460
free_stream_mach = free_stream_speed / self.env.speed_of_sound.get_value_opt(z)
1461-
1461+
14621462
# For rail motion, rocket is constrained - velocity mostly along z-axis in body frame
14631463
# Calculate velocity in body frame (simplified for rail)
14641464
a11 = 1 - 2 * (e2**2 + e3**2)
@@ -1470,18 +1470,18 @@ def udot_rail1(self, t, u, post_processing=False):
14701470
a31 = 2 * (e1 * e3 - e0 * e2)
14711471
a32 = 2 * (e2 * e3 + e0 * e1)
14721472
a33 = 1 - 2 * (e1**2 + e2**2)
1473-
1473+
14741474
vx_b = a11 * vx + a21 * vy + a31 * vz
14751475
vy_b = a12 * vx + a22 * vy + a32 * vz
14761476
vz_b = a13 * vx + a23 * vy + a33 * vz
1477-
1477+
14781478
# Freestream velocity in body frame
14791479
wind_vx = self.env.wind_velocity_x.get_value_opt(z)
14801480
wind_vy = self.env.wind_velocity_y.get_value_opt(z)
14811481
stream_vx_b = a11 * (wind_vx - vx) + a21 * (wind_vy - vy) + a31 * (-vz)
14821482
stream_vy_b = a12 * (wind_vx - vx) + a22 * (wind_vy - vy) + a32 * (-vz)
14831483
stream_vz_b = a13 * (wind_vx - vx) + a23 * (wind_vy - vy) + a33 * (-vz)
1484-
1484+
14851485
drag_coeff = self.__get_drag_coefficient(
14861486
self.rocket.power_on_drag,
14871487
free_stream_mach,
@@ -1660,12 +1660,18 @@ def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals
16601660
vx_b = a11 * vx + a21 * vy + a31 * vz
16611661
vy_b = a12 * vx + a22 * vy + a32 * vz
16621662
vz_b = a13 * vx + a23 * vy + a33 * vz
1663-
1663+
16641664
# Calculate freestream velocity in body frame
1665-
stream_vx_b = a11 * (wind_velocity_x - vx) + a21 * (wind_velocity_y - vy) + a31 * (-vz)
1666-
stream_vy_b = a12 * (wind_velocity_x - vx) + a22 * (wind_velocity_y - vy) + a32 * (-vz)
1667-
stream_vz_b = a13 * (wind_velocity_x - vx) + a23 * (wind_velocity_y - vy) + a33 * (-vz)
1668-
1665+
stream_vx_b = (
1666+
a11 * (wind_velocity_x - vx) + a21 * (wind_velocity_y - vy) + a31 * (-vz)
1667+
)
1668+
stream_vy_b = (
1669+
a12 * (wind_velocity_x - vx) + a22 * (wind_velocity_y - vy) + a32 * (-vz)
1670+
)
1671+
stream_vz_b = (
1672+
a13 * (wind_velocity_x - vx) + a23 * (wind_velocity_y - vy) + a33 * (-vz)
1673+
)
1674+
16691675
# Determine aerodynamics forces
16701676
# Determine Drag Force
16711677
if t < self.rocket.motor.burn_out_time:
@@ -1958,7 +1964,7 @@ def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too
19581964
# Calculate freestream velocity in body frame
19591965
freestream_velocity = wind_velocity - v
19601966
freestream_velocity_body = Kt @ freestream_velocity
1961-
1967+
19621968
if self.rocket.motor.burn_start_time < t < self.rocket.motor.burn_out_time:
19631969
pressure = self.env.pressure.get_value_opt(z)
19641970
net_thrust = max(

0 commit comments

Comments
 (0)