Skip to content

Commit fd4440d

Browse files
authored
Minor CLI improvements (#242)
* improve CLI docstrings * warn on incompatible arg for model type
1 parent 68f2b08 commit fd4440d

File tree

2 files changed

+89
-8
lines changed

2 files changed

+89
-8
lines changed

src/cnlpt/_cli/rest.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ def rest(
3535
typer.Option(
3636
"--model",
3737
callback=parse_models,
38-
help="Model definition as [ROUTER_PREFIX=]PATH_TO_MODEL. Route prefix must start with '/'. This option can be specified multiple times to serve multiple models simultaneously. Route prefixes are required when serving more than one model.",
38+
help=(
39+
"Model to serve, specified as PATH_TO_MODEL or ROUTE_PREFIX=PATH_TO_MODEL. "
40+
"PATH_TO_MODEL can be a local directory or a HuggingFace model repository (e.g. 'mlml-chip/negation_pubmedbert_sharpseed'). "
41+
"ROUTE_PREFIX must start with '/' and is required when serving more than one model. "
42+
"This option can be repeated to serve multiple models simultaneously "
43+
"(e.g. --model /negation=mlml-chip/negation_pubmedbert_sharpseed --model /temporal=mlml-chip/thyme2_colon_e2e)."
44+
),
3945
),
4046
],
4147
host: Annotated[
@@ -45,7 +51,34 @@ def rest(
4551
int, typer.Option("-p", "--port", help="Port to serve the REST app.")
4652
] = 8000,
4753
):
48-
"""Start a REST application from a model."""
54+
"""Start a REST API server for one or more cnlpt models.
55+
56+
Serves a FastAPI application with a /process endpoint that accepts text
57+
(and optionally entity spans) and returns model predictions. Interactive
58+
API documentation is available at /docs once the server is running.
59+
60+
\b
61+
Examples:
62+
63+
Serve a single model from HuggingFace:
64+
65+
cnlpt rest --model mlml-chip/negation_pubmedbert_sharpseed
66+
67+
Serve a single model from a local directory:
68+
69+
cnlpt rest --model ./my_model --host 0.0.0.0 --port 9000
70+
71+
Serve multiple models simultaneously, each under its own route prefix:
72+
73+
cnlpt rest \\
74+
--model /negation=mlml-chip/negation_pubmedbert_sharpseed \\
75+
--model /temporal=mlml-chip/thyme2_colon_e2e
76+
77+
When serving multiple models, each model's /process endpoint is available
78+
at ROUTE_PREFIX/process (e.g. /negation/process).
79+
80+
Interactive API documentation for all models is available at HOST:PORT/docs.
81+
"""
4982
import asyncio
5083
import logging
5184

src/cnlpt/_cli/train.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def train(
333333
# MODEL ARGS #
334334
# ------------------ #
335335
model_type: ModelTypeArg = ...,
336-
encoder_name: EncoderArg = DEFAULT_ENCODER,
336+
encoder: EncoderArg = DEFAULT_ENCODER,
337337
use_prior_tasks: UsePriorTasksArg = False,
338338
encoder_layer: EncoderLayerArg = -1,
339339
classification_mode: ClassificationModeArg = "cls",
@@ -381,15 +381,63 @@ def train(
381381
logging_first_step: LoggingFirstStepArg = True,
382382
cache_dir: CacheDirArg = None,
383383
):
384-
# TODO(ian): it's probably worth making this docstring pretty descriptive
385-
"""Run the cnlp_transformers training system."""
384+
"""Train a model on one or more NLP tasks using the cnlp_transformers training system.
385+
386+
Requires a data directory containing CNLPT-formatted data
387+
(https://github.com/Machine-Learning-for-Medical-Language/cnlp_transformers#workflow)
388+
and a model type. The model will be evaluated on the dev split
389+
after each epoch, and predictions on the test split will be written to the output
390+
directory if --do_predict is set.
391+
392+
\b
393+
MODEL TYPES
394+
proj (Projection) Transformer encoder + task-specific projection heads.
395+
The recommended choice for most NLP tasks. Supports
396+
sequence classification, NER/tagging, and relation
397+
extraction.
398+
hier (Hierarchical) Two-stage model for long documents: a transformer
399+
encoder processes chunks, then a secondary transformer
400+
aggregates across chunks. Use when inputs exceed the
401+
encoder's max sequence length.
402+
cnn (CNN) Lightweight convolutional model trained from scratch
403+
(no pretrained encoder). Fast and low-resource, but
404+
typically lower accuracy than transformer-based models.
405+
lstm (LSTM) Lightweight recurrent model trained from scratch
406+
(no pretrained encoder). Similar trade-offs to CNN.
407+
408+
\b
409+
MULTI-TASK TRAINING
410+
Multiple tasks can be trained jointly using --task/-t to select which tasks to include.
411+
Omitting --task will train on all tasks found in the dataset.
412+
413+
\b
414+
ADDITIONAL HUGGINGFACE TRAINING ARGUMENTS
415+
This command accepts all arguments supported by the HuggingFace Trainer
416+
(e.g. --learning_rate, --num_train_epochs, --output_dir, --do_predict).
417+
These are passed through directly and are not listed in the help below.
418+
See: https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments
419+
"""
420+
421+
# Warn if any explicitly-set args are incompatible with the selected model type.
422+
compat_map = ctx.meta.get(_ARG_COMPAT_METADATA_KEY, {})
423+
for param_name, compatible_types in compat_map.items():
424+
if model_type.value not in compatible_types and (
425+
ctx.get_parameter_source(param_name) == ParameterSource.COMMANDLINE
426+
):
427+
cli_name = f"--{param_name.replace('_', '-')}"
428+
compatible_str = "/".join(compatible_types)
429+
typer.echo(
430+
f"Warning: {cli_name} is only used for {compatible_str} models "
431+
f"and will be ignored for model type '{model_type.value}'.",
432+
err=True,
433+
)
386434

387435
# If the tokenizer wasn't explicitly specified and this is a model
388436
# that accepts an encoder, use the encoder's tokenizer.
389437
if ctx.get_parameter_source("tokenizer") != ParameterSource.COMMANDLINE and (
390438
model_type in (ModelType.HIER, ModelType.PROJ)
391439
):
392-
tokenizer = encoder_name
440+
tokenizer = encoder
393441

394442
dataset = CnlpDataset(
395443
data_dir=data_dir,
@@ -463,7 +511,7 @@ def train(
463511
config = HierarchicalModelConfig(
464512
tasks=list(dataset.tasks),
465513
vocab_size=len(dataset.tokenizer),
466-
encoder_name=encoder_name,
514+
encoder_name=encoder,
467515
layer=hier_use_layer,
468516
n_layers=hier_layers,
469517
d_inner=hier_hidden_dim,
@@ -476,7 +524,7 @@ def train(
476524
config = ProjectionModelConfig(
477525
tasks=list(dataset.tasks),
478526
vocab_size=len(dataset.tokenizer),
479-
encoder_name=encoder_name,
527+
encoder_name=encoder,
480528
encoder_layer=encoder_layer,
481529
use_prior_tasks=use_prior_tasks,
482530
classification_mode=classification_mode,

0 commit comments

Comments
 (0)