@@ -63,9 +63,7 @@ def __init__(
6363        )
6464
6565    def  build (self ) ->  gr .Blocks :
66-         if  self .model_type  ==  "image" :
67-             assert  "stable_diffusion"  in  self .model_family 
68- 
66+         # Remove the stable_diffusion restriction to support OCR models 
6967        interface  =  self .build_main_interface ()
7068        interface .queue ()
7169        # Gradio initiates the queue during a startup event, but since the app has already been 
@@ -1233,9 +1231,98 @@ def tts_generate(
12331231
12341232        return  tts_ui 
12351233
1234+     def  ocr_interface (self ) ->  "gr.Blocks" :
1235+         def  extract_text_from_image (
1236+             image : "PIL.Image.Image" ,
1237+             ocr_type : str  =  "ocr" ,
1238+             progress = gr .Progress (),
1239+         ) ->  str :
1240+             from  ...client  import  RESTfulClient 
1241+ 
1242+             client  =  RESTfulClient (self .endpoint )
1243+             client ._set_token (self .access_token )
1244+             model  =  client .get_model (self .model_uid )
1245+             assert  hasattr (model , "ocr" )
1246+ 
1247+             # Convert PIL image to bytes 
1248+             import  io 
1249+ 
1250+             buffered  =  io .BytesIO ()
1251+             if  image .mode  ==  "RGBA"  or  image .mode  ==  "CMYK" :
1252+                 image  =  image .convert ("RGB" )
1253+             image .save (buffered , format = "PNG" )
1254+             image_bytes  =  buffered .getvalue ()
1255+ 
1256+             progress (0.1 , desc = "Processing image for OCR" )
1257+ 
1258+             # Call the OCR method with bytes instead of PIL Image 
1259+             response  =  model .ocr (
1260+                 image = image_bytes ,
1261+                 ocr_type = ocr_type ,
1262+             )
1263+ 
1264+             progress (0.8 , desc = "Extracting text" )
1265+             progress (1.0 , desc = "OCR complete" )
1266+ 
1267+             return  response  if  response  else  "No text extracted from the image." 
1268+ 
1269+         with  gr .Blocks () as  ocr_interface :
1270+             gr .Markdown (f"### OCR Text Extraction with { self .model_name }  )
1271+ 
1272+             with  gr .Row ():
1273+                 with  gr .Column (scale = 1 ):
1274+                     image_input  =  gr .Image (
1275+                         type = "pil" ,
1276+                         label = "Upload Image for OCR" ,
1277+                         interactive = True ,
1278+                         height = 400 ,
1279+                     )
1280+ 
1281+                     gr .Markdown (f"**Current OCR Model:** { self .model_name }  )
1282+ 
1283+                     ocr_type  =  gr .Dropdown (
1284+                         choices = ["ocr" , "format" ],
1285+                         value = "ocr" ,
1286+                         label = "OCR Type" ,
1287+                         info = "Choose OCR processing type" ,
1288+                     )
1289+ 
1290+                     extract_btn  =  gr .Button ("Extract Text" , variant = "primary" )
1291+ 
1292+                 with  gr .Column (scale = 1 ):
1293+                     text_output  =  gr .Textbox (
1294+                         label = "Extracted Text" ,
1295+                         lines = 20 ,
1296+                         placeholder = "Extracted text will appear here..." ,
1297+                         interactive = True ,
1298+                         show_copy_button = True ,
1299+                     )
1300+ 
1301+             # Examples section 
1302+             gr .Markdown ("### Examples" )
1303+             gr .Examples (
1304+                 examples = [
1305+                     # You can add example image paths here if needed 
1306+                 ],
1307+                 inputs = [image_input ],
1308+                 label = "Example Images" ,
1309+             )
1310+ 
1311+             # Extract button click event 
1312+             extract_btn .click (
1313+                 fn = extract_text_from_image ,
1314+                 inputs = [image_input , ocr_type ],
1315+                 outputs = [text_output ],
1316+             )
1317+ 
1318+         return  ocr_interface 
1319+ 
12361320    def  build_main_interface (self ) ->  "gr.Blocks" :
12371321        if  self .model_type  ==  "image" :
1238-             title  =  f"🎨 Xinference Stable Diffusion: { self .model_name }  
1322+             if  "ocr"  in  self .model_ability :
1323+                 title  =  f"🔍 Xinference OCR: { self .model_name }  
1324+             else :
1325+                 title  =  f"🎨 Xinference Stable Diffusion: { self .model_name }  
12391326        elif  self .model_type  ==  "video" :
12401327            title  =  f"🎨 Xinference Video Generation: { self .model_name }  
12411328        else :
@@ -1266,6 +1353,9 @@ def build_main_interface(self) -> "gr.Blocks":
12661353                    </div> 
12671354                    """ 
12681355            )
1356+             if  "ocr"  in  self .model_ability :
1357+                 with  gr .Tab ("OCR" ):
1358+                     self .ocr_interface ()
12691359            if  "text2image"  in  self .model_ability :
12701360                with  gr .Tab ("Text to Image" ):
12711361                    self .text2image_interface ()
0 commit comments