Skip to content

Commit cbd09bf

Browse files
committed
feat(server):注册新模型、实现车票识别 API 及任务处理逻辑
1 parent 272ef70 commit cbd09bf

File tree

4 files changed

+178
-5
lines changed

4 files changed

+178
-5
lines changed

package/server/app/api/train_ticket.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,104 @@
2828
from app.db.models.trip import TrainTicket
2929
from app.schemas.train_ticket import TrainTicketResponse, TrainTicketCreate, TrainTicketListResponse, TrainTicketUpdate
3030
from app.dependencies import get_db
31+
from app.core.config_manager import config_manager
32+
import aiohttp
33+
from aiohttp import FormData
3134

3235
router = APIRouter()
3336

3437
# ------------------- 火车票接口 -------------------
3538

39+
@router.post("/recognize", summary="识别车票图片")
40+
async def recognize_ticket(
41+
file: UploadFile = File(..., description="车票图片"),
42+
):
43+
"""
44+
上传车票图片并调用AI服务进行识别
45+
返回识别到的结构化数据
46+
"""
47+
try:
48+
# 1. 读取文件内容
49+
file_content = await file.read()
50+
51+
# 2. 构造请求发送给AI服务
52+
async with aiohttp.ClientSession() as session:
53+
form_data = FormData()
54+
form_data.add_field(
55+
name='file',
56+
value=file_content,
57+
filename=file.filename,
58+
content_type=file.content_type or 'image/jpeg'
59+
)
60+
61+
api_url = f"{config_manager.config.ai.ai_api_url}/tickets/predict"
62+
63+
async with session.post(api_url, data=form_data) as response:
64+
if response.status != 200:
65+
raise HTTPException(status_code=500, detail=f"AI服务请求失败: {response.status}")
66+
67+
result = await response.json()
68+
69+
if not result or 'tickets' not in result or not result['tickets']:
70+
raise HTTPException(status_code=400, detail="未能识别出车票信息")
71+
72+
# 获取第一张识别出的车票
73+
ticket_info = result['tickets'][0]
74+
75+
# 3. 数据格式化
76+
processed_data = {}
77+
78+
# 映射字段
79+
field_mapping = {
80+
'train_code': 'train_code',
81+
'departure_station': 'departure_station',
82+
'arrival_station': 'arrival_station',
83+
'seat_num': 'seat_num',
84+
'seat_type': 'seat_type',
85+
'name': 'name',
86+
'carriage': 'carriage'
87+
}
88+
89+
for k, v in field_mapping.items():
90+
if ticket_info.get(k):
91+
processed_data[v] = ticket_info[k]
92+
93+
# 处理日期时间
94+
if ticket_info.get('datetime'):
95+
dt_str = ticket_info.get('datetime')
96+
dt = None
97+
formats = [
98+
"%Y年%m月%d日 %H:%M",
99+
"%Y年%m月%d日%H:%M",
100+
"%Y-%m-%d %H:%M",
101+
"%Y/%m/%d %H:%M"
102+
]
103+
for fmt in formats:
104+
try:
105+
dt = datetime.strptime(dt_str, fmt)
106+
# 转换为前端友好的 ISO 格式 (YYYY-MM-DDTHH:mm)
107+
processed_data['datetime'] = dt.strftime("%Y-%m-%dT%H:%M")
108+
break
109+
except ValueError:
110+
continue
111+
112+
# 处理价格
113+
if ticket_info.get('price'):
114+
price_str = str(ticket_info.get('price')).replace('元', '').replace('¥', '').strip()
115+
try:
116+
processed_data['price'] = float(price_str)
117+
except:
118+
processed_data['price'] = 0
119+
120+
return processed_data
121+
122+
except Exception as e:
123+
# 如果是 HTTPException 直接抛出
124+
if isinstance(e, HTTPException):
125+
raise e
126+
raise HTTPException(status_code=500, detail=f"识别处理失败: {str(e)}")
127+
128+
36129
@router.post("/import", summary="导入车票数据")
37130
async def import_tickets(
38131
file: UploadFile = File(..., description="数据文件(支持JSON/CSV)"),

package/server/app/db/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@
1010
from .album_photos import AlbumPhoto
1111
from .image_vector import ImageVector
1212
from .ocr import OCR
13+
from .trip import TrainTicket

package/server/app/service/task_worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
TaskType.SCAN_FOLDER,
3434
TaskType.RECOGNIZE_FACE,
3535
TaskType.OCR,
36-
TaskType.CLASSIFY_IMAGE
36+
TaskType.CLASSIFY_IMAGE,
37+
TaskType.RECOGNIZE_TICKET
3738
}
3839

