Skip to content

Commit 4db3fbe

Browse files
pykt-teamLiu-lqq
andauthored
change atdkt and simplekt name (#90)
Co-authored-by: Liu-lqq <liuqiongqiong91@163.com>
1 parent 173f807 commit 4db3fbe

File tree

14 files changed

+60
-266
lines changed

14 files changed

+60
-266
lines changed
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
program: wandb_cdkt_train.py
1+
program: wandb_atdkt_train.py
22
method: bayes
33
metric:
44
goal: maximize
55
name: validauc
66
parameters:
77
model_name:
8-
values: ["cdkt"]
8+
values: ["atdkt"]
99
dataset_name:
1010
values: ["xes"]
1111
emb_type:
1212
values: ["qiddelxembhistranscembpredcurc"]
1313
save_dir:
14-
values: ["models/cdkt_tiaocan"]
14+
values: ["models/atdkt_tiaocan"]
1515
emb_size:
1616
values: [64, 256]
1717
num_attn_heads:
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
program: ./wandb_bakt_train.py
1+
program: ./wandb_simplekt_train.py
22
method: bayes
33
metric:
44
goal: maximize
55
name: validauc
66
parameters:
77
model_name:
8-
values: ["bakt"]
8+
values: ["simplekt"]
99
dataset_name:
1010
values: ["xes"]
1111
emb_type:
1212
values: ["qid"]
1313
save_dir:
14-
values: ["models/bakt_tiaocan"]
14+
values: ["models/simplekt_tiaocan"]
1515
d_model:
1616
values: [64, 256]
1717
d_ff:
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
parser = argparse.ArgumentParser()
66

77
parser.add_argument("--dataset_name", type=str, default="algebra2005")
8-
parser.add_argument("--model_name", type=str, default="cdkt")
8+
parser.add_argument("--model_name", type=str, default="atdkt")
99
parser.add_argument("--emb_type", type=str, default="qid")
1010
parser.add_argument("--save_dir", type=str, default="saved_model")
1111
parser.add_argument("--seed", type=int, default=3407)

examples/wandb_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def main(params):
2525
trained_params = config["params"]
2626
model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"]
2727
seq_len = config["train_config"]["seq_len"]
28-
if model_name in ["saint", "sakt", "cdkt"]:
28+
if model_name in ["saint", "sakt", "atdkt"]:
2929
model_config["seq_len"] = seq_len
3030
data_config = config["data_config"]
3131

examples/wandb_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def main(params):
2929
del model_config[remove_item]
3030
trained_params = config["params"]
3131
model_name, dataset_name, emb_type = trained_params["model_name"], trained_params["dataset_name"], trained_params["emb_type"]
32-
if model_name in ["saint", "sakt", "cdkt"]:
32+
if model_name in ["saint", "sakt", "atdkt"]:
3333
train_config = config["train_config"]
3434
seq_len = train_config["seq_len"]
3535
model_config["seq_len"] = seq_len
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
if __name__ == "__main__":
55
parser = argparse.ArgumentParser()
66
parser.add_argument("--dataset_name", type=str, default="algebra2005")
7-
parser.add_argument("--model_name", type=str, default="bakt")
7+
parser.add_argument("--model_name", type=str, default="simplekt")
88
parser.add_argument("--emb_type", type=str, default="qid")
99
parser.add_argument("--save_dir", type=str, default="saved_model")
1010
# parser.add_argument("--learning_rate", type=float, default=1e-5)

examples/wandb_train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def main(params):
4141
train_config = config["train_config"]
4242
if model_name in ["dkvmn","deep_irt", "sakt", "saint","saint++", "akt", "atkt", "lpkt", "skvmn"]:
4343
train_config["batch_size"] = 64 ## because of OOM
44-
if model_name in ["bakt", "bakt_time"]:
44+
if model_name in ["simplekt", "bakt_time"]:
4545
train_config["batch_size"] = 64 ## because of OOM
4646
if model_name in ["gkt"]:
4747
train_config["batch_size"] = 16
@@ -88,7 +88,7 @@ def main(params):
8888
for remove_item in ['use_wandb','learning_rate','add_uuid','l2']:
8989
if remove_item in model_config:
9090
del model_config[remove_item]
91-
if model_name in ["saint","saint++", "sakt", "cdkt", "bakt", "bakt_time"]:
91+
if model_name in ["saint","saint++", "sakt", "atdkt", "simplekt", "bakt_time"]:
9292
model_config["seq_len"] = seq_len
9393

9494
debug_print(text = "init_model",fuc_name="main")
Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torch import FloatTensor, LongTensor
1212
import numpy as np
1313

14-
class CDKTDataset(Dataset):
14+
class ATDKTDataset(Dataset):
1515
"""Dataset for KT
1616
can use to init dataset for: (for models except dkt_forget)
1717
train data, valid data
@@ -24,16 +24,16 @@ class CDKTDataset(Dataset):
2424
qtest (bool, optional): is question evaluation or not. Defaults to False.
2525
"""
2626
def __init__(self, file_path, input_type, folds, qtest=False):
27-
super(CDKTDataset, self).__init__()
27+
super(ATDKTDataset, self).__init__()
2828
sequence_path = file_path
2929
self.input_type = input_type
3030
self.qtest = qtest
3131
folds = sorted(list(folds))
3232
folds_str = "_" + "_".join([str(_) for _ in folds])
3333
if self.qtest:
34-
processed_data = file_path + folds_str + "_cdkt_qtest.pkl"
34+
processed_data = file_path + folds_str + "_atdkt_qtest.pkl"
3535
else:
36-
processed_data = file_path + folds_str + "_cdkt.pkl"
36+
processed_data = file_path + folds_str + "_atdkt.pkl"
3737
self.dpath = "/".join(file_path.split("/")[0:-1])
3838

3939
if not os.path.exists(processed_data):
@@ -51,8 +51,6 @@ def __init__(self, file_path, input_type, folds, qtest=False):
5151
self.dori, self.dqtest = pd.read_pickle(processed_data)
5252
else:
5353
self.dori = pd.read_pickle(processed_data)
54-
for key in self.dori:
55-
self.dori[key] = self.dori[key]#[:100]
5654
print(f"file path: {file_path}, qlen: {len(self.dori['qseqs'])}, clen: {len(self.dori['cseqs'])}, rlen: {len(self.dori['rseqs'])}")
5755

5856
def __len__(self):

pykt/datasets/init_dataset.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import numpy as np
66
from .data_loader import KTDataset
77
from .dkt_forget_dataloader import DktForgetDataset
8-
from .cdkt_dataloader import CDKTDataset
8+
from .atdkt_dataloader import ATDKTDataset
99
from .lpkt_dataloader import LPKTDataset
1010
from .lpkt_utils import generate_time2idx
1111
from .que_data_loader import KTQueDataset
@@ -39,12 +39,12 @@ def init_test_datasets(data_config, model_name, batch_size):
3939
concept_num=data_config['num_c'], max_concepts=data_config['max_concepts'])
4040
test_question_dataset = None
4141
test_question_window_dataset= None
42-
elif model_name in ["cdkt"]:
43-
test_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
44-
test_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
42+
elif model_name in ["atdkt"]:
43+
test_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
44+
test_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
4545
if "test_question_file" in data_config:
46-
test_question_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True)
47-
test_question_window_dataset = CDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True)
46+
test_question_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_file"]), data_config["input_type"], {-1}, True)
47+
test_question_window_dataset = ATDKTDataset(os.path.join(data_config["dpath"], data_config["test_question_window_file"]), data_config["input_type"], {-1}, True)
4848
else:
4949
test_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_file"]), data_config["input_type"], {-1})
5050
test_window_dataset = KTDataset(os.path.join(data_config["dpath"], data_config["test_window_file"]), data_config["input_type"], {-1})
@@ -96,9 +96,9 @@ def init_dataset4train(dataset_name, model_name, data_config, i, batch_size):
9696
curtrain = KTQueDataset(os.path.join(data_config["dpath"], data_config["train_valid_file_quelevel"]),
9797
input_type=data_config["input_type"], folds=all_folds - {i},
9898
concept_num=data_config['num_c'], max_concepts=data_config['max_concepts'])
99-
elif model_name in ["cdkt"]:
100-
curvalid = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
101-
curtrain = CDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
99+
elif model_name in ["atdkt"]:
100+
curvalid = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
101+
curtrain = ATDKTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
102102
else:
103103
curvalid = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], {i})
104104
curtrain = KTDataset(os.path.join(data_config["dpath"], data_config["train_valid_file"]), data_config["input_type"], all_folds - {i})
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
device = "cpu" if not torch.cuda.is_available() else "cuda"
88

