Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions Scripts/helmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def main(args):
# Setup validation if validation folder exists
validation = Validation()
validation_path = Path(forecast_zonedata_path) / 'validations'
validation_path = Path(__file__).parent / 'validations' # TODO: Delete this line
if validation_path.exists():
# Load event listeners from 'forecast/validation' folder
event_handler.load_listeners(validation_path)
Expand Down Expand Up @@ -170,10 +169,13 @@ def main(args):
log.info("Not able to remove file {}.".format(f))
log.info("Removed strategy files in {}".format(dbase_path))
if validation is not None:
validation.to_html(
Path(results_path) / args.scenario_name / 'validation.html')
validation.save_to_file(
Path(results_path) / args.scenario_name / 'validation.pklz')
try:
validation.to_html(
Path(results_path) / args.scenario_name / 'validation.html')
validation.save_to_file(
Path(results_path) / args.scenario_name / 'validation.pklz')
except Exception as e:
log.error("Error saving validation data: {}".format(e))
log.info("Simulation ended.", extra=log_extra)


Expand Down
1 change: 1 addition & 0 deletions Scripts/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
openpyxl==2.6.4;python_version<"3.8"
openpyxl==3.1.4;python_version>="3.8"
plotly
174 changes: 66 additions & 108 deletions Scripts/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def __getstate__(self):
state['_aggregations'] = []
state['_error_terms'] = {}
state['_visualizations'] = {}
# Remove columns corresponding to error terms from the items DataFrame
state['_items'] = state['_items'].drop(columns=self._error_terms.keys(), errors='ignore')
return state

def __setstate__(self, state):
Expand Down Expand Up @@ -260,65 +262,7 @@ def to_html(self, file_path: Path = None) -> str:
Returns:
str: The generated HTML content as a string.
"""
html_content = """
<html>
<head>
<title>Validation Results</title>
<style>
body {
font-family: Arial, sans-serif;
}
.collapsible {
background-color: #4CAF50;
color: white;
cursor: pointer;
padding: 10px;
width: 100%;
border: none;
text-align: left;
outline: none;
font-size: 15px;
margin-bottom: 5px;
}
.active, .collapsible:hover {
background-color: #45a049;
}
.content {
padding: 0 18px;
display: none;
overflow: hidden;
background-color: #f9f9f9;
margin-bottom: 10px;
}
table {
border-collapse: collapse;
width: 100%;
margin-bottom: 20px;
}
th, td {
text-align: left;
padding: 8px;
border: 1px solid #ddd;
}
th {
background-color: #f2f2f2;
cursor: pointer;
}
th.sortable:hover {
background-color: #ddd;
}
h2 {
color: #333;
}
h3 {
color: #555;
}
</style>
</head>
<body>
<h1>Validation Results</h1>
"""

html_content = ""
for group_name, group in self.groups.items():
items = group.get_items()
html_content += f"""
Expand Down Expand Up @@ -360,46 +304,11 @@ def to_html(self, file_path: Path = None) -> str:
html_content += f"<tr><td>{metric_name}</td><td>{value}</td></tr>"
html_content += "</table></div>"

html_content += """
<script>
var coll = document.getElementsByClassName("collapsible");
for (var i = 0; i < coll.length; i++) {
coll[i].addEventListener("click", function() {
this.classList.toggle("active");
var content = this.nextElementSibling;
if (content.style.display === "block") {
content.style.display = "none";
} else {
content.style.display = "block";
}
});
}

function sortTable(header, colIndex) {
var table = header.closest('table');
var rows = Array.from(table.querySelectorAll('tbody > tr'));
var isAsc = header.classList.toggle('asc');
rows.sort((rowA, rowB) => {
var cellA = rowA.children[colIndex].textContent.trim();
var cellB = rowB.children[colIndex].textContent.trim();

var numA = parseFloat(cellA);
var numB = parseFloat(cellB);

if (!isNaN(numA) && !isNaN(numB)) {
// Both values are numbers
return isAsc ? numA - numB : numB - numA;
} else {
// At least one value is not a number, compare as strings
return isAsc ? cellA.localeCompare(cellB, undefined, {numeric: true}) : cellB.localeCompare(cellA, undefined, {numeric: true});
}
});
rows.forEach(row => table.querySelector('tbody').appendChild(row));
}
</script>
</body>
</html>
"""
template_path = Path(__file__).parent / "validation_template.html"
with open(template_path, 'r', encoding='utf8') as template_file:
template_content = template_file.read()
html_content = template_content.replace("{{content}}", html_content)

