Skip to content

Commit 5f63807

Browse files
Added wrapper capabilities
1 parent 97541a4 commit 5f63807

28 files changed

+369
-129
lines changed

acsl_pychrono/config/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# acsl_pychrono/config/__init__.py
2+
3+
from .config import MissionConfig, VehicleConfig, EnvironmentConfig, WrapperParams
4+
from .config import SimulationConfig

acsl_pychrono/config/config.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1+
import warnings
12
from dataclasses import dataclass
23

34
@dataclass
4-
class SimulationConfig:
5+
class MissionConfig:
56
# Total simulation duration in seconds
6-
simulation_duration_seconds: float = 21.5
7+
simulation_duration_seconds: float = 21.5 # 21.5
78
# Run the simulator in Wrapper mode (more simulations automatically run sequentially)
89
wrapper_flag: bool = False
910
# If True, perform real-time rendering of the simulation with Irrlicht
@@ -22,7 +23,7 @@ class SimulationConfig:
2223
# Controller types:
2324
# "PID",
2425
# "MRAC",
25-
controller_type: str = "PID"
26+
controller_type: str = "MRAC"
2627

2728
# User-defined trajectory types:
2829
# "circular_trajectory",
@@ -44,6 +45,17 @@ class SimulationConfig:
4445
# "many_steel_balls_in_random_position"
4546
payload_type: str = "two_steel_balls"
4647

48+
# Unique wrapper batch folder passed to the function used for running many parallel wrapper simulations
49+
wrapper_batch_dir: str = "" # LEAVE BLANK!!!
50+
51+
# Number of parallel simulations (one per CPU) to be run in "wrapper" mode
52+
wrapper_max_parallel: int = 20
53+
54+
def __post_init__(self):
55+
if self.wrapper_flag and self.visualization_flag:
56+
warnings.warn("Visualization is disabled because wrapper mode is enabled.")
57+
self.visualization_flag = False
58+
4759
@dataclass
4860
class VehicleConfig:
4961
# Path relative to 'current_working_directory/assets/vehicles'
@@ -55,4 +67,15 @@ class EnvironmentConfig:
5567
include: bool = False
5668
# Path relative to 'current_working_directory/assets/environments'
5769
model_relative_path: str = "environmentA/environmentA.py"
70+
71+
@dataclass
72+
class WrapperParams: # Add here the params to be sweeped by the wrapper with their default values
73+
my_ball_density: float = 7850
74+
75+
@dataclass
76+
class SimulationConfig:
77+
mission_config: MissionConfig = MissionConfig()
78+
vehicle_config: VehicleConfig = VehicleConfig()
79+
environment_config: EnvironmentConfig = EnvironmentConfig()
80+
wrapper_params: WrapperParams = WrapperParams()
5881

acsl_pychrono/control/MRAC/mrac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33
from acsl_pychrono.control.outerloop_safetymech import OuterLoopSafetyMechanism
44
from acsl_pychrono.control.MRAC.mrac_gains import MRACGains
5-
from acsl_pychrono.ode_input import OdeInput
6-
from acsl_pychrono.flight_params import FlightParams
5+
from acsl_pychrono.simulation.ode_input import OdeInput
6+
from acsl_pychrono.simulation.flight_params import FlightParams
77
from acsl_pychrono.control.control import Control
88
from acsl_pychrono.control.base_mrac import BaseMRAC
99
from acsl_pychrono.control.MRAC.m_mrac import M_MRAC

acsl_pychrono/control/MRAC/mrac_gains.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
from numpy import linalg as LA
44
from scipy import linalg
5-
from acsl_pychrono.flight_params import FlightParams
5+
from acsl_pychrono.simulation.flight_params import FlightParams
66

77
class MRACGains:
88
def __init__(self, flight_params: FlightParams):

acsl_pychrono/control/PID/pid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33
from acsl_pychrono.control.outerloop_safetymech import OuterLoopSafetyMechanism
44
from acsl_pychrono.control.PID.pid_gains import PIDGains
5-
from acsl_pychrono.ode_input import OdeInput
6-
from acsl_pychrono.flight_params import FlightParams
5+
from acsl_pychrono.simulation.ode_input import OdeInput
6+
from acsl_pychrono.simulation.flight_params import FlightParams
77
from acsl_pychrono.control.control import Control
88

99
class PID(Control):

acsl_pychrono/control/PID/pid_gains.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from numpy import linalg as LA
44
import scipy
55
from scipy import linalg
6-
from acsl_pychrono.flight_params import FlightParams
6+
from acsl_pychrono.simulation.flight_params import FlightParams
77

