diff --git a/DocumentUnderstanding/GeoLayoutLM/evaluate.py b/DocumentUnderstanding/GeoLayoutLM/evaluate.py index ecd5355..0a768a8 100644 --- a/DocumentUnderstanding/GeoLayoutLM/evaluate.py +++ b/DocumentUnderstanding/GeoLayoutLM/evaluate.py @@ -56,11 +56,10 @@ def main(): ) dataset = VIEDataset( - cfg.dataset, + cfg.dataset_root_path, cfg.task, backbone_type, cfg.model.head, - cfg.dataset_root_path, net.tokenizer, mode=mode, ) diff --git a/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_data_module.py b/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_data_module.py index 0ad1fa1..91e1f64 100644 --- a/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_data_module.py +++ b/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_data_module.py @@ -47,11 +47,10 @@ def _get_train_loader(self): start_time = time.time() dataset = VIEDataset( - self.cfg.dataset, + self.cfg.dataset_root_path, self.cfg.task, self.backbone_type, self.cfg.model.head, - self.cfg.dataset_root_path, self.tokenizer, self.cfg.train.max_seq_length, self.cfg.train.max_block_num, @@ -74,11 +73,10 @@ def _get_train_loader(self): def _get_val_test_loaders(self, mode): dataset = VIEDataset( - self.cfg.dataset, + self.cfg.dataset_root_path, self.cfg.task, self.backbone_type, self.cfg.model.head, - self.cfg.dataset_root_path, self.tokenizer, self.cfg.train.max_seq_length, self.cfg.train.max_block_num, diff --git a/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_dataset.py b/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_dataset.py index 0555c23..07221f9 100644 --- a/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_dataset.py +++ b/DocumentUnderstanding/GeoLayoutLM/lightning_modules/data_modules/vie_dataset.py @@ -13,11 +13,10 @@ class VIEDataset(Dataset): def __init__( self, - dataset, + dataset_root_path, task, backbone_type, model_head, - dataset_root_path, tokenizer, max_seq_length=512, max_block_num=256, @@ -25,12 +24,11 @@ def __init__( img_w=768, mode=None, ): - self.dataset = dataset + self.dataset_root_path = dataset_root_path self.task = task self.backbone_type = backbone_type self.model_head = model_head - self.dataset_root_path = dataset_root_path self.tokenizer = tokenizer self.max_seq_length = max_seq_length self.max_block_num = max_block_num