if file_path:
with open(file_path, 'w', encoding='utf8') as file:
file.write(html_content)
Expand Down Expand Up @@ -446,17 +355,26 @@ def weighted_mean(source: str, weight: str = 'weight'):
# Visualizations
def scatter_plot(x: str = 'expected',
y: str = 'prediction',
color:str=None,
colormap: str='viridis',
show_diagonal: bool=True) -> Callable[[pd.DataFrame], str]:
color: str = None,
colormap: str = 'viridis',
show_diagonal: bool = True,
discrete_colors: bool = False) -> Callable[[pd.DataFrame], str]:
def _scatter_plot(df: pd.DataFrame) -> str:
try:
import plotly.express as px
import plotly.graph_objects as go
except ImportError:
return "Plotly is not installed. Please install it using 'pip install plotly'"

fig = px.scatter(df, x=x, y=y, color=color, color_continuous_scale=colormap)
if color is not None and discrete_colors:
# Use color_discrete_map='identity' for categorical colors
fig = px.scatter(df, x=x, y=y, color=color,
color_discrete_sequence=px.colors.qualitative.Plotly)
else:
# Original behavior for continuous colors
fig = px.scatter(df, x=x, y=y, color=color,
color_continuous_scale=colormap)

if show_diagonal:
min_val = min(df[x].min(), df[y].min())
max_val = max(df[x].max(), df[y].max())
Expand All @@ -469,7 +387,24 @@ def _scatter_plot(df: pd.DataFrame) -> str:
return fig.to_html()
return _scatter_plot

