diff --git a/tests/test_coverage.py b/tests/test_coverage.py index 4dccf48..0ee5c26 100644 --- a/tests/test_coverage.py +++ b/tests/test_coverage.py @@ -102,3 +102,9 @@ def test_words(): assert len(WORDS) > 0 assert WORDS[0] == 'aa' assert WORDS[-1] == 'zzz' + +def test_keep_case(): + assert segment('\tMaintainCasing \nwith variableSpaCING', True) \ + == ['Maintain', 'Casing', 'with', 'variable', 'SpaCING'] + assert segment('\tMaintainCasing \nwith variableSpaCING', False) \ + == ['maintain', 'casing', 'with', 'variable', 'spacing'] \ No newline at end of file diff --git a/wordsegment/__init__.py b/wordsegment/__init__.py index 1db0776..e1c99f0 100644 --- a/wordsegment/__init__.py +++ b/wordsegment/__init__.py @@ -30,6 +30,7 @@ import math import os.path as op import sys +import re class Segmenter(object): @@ -162,11 +163,26 @@ def candidates(): yield word - def segment(self, text): + def segment(self, text, keep_case=False): "Return list of words that is the best segmenation of `text`." + if keep_case: + return self.maintain_case(text, list(self.isegment(text))) return list(self.isegment(text)) - + def maintain_case(self, orig_text, seg_text): + "maintain the characters casing by referring back to `orig_text`." + cased_text = [] + og_char_i = 0 + for tok_i in range(len(seg_text)): + cased_token = list(seg_text[tok_i]) + for char_i in range(len(cased_token)): + while re.match('[\s]',orig_text[og_char_i]): + og_char_i += 1 + cased_token[char_i] = orig_text[og_char_i] + og_char_i += 1 + cased_text.append(''.join(cased_token)) + return cased_text + def divide(self, text): "Yield `(prefix, suffix)` pairs from `text`." for pos in range(1, min(len(text), self.limit) + 1): @@ -207,12 +223,14 @@ def main(arguments=()): default=sys.stdin) parser.add_argument('outfile', nargs='?', type=argparse.FileType('w'), default=sys.stdout) + parser.add_argument('--keep_case', action='store_true', default=False, + help='maintain original case of input text') streams = parser.parse_args(arguments) load() for line in iter(streams.infile.readline, ''): - streams.outfile.write(' '.join(segment(line.strip()))) + streams.outfile.write(' '.join(segment(line.strip(), streams.keep_case))) streams.outfile.write(os.linesep)