Skip to content

Commit da0bdb0

Browse files
aasitvora99GabrielBarberiniCopilotphmbressan
authored
ENH: class encoding refactor for large sims (#58)
Co-authored-by: Gabriel Barberini <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Pedro Bressan <[email protected]>
1 parent 5504668 commit da0bdb0

File tree

10 files changed

+237
-103
lines changed

10 files changed

+237
-103
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,6 @@ cython_debug/
158158
# and can be added to the global gitignore or merged into this file. For a more nuclear
159159
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
160160
#.idea/
161+
162+
# VSCode config
163+
.vscode/

src/models/rocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class RocketModel(ApiBaseModel):
1919
radius: float
2020
mass: float
2121
motor_position: float
22-
center_of_mass_without_motor: int
22+
center_of_mass_without_motor: float
2323
inertia: Union[
2424
Tuple[float, float, float],
2525
Tuple[float, float, float, float, float, float],

src/models/sub/aerosurfaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ class Fins(BaseModel):
3434
root_chord: float
3535
span: float
3636
position: float
37-
3837
# Optional parameters
3938
tip_chord: Optional[float] = None
4039
cant_angle: Optional[float] = None
4140
rocket_radius: Optional[float] = None
4241
airfoil: Optional[
4342
Tuple[List[Tuple[float, float]], Literal['radians', 'degrees']]
4443
] = None
44+
sweep_length: Optional[float] = None
45+
sweep_angle: Optional[float] = None
4546

4647
def get_additional_parameters(self):
4748
return {

src/services/environment.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from rocketpy.environment.environment import Environment as RocketPyEnvironment
66
from src.models.environment import EnvironmentModel
77
from src.views.environment import EnvironmentSimulation
8-
from src.utils import rocketpy_encoder, DiscretizeConfig
8+
from src.utils import collect_attributes
99

1010

1111
class EnvironmentService:
@@ -54,10 +54,11 @@ def get_environment_simulation(self) -> EnvironmentSimulation:
5454
EnvironmentSimulation
5555
"""
5656

57-
attributes = rocketpy_encoder(
58-
self.environment, DiscretizeConfig.for_environment()
57+
encoded_attributes = collect_attributes(
58+
self.environment,
59+
[EnvironmentSimulation],
5960
)
60-
env_simulation = EnvironmentSimulation(**attributes)
61+
env_simulation = EnvironmentSimulation(**encoded_attributes)
6162
return env_simulation
6263

6364
def get_environment_binary(self) -> bytes:

src/services/flight.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
from src.services.rocket import RocketService
99
from src.models.flight import FlightModel
1010
from src.views.flight import FlightSimulation
11-
from src.utils import rocketpy_encoder, DiscretizeConfig
11+
from src.views.rocket import RocketSimulation
12+
from src.views.motor import MotorSimulation
13+
from src.views.environment import EnvironmentSimulation
14+
from src.utils import collect_attributes
1215

1316

1417
class FlightService:
@@ -55,10 +58,16 @@ def get_flight_simulation(self) -> FlightSimulation:
5558
Returns:
5659
FlightSimulation
5760
"""
58-
attributes = rocketpy_encoder(
59-
self.flight, DiscretizeConfig.for_flight()
61+
encoded_attributes = collect_attributes(
62+
self.flight,
63+
[
64+
FlightSimulation,
65+
RocketSimulation,
66+
MotorSimulation,
67+
EnvironmentSimulation,
68+
],
6069
)
61-
flight_simulation = FlightSimulation(**attributes)
70+
flight_simulation = FlightSimulation(**encoded_attributes)
6271
return flight_simulation
6372

6473
def get_flight_binary(self) -> bytes:

src/services/motor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from src.models.sub.tanks import TankKinds
1818
from src.models.motor import MotorKinds, MotorModel
1919
from src.views.motor import MotorSimulation
20-
from src.utils import rocketpy_encoder, DiscretizeConfig
20+
from src.utils import collect_attributes
2121

2222

2323
class MotorService:
@@ -140,8 +140,11 @@ def get_motor_simulation(self) -> MotorSimulation:
140140
Returns:
141141
MotorSimulation
142142
"""
143-
attributes = rocketpy_encoder(self.motor, DiscretizeConfig.for_motor())
144-
motor_simulation = MotorSimulation(**attributes)
143+
encoded_attributes = collect_attributes(
144+
self.motor,
145+
[MotorSimulation],
146+
)
147+
motor_simulation = MotorSimulation(**encoded_attributes)
145148
return motor_simulation
146149

147150
def get_motor_binary(self) -> bytes:

src/services/rocket.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from src.models.sub.aerosurfaces import NoseCone, Tail, Fins
1818
from src.services.motor import MotorService
1919
from src.views.rocket import RocketSimulation
20-
from src.utils import rocketpy_encoder, DiscretizeConfig
20+
from src.views.motor import MotorSimulation
21+
from src.utils import collect_attributes
2122

2223

2324
class RocketService:
@@ -107,10 +108,10 @@ def get_rocket_simulation(self) -> RocketSimulation:
107108
Returns:
108109
RocketSimulation
109110
"""
110-
attributes = rocketpy_encoder(
111-
self.rocket, DiscretizeConfig.for_rocket()
111+
encoded_attributes = collect_attributes(
112+
self.rocket, [RocketSimulation, MotorSimulation]
112113
)
113-
rocket_simulation = RocketSimulation(**attributes)
114+
rocket_simulation = RocketSimulation(**encoded_attributes)
114115
return rocket_simulation
115116

116117
def get_rocket_binary(self) -> bytes:

src/utils.py

Lines changed: 150 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,23 @@
22
import io
33
import logging
44
import json
5-
import copy
65
from datetime import datetime
7-
86
from typing import NoReturn, Tuple
97

10-
from rocketpy import Function
8+
import numpy as np
9+
from scipy.interpolate import interp1d
10+
11+
from rocketpy import Function, Flight
1112
from rocketpy._encoders import RocketPyEncoder
13+
1214
from starlette.datastructures import Headers, MutableHeaders
1315
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1416

17+
from src.views.environment import EnvironmentSimulation
18+
from src.views.flight import FlightSimulation
19+
from src.views.motor import MotorSimulation
20+
from src.views.rocket import RocketSimulation
21+
1522
logger = logging.getLogger(__name__)
1623

1724

@@ -46,78 +53,154 @@ def for_flight(cls) -> 'DiscretizeConfig':
4653
return cls(bounds=(0, 30), samples=200)
4754

4855

49-
def rocketpy_encoder(obj, config: DiscretizeConfig = DiscretizeConfig()):
50-
"""
51-
Encode a RocketPy object using official RocketPy encoders.
56+
class InfinityEncoder(RocketPyEncoder):
57+
def __init__(self, *args, **kwargs):
58+
super().__init__(*args, **kwargs)
5259

53-
This function creates a copy of the object, discretizes callable Function
54-
attributes on the copy, and then uses RocketPy's official RocketPyEncoder for
55-
complete object serialization. The original object remains unchanged.
60+
def default(self, obj):
61+
if (
62+
isinstance(obj, Function)
63+
and not callable(obj.source)
64+
and obj.__dom_dim__ == 1
65+
):
66+
size = len(obj._domain)
67+
reduction_factor = 1
68+
if size > 25:
69+
reduction_factor = size // 25
70+
if reduction_factor > 1:
71+
obj = obj.set_discrete(
72+
obj.x_array[0],
73+
obj.x_array[-1],
74+
size // reduction_factor,
75+
mutate_self=False,
76+
)
77+
if isinstance(obj, Flight):
78+
obj._Flight__evaluate_post_process
79+
solution = np.array(obj.solution)
80+
size = len(solution)
81+
if size > 25:
82+
reduction_factor = size // 25
83+
reduced_solution = np.zeros(
84+
(size // reduction_factor, solution.shape[1])
85+
)
86+
reduced_scale = np.linspace(
87+
solution[0, 0], solution[-1, 0], size // reduction_factor
88+
)
89+
for i, col in enumerate(solution.T):
90+
reduced_solution[:, i] = interp1d(
91+
solution[:, 0], col, assume_sorted=True
92+
)(reduced_scale)
93+
obj.solution = reduced_solution.tolist()
5694

57-
Args:
58-
obj: RocketPy object (Environment, Motor, Rocket, Flight)
59-
config: DiscretizeConfig object with discretization parameters (optional)
95+
obj.flight_phases = None
96+
obj.function_evaluations = None
6097

61-
Returns:
62-
Dictionary of encoded attributes
63-
"""
98+
return super().default(obj)
6499

65-
if config is None:
66-
config = DiscretizeConfig()
67-
try:
68-
# Create a copy to avoid mutating the original object
69-
obj_copy = copy.deepcopy(obj)
70-
except Exception:
71-
# Fall back to a shallow copy if deep copy is not supported
72-
obj_copy = copy.copy(obj)
73-
74-
for attr_name in dir(obj_copy):
75-
if attr_name.startswith('_'):
76-
continue
77100

78-
try:
79-
attr_value = getattr(obj_copy, attr_name)
80-
except Exception:
81-
continue
101+
def rocketpy_encoder(obj):
102+
"""
103+
Encode a RocketPy object using official RocketPy encoders.
82104
83-
if callable(attr_value) and isinstance(attr_value, Function):
105+
Uses InfinityEncoder for serialization and reduction.
106+
"""
107+
json_str = json.dumps(
108+
obj,
109+
cls=InfinityEncoder,
110+
include_outputs=True,
111+
include_function_data=True,
112+
discretize=True,
113+
allow_pickle=False,
114+
)
115+
encoded_result = json.loads(json_str)
116+
return _fix_datetime_fields(encoded_result)
117+
118+
119+
def collect_attributes(obj, attribute_classes=None):
120+
"""
121+
Collect attributes from various simulation classes and populate them from the flight object.
122+
"""
123+
if attribute_classes is None:
124+
attribute_classes = []
125+
126+
attributes = rocketpy_encoder(obj)
127+
128+
for attribute_class in attribute_classes:
129+
if issubclass(attribute_class, FlightSimulation):
130+
flight_attributes_list = [
131+
attr
132+
for attr in attribute_class.__annotations__.keys()
133+
if attr not in ["message", "rocket", "env"]
134+
]
84135
try:
85-
discretized_func = Function(attr_value.source)
86-
discretized_func.set_discrete(
87-
lower=config.bounds[0],
88-
upper=config.bounds[1],
89-
samples=config.samples,
90-
mutate_self=True,
91-
)
92-
93-
setattr(obj_copy, attr_name, discretized_func)
94-
95-
except Exception as e:
96-
logger.warning(f"Failed to discretize {attr_name}: {e}")
136+
for key in flight_attributes_list:
137+
if key not in attributes:
138+
try:
139+
value = getattr(obj, key)
140+
attributes[key] = value
141+
except Exception:
142+
pass
143+
except Exception:
144+
pass
145+
146+
elif issubclass(attribute_class, RocketSimulation):
147+
rocket_attributes_list = [
148+
attr
149+
for attr in attribute_class.__annotations__.keys()
150+
if attr not in ["message", "motor"]
151+
]
152+
try:
153+
for key in rocket_attributes_list:
154+
if key not in attributes.get("rocket", {}):
155+
try:
156+
value = getattr(obj.rocket, key)
157+
attributes.setdefault("rocket", {})[key] = value
158+
except Exception:
159+
pass
160+
except Exception:
161+
pass
162+
163+
elif issubclass(attribute_class, MotorSimulation):
164+
motor_attributes_list = [
165+
attr
166+
for attr in attribute_class.__annotations__.keys()
167+
if attr not in ["message"]
168+
]
169+
try:
170+
for key in motor_attributes_list:
171+
if key not in attributes.get("rocket", {}).get(
172+
"motor", {}
173+
):
174+
try:
175+
value = getattr(obj.rocket.motor, key)
176+
attributes.setdefault("rocket", {}).setdefault(
177+
"motor", {}
178+
)[key] = value
179+
except Exception:
180+
pass
181+
except Exception:
182+
pass
183+
184+
elif issubclass(attribute_class, EnvironmentSimulation):
185+
environment_attributes_list = [
186+
attr
187+
for attr in attribute_class.__annotations__.keys()
188+
if attr not in ["message"]
189+
]
190+
try:
191+
for key in environment_attributes_list:
192+
if key not in attributes.get("env", {}):
193+
try:
194+
value = getattr(obj.env, key)
195+
attributes.setdefault("env", {})[key] = value
196+
except Exception:
197+
pass
198+
except Exception:
199+
pass
200+
else:
201+
continue
97202

98-
try:
99-
json_str = json.dumps(
100-
obj_copy,
101-
cls=RocketPyEncoder,
102-
include_outputs=True,
103-
include_function_data=True,
104-
)
105-
encoded_result = json.loads(json_str)
106-
107-
# Post-process to fix datetime fields that got converted to lists
108-
return _fix_datetime_fields(encoded_result)
109-
except Exception as e:
110-
logger.warning(f"Failed to encode with RocketPyEncoder: {e}")
111-
attributes = {}
112-
for attr_name in dir(obj_copy):
113-
if not attr_name.startswith('_'):
114-
try:
115-
attr_value = getattr(obj_copy, attr_name)
116-
if not callable(attr_value):
117-
attributes[attr_name] = str(attr_value)
118-
except Exception:
119-
continue
120-
return attributes
203+
return rocketpy_encoder(attributes)
121204

122205

123206
def _fix_datetime_fields(data):

0 commit comments

Comments
 (0)