|
5 | 5 | import numpy as np |
6 | 6 | from .data_loader import KTDataset |
7 | 7 | from .dkt_forget_dataloader import DktForgetDataset |
8 | | -from .cdkt_dataloader import CDKTDataset |
| 8 | +from .atdkt_dataloader import ATDKTDataset |
9 | 9 | from .lpkt_dataloader import LPKTDataset |
10 | 10 | from .lpkt_utils import generate_time2idx |
11 | 11 | from .que_data_loader import KTQueDataset |
@@ -39,12 +39,12 @@ def init_test_datasets(data_config, model_name, batch_size): |
39 | 39 | concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) |
40 | 40 | test_question_dataset = None |
41 | 41 | test_question_window_dataset= None |
42 | | - elif model_name in ["cdkt"]: |
43 | | - test_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) |
44 | | - test_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) |
| 42 | + elif model_name in ["atdkt"]: |
| 43 | + test_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) |
| 44 | + test_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) |
45 | 45 | if "test_question_file" in data_config: |
46 | | - test_question_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) |
47 | | - test_question_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) |
| 46 | + test_question_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True) |
| 47 | + test_question_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True) |
48 | 48 | else: |
49 | 49 | test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1}) |
50 | 50 | test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1}) |
@@ -96,9 +96,9 @@ def init_dataset4train(dataset_name, model_name, data_config, i, batch_size): |
96 | 96 | curtrain = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]), |
97 | 97 | input_type=data_config["input_type"], folds=all_folds - {i}, |
98 | 98 | concept_num=data_config['num_c'], max_concepts=data_config['max_concepts']) |
99 | | - elif model_name in ["cdkt"]: |
100 | | - curvalid = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) |
101 | | - curtrain = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) |
| 99 | + elif model_name in ["atdkt"]: |
| 100 | + curvalid = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) |
| 101 | + curtrain = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) |
102 | 102 | else: |
103 | 103 | curvalid = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i}) |
104 | 104 | curtrain = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i}) |
|
0 commit comments