88
class PIDGains:
99
def __init__(self, flight_params: FlightParams):

acsl_pychrono/control/PID/pid_logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33
from acsl_pychrono.control.PID.pid_gains import PIDGains
44
from acsl_pychrono.control.PID.pid import PID
5-
from acsl_pychrono.ode_input import OdeInput
6-
from acsl_pychrono.flight_params import FlightParams
5+
from acsl_pychrono.simulation.ode_input import OdeInput
6+
from acsl_pychrono.simulation.flight_params import FlightParams
77

88
class PIDLogger:
99
def __init__(self, gains: PIDGains) -> None:

acsl_pychrono/control/base_mrac.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
import numpy as np
33
from numpy import linalg as LA
44

5-
from acsl_pychrono.ode_input import OdeInput
6-
from acsl_pychrono.flight_params import FlightParams
5+
from acsl_pychrono.simulation.ode_input import OdeInput
6+
from acsl_pychrono.simulation.flight_params import FlightParams
77
from acsl_pychrono.control.control import Control
88

99
class BaseMRAC():

acsl_pychrono/control/control.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from numpy.typing import NDArray
44
from abc import ABC, abstractmethod
55

6-
from acsl_pychrono.functions import rk4singlestep
7-
from acsl_pychrono.ode_input import OdeInput
8-
from acsl_pychrono.flight_params import FlightParams
6+
from acsl_pychrono.simulation.functions import rk4singlestep
7+
from acsl_pychrono.simulation.ode_input import OdeInput
8+
from acsl_pychrono.simulation.flight_params import FlightParams
99

1010
class Control:
1111
def __init__(self, odein: OdeInput) -> None:

acsl_pychrono/control/logging.py

Lines changed: 123 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,50 @@
11
import os
22
import datetime
3+
import subprocess
4+
import warnings
35
import numpy as np
46
from scipy.io import savemat
7+
from dataclasses import is_dataclass
8+
import acsl_pychrono.config.config as Cfg
59

610
class Logging:
711
@staticmethod
8-
def saveMatlabWorkspaceLog(log_dict, gains, controller_type):
9-
"""
10-
Save the log and gains data into a .mat file in a structured directory.
11-
12-
Args:
13-
log_dict (dict): Dictionary containing the log data.
14-
gains (object): Object containing gain parameters as attributes.
15-
controller_type (str): The name/type of the controller.
16-
"""
12+
def getOutputDir(sim_cfg: Cfg.SimulationConfig) -> str:
13+
controller_type = Cfg.MissionConfig.controller_type
14+
wrapper_flag = Cfg.MissionConfig.wrapper_flag
1715

1816
# Get current time
1917
now = datetime.datetime.now()
20-
timestamp = now.strftime("%Y%m%d_%H%M%S")
2118
year = now.strftime("%Y")
2219
month = now.strftime("%m")
2320
full_date = now.strftime("%Y%m%d")
2421

2522
# Construct the directory path
26-
dir_path = os.path.join("logs", year, month, full_date, controller_type, "workspaces")
23+
if wrapper_flag:
24+
dir_path = os.path.join(sim_cfg.mission_config.wrapper_batch_dir)
25+
else:
26+
dir_path = os.path.join("logs", year, month, full_date, controller_type, "workspaces")
2727
os.makedirs(dir_path, exist_ok=True) # Create all directories if not present
2828

