File tree Expand file tree Collapse file tree 2 files changed +21
-0
lines changed
chebai/preprocessing/datasets Expand file tree Collapse file tree 2 files changed +21
-0
lines changed Original file line number Diff line number Diff line change @@ -401,6 +401,21 @@ def setup(self, **kwargs):
401401 if not ("keep_reader" in kwargs and kwargs ["keep_reader" ]):
402402 self .reader .on_finish ()
403403
404+ self ._add_num_of_labels_to_hparams ()
405+
406+ def _add_num_of_labels_to_hparams (self ):
407+ num_of_labels = len (
408+ torch .load (
409+ os .path .join (
410+ self .processed_dir , self .processed_file_names_dict ["data" ]
411+ ),
412+ weights_only = False ,
413+ )[0 ]["labels" ]
414+ )
415+
416+ print (f"Number of labels for loaded data: { num_of_labels } " )
417+ self .hparams .num_of_labels = num_of_labels
418+
404419 def setup_processed (self ):
405420 """
406421 Setup the processed data.
@@ -541,6 +556,8 @@ def setup(self, **kwargs):
541556 for s in self .subsets :
542557 s .setup (** kwargs )
543558
559+ self ._add_num_of_labels_to_hparams ()
560+
544561 def dataloader (self , kind : str , ** kwargs ) -> DataLoader :
545562 """
546563 Creates a DataLoader for a specific subset.
Original file line number Diff line number Diff line change @@ -129,6 +129,8 @@ def setup(self, **kwargs) -> None:
129129 ):
130130 self .setup_processed ()
131131
132+ self ._add_num_of_labels_to_hparams ()
133+
132134 def _load_data_from_file (self , input_file_path : str ) -> List [Dict ]:
133135 """Loads data from a CSV file.
134136
@@ -311,6 +313,8 @@ def setup(self, **kwargs) -> None:
311313 ):
312314 self .setup_processed ()
313315
316+ self ._add_num_of_labels_to_hparams ()
317+
314318 def _load_dict (self , input_file_path : str ) -> Generator [Dict , None , None ]:
315319 """Loads data from a CSV file as a generator.
316320
You can’t perform that action at this time.
0 commit comments