Skip to content

Commit 3bb1f82

Browse files
committed
2 parents e57f87b + a440203 commit 3bb1f82

File tree

7 files changed

+333
-32
lines changed

7 files changed

+333
-32
lines changed

massbalancemachine/data_processing/Dataset.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class Dataset:
5959
region_id (str): The region ID, for saving the files accordingly and eventually downloading them if needed
6060
data_dir (str): Path to the directory containing the raw data, and save intermediate results
6161
RGIIds (pd.Series): Series of RGI IDs from the data
62+
output_format (str): csv or parquet
6263
months_tail_pad (list of str): Months to pad the start of the hydrological year
6364
months_head_pad (list of str): Months to pad the end of the hydrological year
6465
"""
@@ -70,6 +71,7 @@ def __init__(
7071
region_name: str,
7172
region_id: int,
7273
data_path: str,
74+
output_format: str = "csv",
7375
months_tail_pad=None, #: List[str] = ['aug_', 'sep_'], # before 'oct'
7476
months_head_pad=None, #: List[str] = ['oct_'], # after 'sep'
7577
):
@@ -81,7 +83,8 @@ def __init__(
8183
self.RGIIds = self.data["RGIId"]
8284
if not os.path.isdir(self.data_dir):
8385
os.makedirs(self.data_dir, exist_ok=True)
84-
86+
assert output_format in ["csv", "parquet"], "format must be csv or parquet"
87+
self.output_format = output_format
8588
# Padding to allow for flexible month ranges (customize freely)
8689
assert (months_head_pad is None) == (
8790
months_tail_pad is None
@@ -101,7 +104,9 @@ def get_topo_features(self, vois: list[str], custom_working_dir: str = "") -> No
101104
vois (list[str]): A string containing the topographical variables of interest
102105
custom_working_dir (str, optional): The path to the custom working directory for OGGM data. Default to ''
103106
"""
104-
output_fname = self._get_output_filename("topographical_features")
107+
output_fname = self._get_output_filename(
108+
"topographical_features", self.output_format
109+
)
105110
self.data = get_topographical_features(
106111
self.data, output_fname, vois, self.RGIIds, custom_working_dir, self.cfg
107112
)
@@ -124,7 +129,7 @@ def get_climate_features(
124129
change_units (bool, optional): A boolean indicating whether to change the units of the climate data. Default to False.
125130
smoothing_vois (dict, optional): A dictionary containing the variables of interest for smoothing climate artifacts. Default to None.
126131
"""
127-
output_fname = self._get_output_filename("climate_features")
132+
output_fname = self._get_output_filename("climate_features", self.output_format)
128133

129134
smoothing_vois = smoothing_vois or {} # Safely default to empty dict
130135
vois_climate = smoothing_vois.get("vois_climate")
@@ -207,9 +212,14 @@ def convert_to_monthly(
207212
"""
208213
if meta_data_columns is None:
209214
meta_data_columns = self.cfg.metaData
210-
output_fname = self._get_output_filename("monthly_dataset")
215+
output_fname = self._get_output_filename("monthly_dataset", self.output_format)
211216
self.data = transform_to_monthly(
212-
self.data, meta_data_columns, vois_climate, vois_topographical, output_fname
217+
self.data,
218+
meta_data_columns,
219+
vois_climate,
220+
vois_topographical,
221+
output_fname,
222+
self.output_format,
213223
)
214224

215225
def get_glacier_mask(
@@ -254,17 +264,19 @@ def create_glacier_grid_RGI(self, custom_working_dir: str = "") -> pd.DataFrame:
254264
df_grid = create_glacier_grid_RGI(ds, years, glacier_indices, gdir, rgi_gl)
255265
return df_grid
256266

257-
def _get_output_filename(self, feature_type: str) -> str:
267+
def _get_output_filename(self, feature_type: str, output_format: str) -> str:
258268
"""
259269
Generates the output filename for a given feature type.
260270
261271
Args:
262272
feature_type (str): The type of feature (e.g., "topographical_features", "climate_features", "monthly")
263-
273+
format : csv or parquet
264274
Returns:
265275
str: The full path to the output file
266276
"""
267-
return os.path.join(self.data_dir, f"{self.region}_{feature_type}.csv")
277+
return os.path.join(
278+
self.data_dir, f"{self.region}_{feature_type}.{output_format}"
279+
)
268280

269281
def _copy_padded_month_columns(
270282
self, df: pd.DataFrame, prefixes=("pcsr",), overwrite: bool = False
@@ -380,7 +392,6 @@ def __init__(
380392
self.metadata = metadata
381393
self.metadataColumns = metadataColumns or self.cfg.metaData
382394
self.targets = targets
383-
384395
assert len(self.features) > 0, "The features variable is empty."
385396

386397
_, self.month_pos = _rebuild_month_index(months_head_pad, months_tail_pad)
@@ -391,10 +402,8 @@ def __init__(
391402
for i in range(len(self.metadata))
392403
]
393404
)
394-
self.uniqueID = np.unique(self.ID)
395-
self.maxConcatNb = max(
396-
[len(np.argwhere(self.ID == id)[:, 0]) for id in self.uniqueID]
397-
)
405+
self.uniqueID, counts = np.unique(self.ID, return_counts=True)
406+
self.maxConcatNb = counts.max()
398407
self.nbFeatures = self.features.shape[1]
399408
self.nbMetadata = self.metadata.shape[1]
400409
self.norm = Normalizer({k: cfg.bnds[k] for k in cfg.featureColumns})
@@ -416,14 +425,19 @@ def mapSplitsToDataset(
416425
corresponding indices the cross validation should use according to
417426
the input splits variable.
418427
"""
428+
# Precompute the mapping of unique IDs to indices
429+
uniqueID_to_indices = {
430+
uid: np.where(self.uniqueID == uid)[0] for uid in self.uniqueID
431+
}
419432
ret = []
420433
for split in splits:
421434
t = []
422435
for e in split:
423-
uniqueSelectedId = np.unique(self.ID[e])
424-
ind = np.argwhere(self.uniqueID[None, :] == uniqueSelectedId[:, None])[
425-
:, 1
426-
]
436+
uniqueSelectedId = np.unique(self.ID[e]) # Get the unique selected IDs
437+
# Use the precomputed mapping for fast lookups
438+
ind = np.concatenate(
439+
[uniqueID_to_indices[uid] for uid in uniqueSelectedId]
440+
)
427441
assert all(uniqueSelectedId == self.uniqueID[ind])
428442
t.append(ind)
429443
ret.append(tuple(t))
@@ -470,6 +484,9 @@ def indexToId(self, index):
470484
return self.uniqueID[index]
471485

472486
def indexToMetadata(self, index):
487+
"""
488+
Returns the metadata for a given index.
489+
"""
473490
ind = self._getInd(index)
474491
return self.metadata[ind][:, :]
475492

massbalancemachine/data_processing/climate_data_download.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99

1010
def path_climate_data(region):
11+
"""Return path of data for a given region (string or integer)."""
1112
if not isinstance(region, str):
1213
region = f"{region:02d}"
1314
return f".data/{region}/"

massbalancemachine/data_processing/transform_to_monthly.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def transform_to_monthly(
1818
vois_climate: "list[str]",
1919
vois_topographical: "list[str]",
2020
output_fname: str,
21+
output_format: str,
2122
) -> pd.DataFrame:
2223
"""
2324
Converts the DataFrame to a monthly format based on climate-related columns.
@@ -52,8 +53,12 @@ def transform_to_monthly(
5253
# Create the final dataframe with the new exploded climate data
5354
result_df = _create_result_dataframe(df_exploded, column_names, vois_climate)
5455

55-
result_df.to_csv(output_fname, index=False)
56-
56+
if output_format == "csv":
57+
result_df.to_csv(output_fname, index=False)
58+
elif output_format == "parquet":
59+
result_df.to_parquet(output_fname, index=False)
60+
else:
61+
print("output format must be csv or parquet")
5762
return result_df
5863

5964

massbalancemachine/dataloader/DataLoader.py

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,18 @@ def __init__(
6565
self.meta_data_columns = meta_data_columns or cfg.metaData
6666

6767
def set_train_test_split(
68-
self, *, test_size: float = None, type_fold: str = "group-meas-id"
68+
self,
69+
*,
70+
test_size: float = None,
71+
type_fold: str = "group-meas-id",
72+
random_state: bool = False,
6973
) -> Tuple[Iterator[Any], Iterator[Any]]:
7074
"""
7175
Split the dataset into training and testing sets.
7276
7377
Args:
7478
test_size (float): Proportion of the dataset to include in the test split.
75-
type_fold (str): Type of splitting between train and test sets. Options are 'group-rgi', or 'group-meas-id'.
79+
type_fold (str): Type of splitting between train and test sets. Options are 'group-rgi','group-c_region' or 'group-meas-id'.
7680
7781
Returns:
7882
Tuple[Iterator[Any], Iterator[Any]]: Iterators for training and testing indices.
@@ -89,15 +93,25 @@ def set_train_test_split(
8993
# I.e, one year of a stake is not split amongst test and train set
9094

9195
# From the data get the features, targets, and glacier IDS
92-
X, y, glacier_ids, stake_meas_id = self._prepare_data_for_cv(
96+
X, y, glacier_ids, stake_meas_id, regions = self._prepare_data_for_cv(
9397
self.data, self.meta_data_columns
9498
)
95-
gss = GroupShuffleSplit(
96-
n_splits=1, test_size=test_size, random_state=self.random_seed
97-
)
98-
groups = {"group-meas-id": stake_meas_id, "group-rgi": glacier_ids}.get(
99-
type_fold
100-
)
99+
if random_state == False:
100+
gss = GroupShuffleSplit(
101+
n_splits=1,
102+
test_size=test_size,
103+
random_state=self.random_seed, # commenting this improve randomness
104+
)
105+
elif random_state == True:
106+
gss = GroupShuffleSplit(
107+
n_splits=1,
108+
test_size=test_size,
109+
)
110+
groups = {
111+
"group-meas-id": stake_meas_id,
112+
"group-rgi": glacier_ids,
113+
"group-c_region": regions,
114+
}.get(type_fold)
101115
train_indices, test_indices = next(gss.split(X, y, groups))
102116

103117
# Check that the intersection train and test ids is empty
@@ -108,9 +122,23 @@ def set_train_test_split(
108122
# Make it iterators and set as an attribute of the class
109123
self.train_indices = train_indices
110124
self.test_indices = test_indices
111-
112125
return iter(self.train_indices), iter(self.test_indices)
113126

127+
def assign_train_test_indices(self, train_indices, test_indices, test_size):
128+
"""
129+
Assign `train_indices`, `test_indices` as well as `test_size` attributes of the object.
130+
131+
Note:
132+
This can be useful when you divide the train and test ensembles based on subregion since this requires to make the sampling N times and then choose the
133+
train-test division closest to, for example, the 70-30 repartition. At each iteration the Dataloader object is redifined as well as
134+
self.train_indices and self.test_indices meaning that the information in the Dataloader object are those of the last iterations
135+
and not those of the train-test division chosen after comparing to the 70-30 repartition.
136+
This function aims to correct this by reassigning the indices of the chosen sampling.
137+
"""
138+
self.train_indices = train_indices
139+
self.test_indices = test_indices
140+
self.test_size = test_size
141+
114142
def set_custom_train_test_indices(
115143
self, train_indices: np.array, test_indices: np.array
116144
):
@@ -157,13 +185,13 @@ def get_cv_split(
157185
train_data = self._get_train_data()
158186

159187
# From the training data get the features, targets, and glacier IDS
160-
X, y, glacier_ids, stake_meas_id = self._prepare_data_for_cv(
188+
X, y, glacier_ids, stake_meas_id, regions = self._prepare_data_for_cv(
161189
train_data, self.meta_data_columns
162190
)
163191

164192
# Create the cross validation splits
165193
splits = self._create_group_kfold_splits(
166-
X, y, glacier_ids, stake_meas_id, type_fold
194+
X, y, glacier_ids, stake_meas_id, regions, type_fold
167195
)
168196
self.cv_split = splits
169197

@@ -239,14 +267,20 @@ def _prepare_data_for_cv(
239267
y = train_data["POINT_BALANCE"]
240268
glacier_ids = train_data["RGIId"].values
241269
stake_meas_id = train_data["ID"].values # unique value per stake measurement
242-
return X, y, glacier_ids, stake_meas_id
270+
regions = (
271+
train_data["C_REGION"].values
272+
if "C_REGION" in train_data.columns
273+
else np.array([])
274+
)
275+
return X, y, glacier_ids, stake_meas_id, regions
243276

244277
def _create_group_kfold_splits(
245278
self,
246279
X: pd.DataFrame,
247280
y: pd.Series,
248281
glacier_ids: np.ndarray,
249282
stake_meas_id: np.ndarray,
283+
regions: np.ndarray,
250284
type_fold: str,
251285
) -> List[Tuple[np.ndarray, np.ndarray]]:
252286
"""
@@ -268,6 +302,7 @@ def _create_group_kfold_splits(
268302
fold_types = {
269303
"group-rgi": (GroupKFold, glacier_ids),
270304
"group-meas-id": (GroupKFold, stake_meas_id),
305+
"group-c_region": (GroupKFold, regions),
271306
}
272307

273308
FoldClass, groups = fold_types.get(type_fold, (KFold, None))
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
from plots.perf_plots import predVSTruth, predVSTruthPerGlacier
22
from plots.style import use_mbm_style, COLOR_ANNUAL, COLOR_WINTER
3+
from plots.input_plot import histogram_mb, scatterplot_mb
4+
from plots.train_plots import plot_training_history

0 commit comments

Comments
 (0)