Skip to content

Commit 3118e36

Browse files
authored
feat: better variables filtering (#24)
1 parent 7f6e78b commit 3118e36

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

src/comment_components.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,15 +158,23 @@ class _Variables(CommentData):
158158

159159
VARIABLES_FILE = "KN2045_Bal_v4/ariadne/exported_variables_full.xlsx"
160160
NRMSE_NORMALIZATION_METHOD = "combined-min-max"
161-
NRMSE_THRESHOLD = 0.3
161+
NRMSE_THRESHOLD = 0.1
162+
NRMSE_MINIMUM_THRESHOLD = 1e-3
163+
MAX_PLOTS = 20
162164

163165
_variables_deviation_df = None
164166

165-
@staticmethod
166167
def get_deviation_df(
167-
df1: pd.DataFrame, df2: pd.DataFrame, nrmse_normalization_method: str
168+
self, df1: pd.DataFrame, df2: pd.DataFrame, nrmse_normalization_method: str
168169
) -> pd.DataFrame:
169170
"""Calculate deviation dataframe between two dataframes."""
171+
# Remove all variables smaller than minimum threshold
172+
minimum_mask = ((df1.abs() < self.NRMSE_MINIMUM_THRESHOLD).all(axis=1)) & (
173+
(df2.abs() < self.NRMSE_MINIMUM_THRESHOLD).all(axis=1)
174+
)
175+
df1 = df1.loc[~minimum_mask]
176+
df2 = df2.loc[~minimum_mask]
177+
170178
nrmse_series = df1.apply(
171179
lambda row: normalized_root_mean_square_error(
172180
row.values,
@@ -182,12 +190,12 @@ def get_deviation_df(
182190

183191
if df1.empty:
184192
return pd.DataFrame(columns=["NRMSE", "Pearson"])
185-
else:
186-
deviation_df = pd.DataFrame(
187-
{"NRMSE": nrmse_series, "Pearson": pearson_series}
188-
).sort_values(by="NRMSE", ascending=False)
189193

190-
return deviation_df
194+
deviation_df = pd.DataFrame(
195+
{"NRMSE": nrmse_series, "Pearson": pearson_series}
196+
).sort_values(by="NRMSE", ascending=False)
197+
198+
return deviation_df
191199

192200
@property
193201
def variables_deviation_df(self) -> pd.DataFrame:
@@ -216,11 +224,12 @@ def variables_deviation_df(self) -> pd.DataFrame:
216224
return self._variables_deviation_df
217225

218226
def variables_plot_strings(self) -> list:
219-
"""Return list of variable plot strings."""
227+
"""Return list of variable plot strings. Maximum defined by MAX_PLOTS."""
220228
plots = (
221229
self.variables_deviation_df.index.to_series()
222230
.apply(lambda x: re.sub(r"[ |/]", "-", x))
223231
.apply(lambda x: "ariadne_comparison/" + x + ".png")
232+
.iloc[: self.MAX_PLOTS]
224233
.to_list()
225234
)
226235
return plots
@@ -272,7 +281,16 @@ def changed_variables_plots(self) -> str:
272281
rows,
273282
columns=pd.Index(["Main branch", "Feature branch"]),
274283
)
275-
return df.to_html(escape=False, index=False) + "\n"
284+
285+
if len(df) == self.MAX_PLOTS:
286+
annotation = (
287+
f":warning: Note: Only the first {self.MAX_PLOTS} variables are shown, "
288+
"but more are above the threshold. Find all of them in the artifacts."
289+
)
290+
else:
291+
annotation = ""
292+
293+
return df.to_html(escape=False, index=False) + "\n" + annotation + "\n"
276294

277295
@property
278296
def body(self) -> str:

0 commit comments

Comments
 (0)