Skip to content

Commit a5aa7d0

Browse files
authored
Unify model loading and REST APIs and improve CLI (#241)
* unify model loading * rework args * rework cli * consolidate rest apis * update tests * update examples * remove unused kwargs * typo * fix import * fix inconsistencies in readme and CLI docs * log hyperlink to interactive docs on REST startup * temporarily disable tlinkx check in temporal test * allow disjoint labels in generated testing datasets * fix issue loading config model type * convert encoder_layer arg when loading legacy models * re-enable test * update readme * fix typo in example notebook * save tokenizer with model and fall back to encoder tokenizer when loading for REST apps
1 parent 5c93d38 commit a5aa7d0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+4055
-4426
lines changed

README.md

Lines changed: 33 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ To use the library for fine-tuning, you'll need to take the following steps:
103103

104104
Instance labels should be formatted the same way as in the csv/tsv example above, see specifically the formats for tagging and relations. The 'metadata' field can either be included in the train/dev/test files or as a separate metadata.json file.
105105

106-
2. Run train_system.py with a ```--task_name``` from your data files and the ```--data-dir``` argument from Step 1. If no ```--task_name``` is provided, all tasks will be trained.
106+
2. Run train_system.py with a `--model-type` (one of `cnn`, `lstm`, `hier`, or `proj`), and a `--data-dir` (path to the folder you created in step 1). Optionally specify one or more `--task` names to train on. By default all tasks will be trained.
107107

108108
### Step-by-step finetuning examples
109109

@@ -115,113 +115,61 @@ We provided the following step-by-step examples how to finetune in clinical NLP
115115

116116
### Fine-tuning options
117117

118-
Run `cnlpt train -h` to see all the available options. In addition to inherited Huggingface Transformers options, there are options to do the following:
118+
Run `cnlpt train --help` to see all the available options. In addition to inherited Huggingface Transformers options, there are options to do the following:
119119

120120
* Select different models: `--model hier` uses a hierarchical transformer layer on top of a specified encoder model. We recommend using a very small encoder: `--encoder microsoft/xtremedistil-l6-h256-uncased` so that the full model fits into memory.
121-
* Run simple baselines (use ``--model cnn|lstm --tokenizer_name roberta-base`` -- since there is no HF model then you must specify the tokenizer explicitly)
121+
* Run simple baselines (use `--model cnn|lstm --tokenizer roberta-base` -- since there is no HF model then you must specify the tokenizer explicitly)
122122
* Use a different layer's CLS token for the classification (e.g., `--layer 10`)
123123
* Probabilistically freeze weights of the encoder (leaving classifier weights all unfrozen) (`--freeze` alone freezes all encoder weights, `--freeze <float>` when given a parameter between 0 and 1, freezes that percentage of encoder weights)
124124
* Classify based on a token embedding instead of the CLS embedding (`--token` -- applies to the event/entity classification setting only, and requires the input to have xml-style tags (`<e>`, `</e>`) around the tokens representing the event/entity)
125125
* Use class-weighted loss function (`--class_weights`)
126126
127127
## Running REST APIs
128128
129-
There are existing REST APIs in the `src/cnlpt/api` folder for a few important clinical NLP tasks:
129+
This library supports serving a REST API for your model with a single `/process` endpoint to process text and generate predictions, via the `cnlpt rest` command.
130130
131-
1. Negation detection
132-
2. Time expression tagging (spans + time classes)
133-
3. Event detection (spans + document creation time relation)
134-
4. End-to-end temporal relation extraction (event spans+DTR+timex spans+time classes+narrative container [CONTAINS] relation extraction)
131+
Run `cnlpt rest --help` to see available options. The only required option is `--model`, which must be either a HuggingFace repository or a local directory containing your model. By default, the model will be served at [http://localhost:8000](http://localhost:8000).
135132
136-
### Negation API
133+
For example, to run our negation detection model from HuggingFace:
137134
138-
To demo the negation API:
139-
140-
1. Install the `cnlp-transformers` package.
141-
2. Run `cnlpt rest --model-type negation [-p PORT]`.
142-
3. Open a python console and run the following commands:
143-
144-
#### Setup variables for negation
145-
146-
```ipython
147-
>>> import requests
148-
>>> process_url = 'http://hostname:8000/negation/process' ## Replace hostname with your host name
149-
```
150-
151-
#### Prepare the document
152-
153-
```ipython
154-
>>> sent = 'The patient has a sore knee and headache but denies nausea and has no anosmia.'
155-
>>> ents = [[18, 27], [32, 40], [52, 58], [70, 77]]
156-
>>> doc = {'doc_text':sent, 'entities':ents}
135+
```bash
136+
cnlpt rest --model mlml-chip/negation_pubmedbert_sharpseed
157137
```
158138
159-
#### Process the document
160-
161-
```ipython
162-
>>> r = requests.post(process_url, json=doc)
163-
>>> r.json()
164-
```
165-
166-
Output: `{'statuses': [-1, -1, 1, 1]}`
167-
168-
The model correctly classifies both nausea and anosmia as negated.
169-
170-
### Temporal API (End-to-end temporal information extraction)
171-
172-
To demo the temporal API:
173-
174-
1. Install the `cnlp-transformers` package.
175-
2. Run `cnlpt rest --model-type temporal [-p PORT]`
176-
3. Open a python console and run the following commands to test:
177-
178-
#### Setup variables for temporal
139+
Once the application is running, you can either interact with it via web interface at [http://localhost:8000/docs](http://localhost:8000/docs), or manually send requests to the `/process` endpoint:
179140
180141
```ipython
181142
>>> import requests
182143
>>> from pprint import pprint
183-
>>> process_url = 'http://hostname:8000/temporal/process_sentence' ## Replace hostname with your host name
144+
>>> sent = "The patient has a sore knee and headache but denies nausea and has no anosmia."
145+
>>> ents = [(18, 27), (32, 40), (52, 58), (70, 77)]
146+
>>> doc = {"text": sent, "entity_spans": ents}
147+
>>> resp = requests.post("http://localhost:8000/process", json=doc)
148+
>>> pprint(resp.json())
149+
[{'Negation': {'prediction': '-1',
150+
'probs': {'-1': 0.9997619986534119, '1': 0.0002379878715146333}},
151+
'text': 'The patient has a <e>sore knee</e> and headache but denies nausea '
152+
'and has no anosmia.'},
153+
{'Negation': {'prediction': '-1',
154+
'probs': {'-1': 0.9995606541633606, '1': 0.0004393413255456835}},
155+
'text': 'The patient has a sore knee and <e>headache</e> but denies nausea '
156+
'and has no anosmia.'},
157+
{'Negation': {'prediction': '1',
158+
'probs': {'-1': 0.007858583703637123, '1': 0.9921413660049438}},
159+
'text': 'The patient has a sore knee and headache but denies <e>nausea</e> '
160+
'and has no anosmia.'},
161+
{'Negation': {'prediction': '1',
162+
'probs': {'-1': 0.0071166763082146645, '1': 0.9928833246231079}},
163+
'text': 'The patient has a sore knee and headache but denies nausea and has '
164+
'no <e>anosmia</e>.'}]
184165
```
185166
186-
#### Prepare and process the document
167+
You can also serve multiple models at once by providing a router prefix for each model, e.g.:
187168
188-
```ipython
189-
>>> sent = 'The patient was diagnosed with adenocarcinoma March 3, 2010 and will be returning for chemotherapy next week.'
190-
>>> r = requests.post(process_url, json={'sentence':sent})
191-
>>> pprint(r.json())
169+
```bash
170+
cnlpt rest --model /negation=mlml-chip/negation_pubmedbert_sharpseed --model /temporal=mlml-chip/thyme2_colon_e2e
192171
```
193172
194-
should return:
195-
196-
```json
197-
{
198-
"events": [
199-
[
200-
{"begin": 3, "dtr": "BEFORE", "end": 3},
201-
{"begin": 5, "dtr": "BEFORE", "end": 5},
202-
{"begin": 13, "dtr": "AFTER", "end": 13},
203-
{"begin": 15, "dtr": "AFTER", "end": 15}
204-
]
205-
],
206-
"relations": [
207-
[
208-
{"arg1": "TIMEX-0", "arg2": "EVENT-0", "category": "CONTAINS"},
209-
{"arg1": "EVENT-2", "arg2": "EVENT-3", "category": "CONTAINS"},
210-
{"arg1": "TIMEX-1", "arg2": "EVENT-2", "category": "CONTAINS"},
211-
{"arg1": "TIMEX-1", "arg2": "EVENT-3", "category": "CONTAINS"}
212-
]
213-
],
214-
"timexes": [
215-
[
216-
{"begin": 6, "end": 9, "timeClass": "DATE"},
217-
{"begin": 16, "end": 17, "timeClass": "DATE"}
218-
]
219-
]
220-
}
221-
```
222-
223-
This output indicates the token spans of events and timexes, and relations between events and timexes, where the suffixes are indices into the respective arrays (e.g., TIMEX-0 in a relation refers to the 0th time expression found, which begins at token 6 and ends at token 9 -- ["March 3, 2010"])
224-
225173
## Citing cnlp_transformers
226174
227175
Please use the following bibtex to cite cnlp_transformers if you use it in a publication:

docker/model_download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from transformers.models.auto.tokenization_auto import AutoTokenizer
55

66
from cnlpt.legacy.train_system import is_hub_model
7-
from cnlpt.models import CnlpModelForClassification, HierarchicalModel
7+
from cnlpt.modeling import CnlpModelForClassification, HierarchicalModel
88

99

1010
def pre_initialize_cnlpt_model(model_name, cuda=True, batch_size=8):

examples/.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*/dataset/
2+
*/train_output/

examples/chemprot/README.md

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
# Fine-tuning for tagging: End-to-end example
22

3-
1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py data/chemprot`
3+
1. Preprocess the data with `uv run examples/chemprot/prepare_chemprot_dataset.py`
44

5-
2. Fine-tune with something like:
5+
2. Fine-tune for NER with something like:
66

77
```bash
8-
cnlpt train \
9-
--task_name chemical_ner gene_ner \
10-
--data_dir data/chemprot \
11-
--encoder_name allenai/scibert_scivocab_uncased \
12-
--do_train \
13-
--do_eval \
14-
--cache_dir cache/ \
15-
--output_dir temp/ \
8+
uv run cnlpt train \
9+
--model_type proj \
10+
--encoder allenai/scibert_scivocab_uncased \
11+
--data_dir ./dataset \
12+
--task chemical_ner --task gene_ner \
13+
--output_dir ./train_output \
1614
--overwrite_output_dir \
17-
--num_train_epochs 50 \
15+
--do_train --do_eval \
16+
--num_train_epochs 3 \
1817
--learning_rate 2e-5 \
1918
--lr_scheduler_type constant \
20-
--report_to none \
21-
--save_strategy no \
19+
--save_strategy best \
2220
--gradient_accumulation_steps 1 \
2321
--eval_accumulation_steps 10 \
2422
--weight_decay 0.2

examples/chemprot/preprocess_chemprot.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
import bisect
22
import itertools
3-
import os
43
import re
54
from dataclasses import dataclass
6-
from sys import argv
7-
from typing import Any, Union
5+
from pathlib import Path
6+
from typing import Any
87

98
import polars as pl
109
from datasets import load_dataset
1110
from datasets.dataset_dict import Dataset, DatasetDict
11+
from datasets.utils import disable_progress_bars, enable_progress_bars
1212
from rich.console import Console
1313

1414

1515
def load_chemprot_dataset(cache_dir="./cache") -> DatasetDict:
16-
return load_dataset("bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir)
16+
disable_progress_bars()
17+
dataset = load_dataset(
18+
"bigbio/chemprot", "chemprot_full_source", cache_dir=cache_dir
19+
)
20+
enable_progress_bars()
21+
return dataset
1722

1823

1924
def clean_text(text: str):
@@ -156,25 +161,18 @@ def preprocess_data(split: Dataset):
156161
)
157162

158163

159-
def main(out_dir: Union[str, os.PathLike]):
164+
if __name__ == "__main__":
160165
console = Console()
161-
162-
if not os.path.isdir(out_dir):
163-
os.mkdir(out_dir)
166+
out_dir = Path(__file__).parent / "dataset"
167+
out_dir.mkdir(exist_ok=True)
164168

165169
with console.status("Loading dataset...") as st:
166170
dataset = load_chemprot_dataset()
167171
for split in ("train", "test", "validation"):
168172
st.update(f"Preprocessing {split} data...")
169173
preprocessed = preprocess_data(dataset[split])
170-
preprocessed.write_csv(
171-
os.path.join(out_dir, f"{split}.tsv"), separator="\t"
172-
)
174+
preprocessed.write_csv(out_dir / f"{split}.tsv", separator="\t")
173175

174176
console.print(
175177
f"[green i]Preprocessed chemprot data saved to [repr.filename]{out_dir}[/]."
176178
)
177-
178-
179-
if __name__ == "__main__":
180-
main(argv[1])

examples/uci_drug/README.md

Lines changed: 55 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,66 @@
1-
### Fine-tuning for classification: End-to-end example
1+
# Drug Review Sentiment Classification
22

3-
1. Download data from [Drug Reviews (Druglib.com) Data Set](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com) to `data` folder and extract. Pay attention to their terms:
4-
1. only use the data for research purposes
5-
2. don't use the data for any commerical purposes
6-
3. don't distribute the data to anyone else
7-
4. cite us
3+
## Jupyter notebook example
84

9-
2. Run ```python examples/uci_drug/transform_uci_drug.py <raw dir> <processed dir>``` to preprocess the data from the extract directory into a new directory. This will create {train,dev,test}.tsv in the processed directory specified, where the sentiment ratings have been collapsed into 3 categories.
5+
See the [example notebook](./uci_drug.ipynb) for a step-by-step walkthrough of
6+
how to use CNLPT to train a model for sentiment classification of drug reviews.
107

11-
3. Fine-tune with something like:
8+
## CLI example
129

13-
```bash
14-
cnlpt train \
15-
--data_dir <processed dir> \
16-
--task_name sentiment \
17-
--encoder_name roberta-base \
18-
--do_train \
19-
--do_eval \
20-
--cache_dir cache/ \
21-
--output_dir temp/ \
22-
--overwrite_output_dir \
23-
--evals_per_epoch 5 \
24-
--num_train_epochs 1 \
25-
--learning_rate 1e-5 \
26-
--report_to none \
27-
--metric_for_best_model eval_sentiment.avg_micro_f1 \
28-
--load_best_model_at_end \
29-
--save_strategy best
30-
```
31-
32-
On our hardware, that command results in eval performance like the following:
33-
```sentiment = {'acc': 0.7041800643086816, 'f1': [0.7916666666666666, 0.7228915662650603, 0.19444444444444442], 'acc_and_f1': [0.7479233654876741, 0.7135358152868709, 0.449312254376563], 'recall': [0.8216216216216217, 0.8695652173913043, 0.12280701754385964], 'precision': [0.7638190954773869, 0.6185567010309279, 0.4666666666666667]}```
34-
35-
#### Error Analysis for Classification
36-
37-
If you run the above command with the `--error_analysis` flag, you can obtain the `dev` instances for which the model made an erroneous
38-
prediction, organized by their original index in `dev` split, in the `eval_predictions...tsv` file in the `--output_dir` argument.
39-
For us the first line of this file (after the header) is:
40-
41-
```
42-
text sentiment
43-
2 Benefits: <cr> helped aleviate whip lash symptoms <cr> Side effects: <cr> none that i noticed <cr> Overall comments: <cr> i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High
44-
45-
```
46-
47-
The number at the beginning of the line, 2, is the index of the instance in the `dev` split. The `text` column contains the text of the erroneous instances and the following columns are the tasks provided to the model, in this case, just `sentiment`. `Ground: Medium Predicted: High` indicates that the provided ground truth label for the instance sentiment is `Medium` but the model predicted `High`.
48-
49-
#### Human Readable Predictions for Classification
10+
If you prefer, you can instead use the CLI to train the model:
5011

51-
Similarly if you run the above command with `--do_predict` you can obtain human readable predictions for the `test` split, in the `test_predictions...tsv` file. For us the first line of this file (after the header) is:
52-
53-
```
54-
0 Benefits: <cr> The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. <cr> Side effects: <cr> Some back pain, some nauseau. <cr> Overall comments: <cr> Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low
55-
56-
```
57-
58-
##### Prediction Probability Outputs for Classification
59-
60-
(Currently only supported for classification tasks), if you run the above command with the `--output_prob` flag, you can see the model's softmax-obtained probability for the predicted classification label. The first error analysis sample from `dev` would now looks like:
61-
62-
```
63-
text sentiment
64-
2 Benefits: <cr> helped aleviate whip lash symptoms <cr> Side effects: <cr> none that i noticed <cr> Overall comments: <cr> i took the medications for the prescribed time and symptoms improved, however, I still have some symptoms which are being treated through physical therapy since the accident was only in December Ground: Medium Predicted: High , Probability 0.613825
12+
### Download and preprocess the data
6513

14+
Use the [`prepare_data.py`](./prepare_data.py) script to download the data and convert it to CNLPT's data format:
6615

16+
```bash
17+
uv run prepare_data.py
6718
```
6819

69-
And the first prediction sample from `test` now looks like:
20+
> [!TIP] About the dataset:
21+
> This script downloads the
22+
> [*Drug Reviews (Druglib.com)* dataset](https://archive.ics.uci.edu/dataset/461/drug+review+dataset+druglib+com).
23+
> Please be aware of the terms of use:
24+
>
25+
> > Important Notes:
26+
> >
27+
> > When using this dataset, you agree that you
28+
> >
29+
> > 1) only use the data for research purposes
30+
> > 2) don't use the data for any commerical purposes
31+
> > 3) don't distribute the data to anyone else
32+
> > 4) cite UCI data lab and the source
33+
>
34+
> Here is the dataset's BibTeX citation:
35+
>
36+
> ```bibtex
37+
> @misc{drug_reviews_(druglib.com)_461,
38+
> author = {Kallumadi, Surya and Grer, Felix},
39+
> title = {{Drug Reviews (Druglib.com)}},
40+
> year = {2018},
41+
> howpublished = {UCI Machine Learning Repository},
42+
> note = {{DOI}: https://doi.org/10.24432/C55G6J}
43+
> }
44+
> ```
45+
46+
### Train a model
47+
48+
The following example fine-tunes
49+
[the RoBERTa base model](https://huggingface.co/FacebookAI/roberta-base)
50+
with an added projection layer for classification:
7051
71-
```
72-
text sentiment
73-
0 Benefits: <cr> The antibiotic may have destroyed bacteria causing my sinus infection. But it may also have been caused by a virus, so its hard to say. <cr> Side effects: <cr> Some back pain, some nauseau. <cr> Overall comments: <cr> Took the antibiotics for 14 days. Sinus infection was gone after the 6th day. Low , Probability 0.370522
52+
```bash
53+
uv run cnlpt train \
54+
--model_type proj \
55+
--encoder roberta-base \
56+
--data_dir ./dataset \
57+
--task sentiment \
58+
--output_dir ./train_output \
59+
--overwrite_output_dir \
60+
--do_train --do_eval --do_predict \
61+
--evals_per_epoch 2 \
62+
--learning_rate 1e-5 \
63+
--metric_for_best_model 'sentiment.macro_f1' \
64+
--load_best_model_at_end \
65+
--save_strategy best
7466
```

0 commit comments

Comments
 (0)