9-
class CDKT(Module):
10-
def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid', num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768):
9+
class ATDKT(Module):
10+
def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
11+
num_layers=1, num_attn_heads=5, l1=0.5, l2=0.5, l3=0.5, start=50, emb_path="", pretrain_dim=768):
1112
super().__init__()
12-
self.model_name = "cdkt"
13+
self.model_name = "atdkt"
1314
print(f"qnum: {num_q}, cnum: {num_c}")
1415
print(f"emb_type: {emb_type}")
1516
self.num_q = num_q
@@ -34,7 +35,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
3435
if self.emb_type.find("qemb") != -1:
3536
self.question_emb = Embedding(self.num_q, self.emb_size)
3637

37-
# 加一个预测历史准确率的任务
3838
self.start = start
3939
self.hisclasifier = nn.Sequential(
4040
nn.Linear(self.hidden_size, self.hidden_size//2), nn.ReLU(), nn.Dropout(dropout),
@@ -62,7 +62,6 @@ def __init__(self, num_q, num_c, seq_len, emb_size, dropout=0.1, emb_type='qid',
6262
self.concept_emb = Embedding(self.num_c, self.emb_size) # add concept emb
6363

6464
self.closs = CrossEntropyLoss()
65-
# 加一个预测历史准确率的任务
6665
if self.emb_type.find("his") != -1:
6766
self.start = start
6867
self.hisclasifier = nn.Sequential(
@@ -122,7 +121,7 @@ def predcurc(self, dcur, q, c, r, xemb, train):
122121
h = self.dropout_layer(h)
123122
y = self.out_layer(h)
124123
y = torch.sigmoid(y)
125-
return y, y2, y3, rpreds, qh
124+
return y, y2, y3
126125

127126
def forward(self, dcur, train=False): ## F * xemb
128127
# print(f"keys: {dcur.keys()}")
@@ -162,10 +161,10 @@ def forward(self, dcur, train=False): ## F * xemb
162161
y = self.out_layer(h)
163162
y = torch.sigmoid(y)
164163
elif emb_type.endswith("predcurc"): # predict current question' current concept
165-
y, y2, y3, rpreds, qh = self.predcurc(dcur, q, c, r, xemb, train)
164+
y, y2, y3 = self.predcurc(dcur, q, c, r, xemb, train)
166165

167166
if train:
168167
return y, y2, y3
169168
else:
170-
return y, rpreds, qh
169+
return y
171170

0 commit comments

Comments
 (0)