Skip to content

Commit b3b1269

Browse files
committed
2 parents 9809e5a + 41adb9b commit b3b1269

File tree

4 files changed

+14
-10
lines changed

4 files changed

+14
-10
lines changed

src/reynir_correct/classifier.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,13 @@
4242
4343
"""
4444

45-
from typing import List, overload
45+
from typing import List, Union, overload
4646

4747

4848
try:
4949
from datasets import load_dataset
5050
from transformers import pipeline # type: ignore
5151
except:
52-
import sys
5352
import warnings
5453

5554
warningtext = (
@@ -82,19 +81,18 @@ def classify(self, text: str) -> bool:
8281
def classify(self, text: List[str]) -> List[bool]:
8382
...
8483

85-
def classify(self, text):
84+
def classify(self, text: Union[str, List[str]]) -> Union[List[bool], bool]:
8685
"""Classify a sentence or sentences.
8786
For each sentence, return true iff the sentence probably contains an error."""
8887
if isinstance(text, str):
8988
text = [text]
9089

91-
result = self.pipe([self._domain_prefix + t for t in text])
92-
result = [r["generated_text"] == self._true_label for r in result]
90+
pipe_result = self.pipe([self._domain_prefix + t for t in text])
91+
result: List[bool] = [
92+
r["generated_text"] == self._true_label for r in pipe_result
93+
]
9394

94-
if len(result) == 1:
95-
result = result[0]
96-
97-
return result
95+
return result[0] if len(result) == 1 else result
9896

9997

10098
def _main() -> None:

src/reynir_correct/config/GreynirCorrect.conf

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@
204204
"áratugs", "áratugar"
205205
"áratugsins", "áratugarins"
206206
"árð", "árið"
207+
"arfleið", "arfleifð"
208+
"arfleiðar", "arfleifðar"
207209
"Argentísk", "Argentínsk"
208210
"argentíska", "argentínska"
209211
"argentíski", "argentínski"

src/reynir_correct/pattern.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def wrong_preposition_grin_af(self, match: SimpleTree) -> None:
369369
pp = match.first_match('P > { "af" }')
370370
if pp is None:
371371
pp = match.first_match('ADVP > { "af" }')
372-
if np is None or pp is None:
372+
if vp is None or np is None or pp is None:
373373
return
374374
pp_af = pp.first_match('"af"')
375375
if pp_af is None:

src/reynir_correct/wrappers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,12 @@
6464
Any,
6565
Union,
6666
cast,
67+
TYPE_CHECKING,
6768
)
6869

70+
if TYPE_CHECKING:
71+
from .classifier import SentenceClassifier
72+
6973
import sys
7074
import argparse
7175
import json

0 commit comments

Comments
 (0)