3940
class TaskWorker:

package/server/app/service/tasks/tickets.py

Lines changed: 82 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55
from sqlalchemy.orm import Session
66
from app.db.models.task import Task, TaskType
77
from app.db.models.photo import Photo, FileType
8+
from app.db.models.trip import TrainTicket
89
from typing import Dict, Any, List
10+
from datetime import datetime
911
from app.core.config_manager import config_manager
1012
from app.service import storage
13+
import re
1114

1215
logger = logging.getLogger(__name__)
1316

@@ -100,15 +103,90 @@ async def process_single_photo(task_manager, photo: Photo, db: Session) -> Dict[
100103
if response.status == 200:
101104
result = await response.json()
102105

106+
# === Auto-add tickets to database ===
107+
if result and 'tickets' in result:
108+
tickets_data = result['tickets']
109+
added_count = 0
110+
for t_info in tickets_data:
111+
try:
112+
# Validation
113+
if not t_info.get('train_code') or not t_info.get('datetime'):
114+
continue
115+
116+
# Parse datetime
117+
dt_str = t_info.get('datetime')
118+
dt = None
119+
# Try standard formats
120+
formats = [
121+
"%Y年%m月%d日 %H:%M",
122+
"%Y年%m月%d日%H:%M",
123+
"%Y-%m-%d %H:%M",
124+
"%Y/%m/%d %H:%M"
125+
]
126+
for fmt in formats:
127+
try:
128+
dt = datetime.strptime(dt_str, fmt)
129+
break
130+
except ValueError:
131+
continue
132+
133+
if not dt:
134+
logger.warning(f"Skipping ticket due to invalid datetime: {dt_str}")
135+
continue
136+
137+
# Parse Price
138+
price_val = 0.0
139+
price_str = str(t_info.get('price', '0')).replace('元', '').replace('¥', '').strip()
140+
try:
141+
price_val = float(price_str)
142+
except:
143+
pass
144+
145+
# Check duplicate
146+
existing = db.query(TrainTicket).filter(
147+
TrainTicket.train_code == t_info['train_code'],
148+
TrainTicket.date_time == dt,
149+
TrainTicket.seat_num == (t_info.get('seat_num') or '无座')
150+
).first()
151+
152+
if existing:
153+
logger.info(f"Duplicate ticket found: {t_info['train_code']} {dt}")
154+
continue
155+
156+
# Create Ticket
157+
new_ticket = TrainTicket(
158+
train_code=t_info['train_code'],
159+
departure_station=t_info.get('departure_station', '未知'),
160+
arrival_station=t_info.get('arrival_station', '未知'),
161+
date_time=dt,
162+
carriage=t_info.get('carriage') or '无',
163+
seat_num=t_info.get('seat_num') or '无座',
164+
berth_type=t_info.get('berth_type') or '无',
165+
price=price_val,
166+
seat_type=t_info.get('seat_type') or '二等座',
167+
name=t_info.get('name') or '未知',
168+
discount_type=t_info.get('discount_type') or '全价票',
169+
total_mileage=0,
170+
total_running_time=0,
171+
stop_stations="[]",
172+
comments=f"自动识别自图片: {photo.filename}"
173+
)
174+
db.add(new_ticket)
175+
db.flush()
176+
t_info['saved_id'] = new_ticket.id
177+
added_count += 1
178+
179+
except Exception as ex:
180+
logger.error(f"Error saving ticket to DB: {ex}")
181+
182+
if added_count > 0:
183+
logger.info(f"Successfully added {added_count} tickets from photo {photo.id}")
184+
103185
# Update processed status
104186
tasks_status = photo.processed_tasks or {}
105187
tasks_status['tickets'] = True
106188
photo.processed_tasks = tasks_status
107189

108-
# NOTE: Here we could save the structured ticket info to a DB table if one existed.
109-
# For now, the result is returned and will be stored in Task.result
110-
# logger.info(f"Ticket task result for photo {photo.id}: {result}")
111-
112190
db.add(photo)
113191
db.commit()
114192

0 commit comments

Comments
 (0)