def bar_plot(x: str = 'id', y: Union[str, List[str]] = None) -> Callable[[pd.DataFrame], str]:
def bar_plot(x: str = 'id',
y: Union[str, List[str]] = None,
stacked: bool = False,
color_palette: str = 'Plotly') -> Callable[[pd.DataFrame], str]:
"""
Creates a bar plot visualization function.

Args:
x (str): The column name to use for the x-axis. Defaults to 'id'.
y (Union[str, List[str]]): The column name(s) to plot on the y-axis.
If None, uses ['prediction', 'expected']. Defaults to None.
stacked (bool): Whether to stack the bars or group them. Defaults to False.
color_palette (str): The color palette to use. Options: 'Plotly', 'G10',
'Dark24', 'Light24', 'Pastel', etc. Defaults to 'Plotly'.

Returns:
Callable[[pd.DataFrame], str]: A function that takes a DataFrame and returns HTML.
"""
if y is None:
y = ['prediction', 'expected']
if isinstance(y, str):
Expand All @@ -478,16 +413,36 @@ def bar_plot(x: str = 'id', y: Union[str, List[str]] = None) -> Callable[[pd.Dat
def _bar_plot(df: pd.DataFrame) -> str:
try:
import plotly.graph_objects as go
import plotly.express as px
except ImportError:
return "Plotly is not installed. Please install it using 'pip install plotly'"

# Get the appropriate color palette
try:
color_seq = getattr(px.colors.qualitative, color_palette)
except AttributeError:
# Fall back to default Plotly colors if the palette doesn't exist
color_seq = px.colors.qualitative.Plotly

fig = go.Figure()
colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'cyan', 'magenta']

for i, y_col in enumerate(y):
fig.add_trace(go.Bar(name=y_col, x=df[x], y=df[y_col], marker_color=colors[i % len(colors)]))
fig.add_trace(go.Bar(
name=y_col,
x=df[x],
y=df[y_col],
marker_color=color_seq[i % len(color_seq)]
))

# Set barmode based on stacked parameter
barmode = 'stack' if stacked else 'group'
fig.update_layout(
barmode=barmode,
xaxis_title=x,
yaxis_title='Value',
legend_title_text='Metrics'
)

fig.update_layout(barmode='group', xaxis_title=x, yaxis_title='Value')
return fig.to_html()

return _bar_plot
Expand Down Expand Up @@ -535,4 +490,7 @@ def _bar_plot(df: pd.DataFrame) -> str:
# test_valid.save_to_file('test_validation.pklz')
# # Load the validation object from a file
# loaded_valid = Validation.load_from_file('test_validation.pklz')
# loaded_valid.to_html('test_validation_loaded.html')
# loaded_valid.to_html('test_validation_loaded.html')
# # Open the generated HTML file in the default web browser
# import webbrowser
# webbrowser.open('test_validation.html')
94 changes: 94 additions & 0 deletions Scripts/utils/validation_template.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
<html>
<head>
<title>Validation Results</title>
<style>
body {
font-family: Arial, sans-serif;
}
.collapsible {
background-color: #4CAF50;
color: white;
cursor: pointer;
padding: 10px;
width: 100%;
border: none;
text-align: left;
outline: none;
font-size: 15px;
margin-bottom: 5px;
}
.active, .collapsible:hover {
background-color: #45a049;
}
.content {
padding: 0 18px;
display: none;
overflow: hidden;
background-color: #f9f9f9;
margin-bottom: 10px;
}
table {
border-collapse: collapse;
width: 100%;
margin-bottom: 20px;
}
th, td {
text-align: left;
padding: 8px;
border: 1px solid #ddd;
}
th {
background-color: #f2f2f2;
cursor: pointer;
}
th.sortable:hover {
background-color: #ddd;
}
h2 {
color: #333;
}
h3 {
color: #555;
}
</style>
</head>
<body>
<script>
var coll = document.getElementsByClassName("collapsible");
for (var i = 0; i < coll.length; i++) {
coll[i].addEventListener("click", function() {
this.classList.toggle("active");
var content = this.nextElementSibling;
if (content.style.display === "block") {
content.style.display = "none";
} else {
content.style.display = "block";
}
});
}

function sortTable(header, colIndex) {
var table = header.closest('table');
var rows = Array.from(table.querySelectorAll('tbody > tr'));
var isAsc = header.classList.toggle('asc');
rows.sort((rowA, rowB) => {
var cellA = rowA.children[colIndex].textContent.trim();
var cellB = rowB.children[colIndex].textContent.trim();

var numA = parseFloat(cellA);
var numB = parseFloat(cellB);

if (!isNaN(numA) && !isNaN(numB)) {
// Both values are numbers
return isAsc ? numA - numB : numB - numA;
} else {
// At least one value is not a number, compare as strings
return isAsc ? cellA.localeCompare(cellB, undefined, {numeric: true}) : cellB.localeCompare(cellA, undefined, {numeric: true});
}
});
rows.forEach(row => table.querySelector('tbody').appendChild(row));
}
</script>
{{content}}
</body>
</html>
8 changes: 4 additions & 4 deletions Scripts/validations/car_volumes_and_speeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def is_disabled(self) -> bool:

def create_vol_group(self, name: str) -> ValidationGroup:
vol_group = self.validation.create_group(name)
vol_group.add_visualization('Volumes vs survey', scatter_plot(x='expected', y='prediction'))
vol_group.add_visualization('Volumes vs Helmet4', scatter_plot(x='helmet4', y='prediction'))
vol_group.add_visualization('Volumes vs survey', scatter_plot(x='expected', y='prediction', color='kuntaryhma', discrete_colors=True))
vol_group.add_visualization('Volumes vs Helmet4', scatter_plot(x='helmet4', y='prediction', color='kuntaryhma', discrete_colors=True))
vol_group.add_aggregation('mean absolute error', mae, group_by='kuntaryhma')
vol_group.add_aggregation('mean relative error', mean('relative_error'), group_by='kuntaryhma')
return vol_group

def create_speed_group(self, name: str) -> ValidationGroup:
speed_group = self.validation.create_group(name)
speed_group.add_visualization('Speed vs survey', scatter_plot(x='expected', y='prediction'))
speed_group.add_visualization('Speed vs survey', scatter_plot(x='expected', y='prediction', color='kuntaryhma', discrete_colors=True))
speed_group.add_aggregation('mean relative error', mean('relative_error'), group_by='kuntaryhma')
return speed_group

Expand Down Expand Up @@ -128,7 +128,7 @@ def _add_to_validation_group(group: ValidationGroup,
kuntaryhma=row['kuntaryhma'])

def _sum_volumes(link, suffix: str) -> float:
attributes = ['@car_work', '@car_leisure', '@truck', '@trailer_truck', '@van', '@bus']
attributes = ['@car_work', '@car_leisure', '@van']
return sum(link[f'{attr}_{suffix}'] for attr in attributes)

def _get_link(network: 'Network', id: Tuple[int, int]) -> 'Link':
Expand Down