Skip to content

Commit 2d6763e

Browse files
authored
Merge branch 'main' into dev
2 parents ed15391 + d1bd046 commit 2d6763e

File tree

3 files changed

+9
-4
lines changed

3 files changed

+9
-4
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ pip install -U pykt-toolkit -i https://pypi.python.org/simple
6565
We now have a [paper](https://arxiv.org/abs/2206.11460?context=cs.CY) you can cite for the our pyKT library:
6666

6767
```bibtex
68-
@article{liu2022pykt,
68+
@inproceedings{liupykt2022,
6969
title={pyKT: A Python Library to Benchmark Deep Learning based Knowledge Tracing Models},
7070
author={Liu, Zitao and Liu, Qiongqiong and Chen, Jiahao and Huang, Shuyan and Tang, Jiliang and Luo, Weiqi},
71-
journal={arXiv preprint arXiv:2206.11460},
71+
booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
7272
year={2022}
7373
}
74-
```
74+
```

examples/wandb_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def main(params):
6767
print(dataset_name, model_name, data_config, fold, batch_size)
6868

6969
debug_print(text="init_dataset",fuc_name="main")
70-
train_loader, valid_loader = init_dataset4train(dataset_name, model_name, data_config, fold, batch_size)
70+
train_loader, valid_loader, *_ = init_dataset4train(dataset_name, model_name, data_config, fold, batch_size)
7171

7272
params_str = "_".join([str(v) for k,v in params.items() if not k in ['other_config']])
7373

pykt/datasets/data_loader.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@
1111
from torch import FloatTensor, LongTensor
1212
import numpy as np
1313

14+
if torch.cuda.is_available():
15+
from torch.cuda import FloatTensor, LongTensor
16+
else:
17+
from torch import FloatTensor, LongTensor
18+
1419
class KTDataset(Dataset):
1520
"""Dataset for KT
1621
can use to init dataset for: (for models except dkt_forget)

0 commit comments

Comments
 (0)