33import pytest
44from PIL import Image
55
6- from unstructured_inference .models import largemodel
6+ from unstructured_inference .models import chipper
77
88
99def test_initialize ():
1010 with mock .patch .object (
11- largemodel .AutoTokenizer ,
11+ chipper .AutoTokenizer ,
1212 "from_pretrained" ,
1313 ) as mock_tokenizer , mock .patch .object (
14- largemodel ,
14+ chipper ,
1515 "DonutProcessor" ,
1616 ) as mock_donut_processor , mock .patch .object (
17- largemodel ,
17+ chipper ,
1818 "DonutImageProcessor" ,
1919 ) as mock_donut_image_processor , mock .patch .object (
20- largemodel .VisionEncoderDecoderModel ,
20+ chipper .VisionEncoderDecoderModel ,
2121 "from_pretrained" ,
2222 ) as mock_vision_encoder_decoder_model :
23- model = largemodel . UnstructuredLargeModel ()
23+ model = chipper . UnstructuredChipperModel ()
2424 model .initialize ("" , "" , "" )
2525 mock_tokenizer .assert_called_once ()
2626 mock_donut_processor .assert_called_once ()
@@ -44,8 +44,8 @@ def mock_initialize(self, *arg, **kwargs):
4444
4545
4646def test_predict_tokens ():
47- with mock .patch .object (largemodel . UnstructuredLargeModel , "initialize" , mock_initialize ):
48- model = largemodel . UnstructuredLargeModel ()
47+ with mock .patch .object (chipper . UnstructuredChipperModel , "initialize" , mock_initialize ):
48+ model = chipper . UnstructuredChipperModel ()
4949 model .initialize ()
5050 with open ("sample-docs/loremipsum.png" , "rb" ) as fp :
5151 im = Image .open (fp )
@@ -64,9 +64,9 @@ def test_predict_tokens():
6464 ],
6565)
6666def test_postprocess (decoded_str , expected_classes ):
67- with mock .patch .object (largemodel . UnstructuredLargeModel , "initialize" , mock_initialize ):
67+ with mock .patch .object (chipper . UnstructuredChipperModel , "initialize" , mock_initialize ):
6868 pass
69- model = largemodel . UnstructuredLargeModel ()
69+ model = chipper . UnstructuredChipperModel ()
7070 tokenizer_model = "xlm-roberta-large"
7171 pre_trained_model = "nielsr/donut-base"
7272 model .initialize (tokenizer_model , pre_trained_model , None )
@@ -81,13 +81,13 @@ def test_postprocess(decoded_str, expected_classes):
8181
8282def test_predict ():
8383 with mock .patch .object (
84- largemodel . UnstructuredLargeModel ,
84+ chipper . UnstructuredChipperModel ,
8585 "predict_tokens" ,
8686 ) as mock_predict_tokens , mock .patch .object (
87- largemodel . UnstructuredLargeModel ,
87+ chipper . UnstructuredChipperModel ,
8888 "postprocess" ,
8989 ) as mock_postprocess :
90- model = largemodel . UnstructuredLargeModel ()
90+ model = chipper . UnstructuredChipperModel ()
9191 model .predict ("hello" )
9292 mock_predict_tokens .assert_called_once ()
9393 mock_postprocess .assert_called_once ()
0 commit comments