29+
return dir_path
30+
31+
@staticmethod
32+
def generateUniqueFilename(base_name: str, extension: str, dir_path: str, use_suffix: bool) -> str:
33+
if not use_suffix:
34+
# No suffix, allow overwrite
35+
return os.path.join(dir_path, f"{base_name}.{extension}")
36+
37+
# Suffix to avoid overwrite
38+
run_id = 1
39+
while True:
40+
filename = f"{base_name}-{run_id}.{extension}"
41+
full_path = os.path.join(dir_path, filename)
42+
if not os.path.exists(full_path):
43+
return full_path
44+
run_id += 1
45+
46+
@staticmethod
47+
def extractGainsDict(gains) -> dict:
2948
# Create a dictionary from instance variables
3049
gains_dict = {
3150
key: value for key, value in gains.__dict__.items()
@@ -40,16 +59,102 @@ def saveMatlabWorkspaceLog(log_dict, gains, controller_type):
4059
for key in gains_dict_shortened:
4160
if isinstance(gains_dict_shortened[key], np.matrix):
4261
gains_dict_shortened[key] = np.array(gains_dict_shortened[key])
62+
63+
return gains_dict_shortened
64+
65+
@staticmethod
66+
def dataclassToDict(obj, truncate_keys: bool = True):
67+
"""
68+
Recursively convert a dataclass to a nested dictionary,
69+
handling nested dataclasses, NumPy arrays/matrices,
70+
truncating field names to 31 characters if specified,
71+
and filtering out unsupported types.
72+
"""
73+
if is_dataclass(obj):
74+
result = {}
75+
for key, value in obj.__dict__.items():
76+
# Filter to only include serializable fields
77+
if isinstance(value, (int, float, str, bool, np.ndarray, np.matrix)) or is_dataclass(value):
78+
converted_value = Logging.dataclassToDict(value, truncate_keys=truncate_keys)
79+
key = key[:31] if truncate_keys else key
80+
if isinstance(converted_value, np.matrix):
81+
converted_value = np.array(converted_value)
82+
result[key] = converted_value
83+
return result
84+
elif isinstance(obj, dict):
85+
return {k[:31] if truncate_keys else k: Logging.dataclassToDict(v, truncate_keys=truncate_keys) for k, v in obj.items()}
86+
elif isinstance(obj, (list, tuple)):
87+
return [Logging.dataclassToDict(v, truncate_keys=truncate_keys) for v in obj]
88+
elif isinstance(obj, np.matrix):
89+
return np.array(obj)
90+
else:
91+
return obj
92+
93+
@staticmethod
94+
def getGitRepoInfo() -> dict:
95+
def run_git_cmd(args):
96+
try:
97+
return subprocess.check_output(['git'] + args, cwd=repo_dir, stderr=subprocess.DEVNULL).decode().strip()
98+
except subprocess.CalledProcessError:
99+
return ""
100+
except FileNotFoundError:
101+
return ""
102+
103+
def get_github_url(repo_url: str, commit_hash: str) -> str:
104+
if repo_url.endswith(".git"):
105+
repo_url = repo_url[:-4]
106+
return f"{repo_url}/tree/{commit_hash}"
107+
108+
repo_dir = os.getcwd()
109+
110+
if not os.path.exists(os.path.join(repo_dir, '.git')):
111+
warnings.warn("⚠️ Git repository info cannot be tracked. No .git directory found.")
112+
return {
113+
"repo_path": "",
114+
"remote_url": "",
115+
"commit_hash": "",
116+
"commit_tag": "",
117+
"branch": "",
118+
"dirty": 0,
119+
"commit_url": ""
120+
}
121+
122+
remote_url = run_git_cmd(['remote', 'get-url', 'origin']) or ""
123+
commit_hash = run_git_cmd(['rev-parse', 'HEAD']) or ""
124+
125+
git_info = {
126+
"repo_path": os.path.abspath(repo_dir),
127+
"remote_url": remote_url,
128+
"commit_hash": commit_hash,
129+
"commit_tag": run_git_cmd(['describe', '--tags', '--exact-match']) or "",
130+
"branch": run_git_cmd(['rev-parse', '--abbrev-ref', 'HEAD']) or "",
131+
"dirty": 1 if run_git_cmd(['status', '--porcelain']) else 0,
132+
"commit_url": get_github_url(remote_url, commit_hash) or ""
133+
}
134+
135+
return git_info
136+
137+
@staticmethod
138+
def saveMatlabWorkspaceLog(log_dict, gains, sim_cfg: Cfg.SimulationConfig, git_info: dict | None = None):
139+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
140+
dir_path = Logging.getOutputDir(sim_cfg)
141+
os.makedirs(dir_path, exist_ok=True)
142+
143+
base_filename = f"workspace_log_{timestamp}"
144+
full_path_log = Logging.generateUniqueFilename(
145+
base_filename,
146+
"mat",
147+
dir_path,
148+
Cfg.MissionConfig.wrapper_flag
149+
)
43150

44-
# Nest gains inside the log structure
45151
mat_dict = {
46-
"log": log_dict, # Log data
47-
"gains": gains_dict_shortened # Nested gains struct
152+
"log": log_dict,
153+
"gains": Logging.extractGainsDict(gains),
154+
"sim_cfg": Logging.dataclassToDict(sim_cfg)
48155
}
49156

50-
# Construct the full file path for the log file
51-
filename_log = f"workspace_log_{timestamp}.mat"
52-
full_path_log = os.path.join(dir_path, filename_log)
157+
if git_info is not None:
158+
mat_dict["git_info"] = git_info
53159

54-
# Save the .mat file with the nested gains structure
55160
savemat(full_path_log, mat_dict)

0 commit comments

Comments
 (0)