|
2 | 2 | import io |
3 | 3 | import logging |
4 | 4 | import json |
5 | | -import copy |
6 | 5 | from datetime import datetime |
7 | | - |
8 | 6 | from typing import NoReturn, Tuple |
9 | 7 |
|
10 | | -from rocketpy import Function |
| 8 | +import numpy as np |
| 9 | +from scipy.interpolate import interp1d |
| 10 | + |
| 11 | +from rocketpy import Function, Flight |
11 | 12 | from rocketpy._encoders import RocketPyEncoder |
| 13 | + |
12 | 14 | from starlette.datastructures import Headers, MutableHeaders |
13 | 15 | from starlette.types import ASGIApp, Message, Receive, Scope, Send |
14 | 16 |
|
| 17 | +from src.views.environment import EnvironmentSimulation |
| 18 | +from src.views.flight import FlightSimulation |
| 19 | +from src.views.motor import MotorSimulation |
| 20 | +from src.views.rocket import RocketSimulation |
| 21 | + |
15 | 22 | logger = logging.getLogger(__name__) |
16 | 23 |
|
17 | 24 |
|
@@ -46,78 +53,154 @@ def for_flight(cls) -> 'DiscretizeConfig': |
46 | 53 | return cls(bounds=(0, 30), samples=200) |
47 | 54 |
|
48 | 55 |
|
49 | | -def rocketpy_encoder(obj, config: DiscretizeConfig = DiscretizeConfig()): |
50 | | - """ |
51 | | - Encode a RocketPy object using official RocketPy encoders. |
| 56 | +class InfinityEncoder(RocketPyEncoder): |
| 57 | + def __init__(self, *args, **kwargs): |
| 58 | + super().__init__(*args, **kwargs) |
52 | 59 |
|
53 | | - This function creates a copy of the object, discretizes callable Function |
54 | | - attributes on the copy, and then uses RocketPy's official RocketPyEncoder for |
55 | | - complete object serialization. The original object remains unchanged. |
| 60 | + def default(self, obj): |
| 61 | + if ( |
| 62 | + isinstance(obj, Function) |
| 63 | + and not callable(obj.source) |
| 64 | + and obj.__dom_dim__ == 1 |
| 65 | + ): |
| 66 | + size = len(obj._domain) |
| 67 | + reduction_factor = 1 |
| 68 | + if size > 25: |
| 69 | + reduction_factor = size // 25 |
| 70 | + if reduction_factor > 1: |
| 71 | + obj = obj.set_discrete( |
| 72 | + obj.x_array[0], |
| 73 | + obj.x_array[-1], |
| 74 | + size // reduction_factor, |
| 75 | + mutate_self=False, |
| 76 | + ) |
| 77 | + if isinstance(obj, Flight): |
| 78 | + obj._Flight__evaluate_post_process |
| 79 | + solution = np.array(obj.solution) |
| 80 | + size = len(solution) |
| 81 | + if size > 25: |
| 82 | + reduction_factor = size // 25 |
| 83 | + reduced_solution = np.zeros( |
| 84 | + (size // reduction_factor, solution.shape[1]) |
| 85 | + ) |
| 86 | + reduced_scale = np.linspace( |
| 87 | + solution[0, 0], solution[-1, 0], size // reduction_factor |
| 88 | + ) |
| 89 | + for i, col in enumerate(solution.T): |
| 90 | + reduced_solution[:, i] = interp1d( |
| 91 | + solution[:, 0], col, assume_sorted=True |
| 92 | + )(reduced_scale) |
| 93 | + obj.solution = reduced_solution.tolist() |
56 | 94 |
|
57 | | - Args: |
58 | | - obj: RocketPy object (Environment, Motor, Rocket, Flight) |
59 | | - config: DiscretizeConfig object with discretization parameters (optional) |
| 95 | + obj.flight_phases = None |
| 96 | + obj.function_evaluations = None |
60 | 97 |
|
61 | | - Returns: |
62 | | - Dictionary of encoded attributes |
63 | | - """ |
| 98 | + return super().default(obj) |
64 | 99 |
|
65 | | - if config is None: |
66 | | - config = DiscretizeConfig() |
67 | | - try: |
68 | | - # Create a copy to avoid mutating the original object |
69 | | - obj_copy = copy.deepcopy(obj) |
70 | | - except Exception: |
71 | | - # Fall back to a shallow copy if deep copy is not supported |
72 | | - obj_copy = copy.copy(obj) |
73 | | - |
74 | | - for attr_name in dir(obj_copy): |
75 | | - if attr_name.startswith('_'): |
76 | | - continue |
77 | 100 |
|
78 | | - try: |
79 | | - attr_value = getattr(obj_copy, attr_name) |
80 | | - except Exception: |
81 | | - continue |
| 101 | +def rocketpy_encoder(obj): |
| 102 | + """ |
| 103 | + Encode a RocketPy object using official RocketPy encoders. |
82 | 104 |
|
83 | | - if callable(attr_value) and isinstance(attr_value, Function): |
| 105 | + Uses InfinityEncoder for serialization and reduction. |
| 106 | + """ |
| 107 | + json_str = json.dumps( |
| 108 | + obj, |
| 109 | + cls=InfinityEncoder, |
| 110 | + include_outputs=True, |
| 111 | + include_function_data=True, |
| 112 | + discretize=True, |
| 113 | + allow_pickle=False, |
| 114 | + ) |
| 115 | + encoded_result = json.loads(json_str) |
| 116 | + return _fix_datetime_fields(encoded_result) |
| 117 | + |
| 118 | + |
| 119 | +def collect_attributes(obj, attribute_classes=None): |
| 120 | + """ |
| 121 | + Collect attributes from various simulation classes and populate them from the flight object. |
| 122 | + """ |
| 123 | + if attribute_classes is None: |
| 124 | + attribute_classes = [] |
| 125 | + |
| 126 | + attributes = rocketpy_encoder(obj) |
| 127 | + |
| 128 | + for attribute_class in attribute_classes: |
| 129 | + if issubclass(attribute_class, FlightSimulation): |
| 130 | + flight_attributes_list = [ |
| 131 | + attr |
| 132 | + for attr in attribute_class.__annotations__.keys() |
| 133 | + if attr not in ["message", "rocket", "env"] |
| 134 | + ] |
84 | 135 | try: |
85 | | - discretized_func = Function(attr_value.source) |
86 | | - discretized_func.set_discrete( |
87 | | - lower=config.bounds[0], |
88 | | - upper=config.bounds[1], |
89 | | - samples=config.samples, |
90 | | - mutate_self=True, |
91 | | - ) |
92 | | - |
93 | | - setattr(obj_copy, attr_name, discretized_func) |
94 | | - |
95 | | - except Exception as e: |
96 | | - logger.warning(f"Failed to discretize {attr_name}: {e}") |
| 136 | + for key in flight_attributes_list: |
| 137 | + if key not in attributes: |
| 138 | + try: |
| 139 | + value = getattr(obj, key) |
| 140 | + attributes[key] = value |
| 141 | + except Exception: |
| 142 | + pass |
| 143 | + except Exception: |
| 144 | + pass |
| 145 | + |
| 146 | + elif issubclass(attribute_class, RocketSimulation): |
| 147 | + rocket_attributes_list = [ |
| 148 | + attr |
| 149 | + for attr in attribute_class.__annotations__.keys() |
| 150 | + if attr not in ["message", "motor"] |
| 151 | + ] |
| 152 | + try: |
| 153 | + for key in rocket_attributes_list: |
| 154 | + if key not in attributes.get("rocket", {}): |
| 155 | + try: |
| 156 | + value = getattr(obj.rocket, key) |
| 157 | + attributes.setdefault("rocket", {})[key] = value |
| 158 | + except Exception: |
| 159 | + pass |
| 160 | + except Exception: |
| 161 | + pass |
| 162 | + |
| 163 | + elif issubclass(attribute_class, MotorSimulation): |
| 164 | + motor_attributes_list = [ |
| 165 | + attr |
| 166 | + for attr in attribute_class.__annotations__.keys() |
| 167 | + if attr not in ["message"] |
| 168 | + ] |
| 169 | + try: |
| 170 | + for key in motor_attributes_list: |
| 171 | + if key not in attributes.get("rocket", {}).get( |
| 172 | + "motor", {} |
| 173 | + ): |
| 174 | + try: |
| 175 | + value = getattr(obj.rocket.motor, key) |
| 176 | + attributes.setdefault("rocket", {}).setdefault( |
| 177 | + "motor", {} |
| 178 | + )[key] = value |
| 179 | + except Exception: |
| 180 | + pass |
| 181 | + except Exception: |
| 182 | + pass |
| 183 | + |
| 184 | + elif issubclass(attribute_class, EnvironmentSimulation): |
| 185 | + environment_attributes_list = [ |
| 186 | + attr |
| 187 | + for attr in attribute_class.__annotations__.keys() |
| 188 | + if attr not in ["message"] |
| 189 | + ] |
| 190 | + try: |
| 191 | + for key in environment_attributes_list: |
| 192 | + if key not in attributes.get("env", {}): |
| 193 | + try: |
| 194 | + value = getattr(obj.env, key) |
| 195 | + attributes.setdefault("env", {})[key] = value |
| 196 | + except Exception: |
| 197 | + pass |
| 198 | + except Exception: |
| 199 | + pass |
| 200 | + else: |
| 201 | + continue |
97 | 202 |
|
98 | | - try: |
99 | | - json_str = json.dumps( |
100 | | - obj_copy, |
101 | | - cls=RocketPyEncoder, |
102 | | - include_outputs=True, |
103 | | - include_function_data=True, |
104 | | - ) |
105 | | - encoded_result = json.loads(json_str) |
106 | | - |
107 | | - # Post-process to fix datetime fields that got converted to lists |
108 | | - return _fix_datetime_fields(encoded_result) |
109 | | - except Exception as e: |
110 | | - logger.warning(f"Failed to encode with RocketPyEncoder: {e}") |
111 | | - attributes = {} |
112 | | - for attr_name in dir(obj_copy): |
113 | | - if not attr_name.startswith('_'): |
114 | | - try: |
115 | | - attr_value = getattr(obj_copy, attr_name) |
116 | | - if not callable(attr_value): |
117 | | - attributes[attr_name] = str(attr_value) |
118 | | - except Exception: |
119 | | - continue |
120 | | - return attributes |
| 203 | + return rocketpy_encoder(attributes) |
121 | 204 |
|
122 | 205 |
|
123 | 206 | def _fix_datetime_fields(data): |
|
0 commit comments