Skip to content

Commit 0829fab

Browse files
authored
Merge pull request sunlabuiuc#1 from zzachw/main
Reorganize file structure
2 parents 2568793 + 7e8c239 commit 0829fab

File tree

11 files changed

+166
-21
lines changed

11 files changed

+166
-21
lines changed

.gitignore

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
.idea
2+
examples/data/raw/*
3+
examples/data/processed/*
4+
5+
# Byte-compiled / optimized / DLL files
6+
__pycache__/
7+
*.py[cod]
8+
*$py.class
9+
.pytest_cache
10+
11+
# C extensions
12+
*.so
13+
14+
# Distribution / packaging
15+
.Python
16+
build/
17+
develop-eggs/
18+
dist/
19+
downloads/
20+
eggs/
21+
.eggs/
22+
lib/
23+
lib64/
24+
parts/
25+
sdist/
26+
var/
27+
wheels/
28+
*.egg-info/
29+
.installed.cfg
30+
*.egg
31+
MANIFEST
32+
33+
# PyInstaller
34+
# Usually these files are written by a python script from a template
35+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
36+
*.manifest
37+
*.spec
38+
39+
# Installer logs
40+
pip-log.txt
41+
pip-delete-this-directory.txt
42+
43+
# Unit test / coverage reports
44+
htmlcov/
45+
.tox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
.hypothesis/
53+
.pytest_cache/
54+
55+
# Translations
56+
*.mo
57+
*.pot
58+
59+
# Django stuff:
60+
*.log
61+
local_settings.py
62+
db.sqlite3
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
target/
76+
77+
# Jupyter Notebook
78+
.ipynb_checkpoints
79+
80+
# pyenv
81+
.python-version
82+
83+
# celery beat schedule file
84+
celerybeat-schedule
85+
86+
# SageMath parsed files
87+
*.sage.py
88+
89+
# Environments
90+
.env
91+
.venv
92+
env/
93+
venv/
94+
ENV/
95+
env.bak/
96+
venv.bak/
97+
98+
# Spyder project settings
99+
.spyderproject
100+
.spyproject
101+
102+
# Rope project settings
103+
.ropeproject
104+
105+
# mkdocs documentation
106+
/site
107+
108+
# mypy
109+
.mypy_cache/

README.md

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,32 @@
1-
# PyHealth OMOP Development PLAN
1+
# PyHealth
2+
3+
## Environment
4+
- pytorch: 1.12.0
5+
- pytorch-lightning: 1.6.4
6+
7+
## Dataset
8+
- MIMIC-III
9+
- MIMIC-IV
10+
- eICU
11+
- OMOP CDM
12+
13+
## Input
14+
- Condition code
15+
- Drug code
16+
- Procedure code
17+
18+
## Output
19+
- Mortality prediction
20+
- Length-of-stay estimation
21+
- Drug recommendation
22+
- Phenotyping
23+
24+
## Model
25+
26+
27+
28+
29+
230

331
### datasets.py
432
- provide process for MIMIC-III, eICU and MIMIC-IV

Example.ipynb renamed to examples/Example.ipynb

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010
},
1111
{
1212
"cell_type": "code",
13-
"execution_count": 1,
13+
"execution_count": 4,
1414
"id": "e424d935",
1515
"metadata": {},
1616
"outputs": [
1717
{
18-
"name": "stdout",
19-
"output_type": "stream",
20-
"text": [
21-
"\n"
18+
"ename": "ModuleNotFoundError",
19+
"evalue": "No module named 'pandas'",
20+
"output_type": "error",
21+
"traceback": [
22+
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
23+
"\u001B[0;31mModuleNotFoundError\u001B[0m Traceback (most recent call last)",
24+
"\u001B[0;32m<ipython-input-4-65ab904da1ab>\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mnumpy\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0;32mimport\u001B[0m \u001B[0mpandas\u001B[0m \u001B[0;32mas\u001B[0m \u001B[0mpd\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mpytorch_lightning\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mLightningModule\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mTrainer\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mpytorch_lightning\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mcallbacks\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mModelCheckpoint\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mMedCode\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mCodeMapping\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
25+
"\u001B[0;31mModuleNotFoundError\u001B[0m: No module named 'pandas'"
2226
]
2327
}
2428
],
@@ -29,15 +33,16 @@
2933
"from pytorch_lightning.callbacks import ModelCheckpoint\n",
3034
"from MedCode import CodeMapping\n",
3135
"\n",
32-
"import datasets\n",
33-
"import models\n",
34-
"import utils\n",
36+
"import pyhealth.datasets as datasets\n",
37+
"import pyhealth.models as models\n",
38+
"import pyhealth.utils as utils\n",
3539
"from importlib import reload\n",
3640
"reload(datasets)\n",
3741
"reload(models)\n",
3842
"reload(utils)\n",
3943
"\n",
40-
"print ()\n"
44+
"import torch\n",
45+
"print(torch.cuda.is_available())\n"
4146
]
4247
},
4348
{
@@ -519,4 +524,4 @@
519524
},
520525
"nbformat": 4,
521526
"nbformat_minor": 5
522-
}
527+
}

examples/data/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# Data
2+
3+
Put raw data under `./data`.
File renamed without changes.

pyhealth/__init__.py

Whitespace-only changes.

pyhealth/datasets/__init__.py

Whitespace-only changes.

datasets.py renamed to pyhealth/datasets/datasets.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from collections import defaultdict
1111

1212
def collate_func_RETAIN(cur_patient, voc_size):
13-
""" data is a list of sample from the dataset """
13+
""" datasets is a list of sample from the dataset """
1414
max_len = max([len(visit[0]) + len(visit[1]) for visit in cur_patient])
1515
X = []
1616
y = torch.zeros((len(cur_patient), voc_size[2]))
@@ -86,7 +86,7 @@ def encodes(self, code_list):
8686

8787
class MIMIC_III:
8888
"""
89-
MIMIC-III data object
89+
MIMIC-III datasets object
9090
- the original MIMIC-III medication is encoded by RxNorm (the column name uses "NDC")
9191
- when initialize, input the target_code and the according code_map
9292
For example,
@@ -211,12 +211,12 @@ def get_atc3(x):
211211

212212
def get_dataloader(self, MODEL):
213213
"""
214-
get the dataloaders for MODEL, since different models has different data loader (input formats are different)
215-
- data <list>: each element is a patient record
216-
- data[0] <list>: each element is a visit
217-
- data[0][0] <list>: diag encoded list for this visit
218-
- data[0][1] <list>: prod encoded list for this visit
219-
- data[0][2] <list>: med encoded list for this visit
214+
get the dataloaders for MODEL, since different models has different datasets loader (input formats are different)
215+
- datasets <list>: each element is a patient record
216+
- datasets[0] <list>: each element is a visit
217+
- datasets[0][0] <list>: diag encoded list for this visit
218+
- datasets[0][1] <list>: prod encoded list for this visit
219+
- datasets[0][2] <list>: med encoded list for this visit
220220
"""
221221
data = []
222222
for _, visit_ls in self.pat_to_visit.items():
@@ -233,7 +233,7 @@ def get_dataloader(self, MODEL):
233233
if len(cur_pat) <= 1: continue
234234
data.append(cur_pat)
235235

236-
# data split
236+
# datasets split
237237
split_point = int(len(data) * 2 / 3)
238238
data_train = data[:split_point]
239239
eval_len = int(len(data[split_point:]) / 2)

pyhealth/models/__init__.py

Whitespace-only changes.

models.py renamed to pyhealth/models/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn as nn
44
import torch.nn.functional as F
55
import numpy as np
6-
from utils import multi_label_metric
6+
from pyhealth.utils import multi_label_metric
77

88
class RETAIN(LightningModule):
99
def __init__(self, voc_size, emb_size=64):

0 commit comments

Comments
 (0)