Skip to content

Commit 269891c

Browse files
committed
ocr page
1 parent a1fadd2 commit 269891c

File tree

3 files changed

+102
-6
lines changed

3 files changed

+102
-6
lines changed

xinference/model/image/core.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ def create_ocr_model_instance(
159159
model_spec: ImageModelFamilyV2,
160160
model_path: Optional[str] = None,
161161
**kwargs,
162-
) -> GotOCR2Model:
162+
):
163163
from .cache_manager import ImageCacheManager
164164

165165
if not model_path:
166166
cache_manager = ImageCacheManager(model_spec)
167167
model_path = cache_manager.cache()
168+
169+
# Use GOT-OCR2 for all OCR models
168170
model = GotOCR2Model(
169171
model_uid,
170172
model_path,

xinference/ui/gradio/media_interface.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def __init__(
6363
)
6464

6565
def build(self) -> gr.Blocks:
66-
if self.model_type == "image":
67-
assert "stable_diffusion" in self.model_family
68-
66+
# Remove the stable_diffusion restriction to support OCR models
6967
interface = self.build_main_interface()
7068
interface.queue()
7169
# Gradio initiates the queue during a startup event, but since the app has already been
@@ -1233,9 +1231,98 @@ def tts_generate(
12331231

12341232
return tts_ui
12351233

1234+
def ocr_interface(self) -> "gr.Blocks":
1235+
def extract_text_from_image(
1236+
image: "PIL.Image.Image",
1237+
ocr_type: str = "ocr",
1238+
progress=gr.Progress(),
1239+
) -> str:
1240+
from ...client import RESTfulClient
1241+
1242+
client = RESTfulClient(self.endpoint)
1243+
client._set_token(self.access_token)
1244+
model = client.get_model(self.model_uid)
1245+
assert hasattr(model, "ocr")
1246+
1247+
# Convert PIL image to bytes
1248+
import io
1249+
1250+
buffered = io.BytesIO()
1251+
if image.mode == "RGBA" or image.mode == "CMYK":
1252+
image = image.convert("RGB")
1253+
image.save(buffered, format="PNG")
1254+
image_bytes = buffered.getvalue()
1255+
1256+
progress(0.1, desc="Processing image for OCR")
1257+
1258+
# Call the OCR method with bytes instead of PIL Image
1259+
response = model.ocr(
1260+
image=image_bytes,
1261+
ocr_type=ocr_type,
1262+
)
1263+
1264+
progress(0.8, desc="Extracting text")
1265+
progress(1.0, desc="OCR complete")
1266+
1267+
return response if response else "No text extracted from the image."
1268+
1269+
with gr.Blocks() as ocr_interface:
1270+
gr.Markdown(f"### OCR Text Extraction with {self.model_name}")
1271+
1272+
with gr.Row():
1273+
with gr.Column(scale=1):
1274+
image_input = gr.Image(
1275+
type="pil",
1276+
label="Upload Image for OCR",
1277+
interactive=True,
1278+
height=400,
1279+
)
1280+
1281+
gr.Markdown(f"**Current OCR Model:** {self.model_name}")
1282+
1283+
ocr_type = gr.Dropdown(
1284+
choices=["ocr", "format"],
1285+
value="ocr",
1286+
label="OCR Type",
1287+
info="Choose OCR processing type",
1288+
)
1289+
1290+
extract_btn = gr.Button("Extract Text", variant="primary")
1291+
1292+
with gr.Column(scale=1):
1293+
text_output = gr.Textbox(
1294+
label="Extracted Text",
1295+
lines=20,
1296+
placeholder="Extracted text will appear here...",
1297+
interactive=True,
1298+
show_copy_button=True,
1299+
)
1300+
1301+
# Examples section
1302+
gr.Markdown("### Examples")
1303+
gr.Examples(
1304+
examples=[
1305+
# You can add example image paths here if needed
1306+
],
1307+
inputs=[image_input],
1308+
label="Example Images",
1309+
)
1310+
1311+
# Extract button click event
1312+
extract_btn.click(
1313+
fn=extract_text_from_image,
1314+
inputs=[image_input, ocr_type],
1315+
outputs=[text_output],
1316+
)
1317+
1318+
return ocr_interface
1319+
12361320
def build_main_interface(self) -> "gr.Blocks":
12371321
if self.model_type == "image":
1238-
title = f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨"
1322+
if "ocr" in self.model_ability:
1323+
title = f"🔍 Xinference OCR: {self.model_name} 🔍"
1324+
else:
1325+
title = f"🎨 Xinference Stable Diffusion: {self.model_name} 🎨"
12391326
elif self.model_type == "video":
12401327
title = f"🎨 Xinference Video Generation: {self.model_name} 🎨"
12411328
else:
@@ -1266,6 +1353,9 @@ def build_main_interface(self) -> "gr.Blocks":
12661353
</div>
12671354
"""
12681355
)
1356+
if "ocr" in self.model_ability:
1357+
with gr.Tab("OCR"):
1358+
self.ocr_interface()
12691359
if "text2image" in self.model_ability:
12701360
with gr.Tab("Text to Image"):
12711361
self.text2image_interface()

xinference/ui/web/ui/src/scenes/register_model/registerModel.js

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ const model_ability_options = [
5858
'Hybrid',
5959
],
6060
},
61+
{
62+
type: 'image',
63+
options: ['ocr'],
64+
},
6165
{
6266
type: 'audio',
6367
options: ['text2audio', 'audio2text'],
@@ -76,7 +80,7 @@ const messages = [
7680
const model_family_options = [
7781
{
7882
type: 'image',
79-
options: ['stable_diffusion'],
83+
options: ['stable_diffusion', 'ocr'],
8084
},
8185
{
8286
type: 'audio',

0 commit comments

Comments
 (0)