Skip to content

Commit 1c23a83

Browse files
authored
Add files via upload
1 parent d57b3b5 commit 1c23a83

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

src/cnlpt/api/annotate_rest.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import logging
19+
from contextlib import asynccontextmanager
20+
from time import time
21+
22+
import numpy as np
23+
from fastapi import FastAPI
24+
from pydantic import BaseModel
25+
26+
27+
from nltk.tokenize import wordpunct_tokenize as tokenize # from timex rest
28+
from seqeval.metrics.sequence_labeling import get_entities # from timex rest
29+
30+
from .utils import (
31+
EntityDocument,
32+
create_dataset,
33+
create_instance_string,
34+
initialize_cnlpt_model,
35+
)
36+
37+
38+
logger = logging.getLogger("DoNotAnnotate_REST_Processor")
39+
logger.setLevel(logging.DEBUG)
40+
41+
42+
MODEL_PATH = "/my_trained_model" # path to saved model with model.safetensors
43+
TASK = "DoNotAnnotate"
44+
LABELS = [0, 1] # Do not annotate = 0, Annotate = 1
45+
46+
MAX_LENGTH = 128
47+
48+
class AnnotateResults(BaseModel):
49+
"""statuses: list of classifier outputs for every input"""
50+
labels: list[int]
51+
52+
tokenizer: PreTrainedTokenizer
53+
trainer: Trainer
54+
55+
@asynccontextmanager
56+
async def lifespan(app: FastAPI):
57+
global tokenizer, trainer
58+
tokenizer, trainer = initialize_cnlpt_model(MODEL_NAME)
59+
yield
60+
61+
app = FastAPI(lifespan=lifespan)
62+
63+
@app.post("/annotate/process")
64+
async def process(doc: UnannotatedDocument):
65+
lines = [line.strip() for line in doc_text.split("\n") if line.strip()]
66+
logger.warning(
67+
f"Received document of {len(doc_text)} to process with {len(lines)} non-empty lines"
68+
)
69+
start_time = time()
70+
71+
if not lines:
72+
return AnnotateResults(labels=[])
73+
74+
dataset = create_dataset(lines, tokenizer, MAX_LENGTH)
75+
preproc_end = time()
76+
77+
output = trainer.predict(test_dataset=dataset)
78+
predictions = np.argmax(output.predictions, axis=1)
79+
80+
pred_end = time()
81+
82+
results = AnnotateResults(labels=predictions.tolist())
83+
84+
results = []
85+
for ind in range(len(dataset)):
86+
results.append(LABELS[predictions[ind]])
87+
88+
output = AnnotateResults(statuses=results)
89+
90+
postproc_end = time()
91+
92+
preproc_time = preproc_end - start_time
93+
pred_time = pred_end - preproc_end
94+
postproc_time = postproc_end - pred_end
95+
96+
logging.info(
97+
f"Pre-processing time: {preproc_time:f}, processing time: {pred_time:f}, post-processing time {postproc_time:f}"
98+
)
99+
100+
return output
101+

0 commit comments

Comments
 (0)