11from fastapi import FastAPI , Request
22import uvicorn
3+ from contextlib import asynccontextmanager
34import aiohttp
45
5- app = FastAPI ()
66
7+ @asynccontextmanager
8+ async def lifespan (app : FastAPI ):
9+ app .state .session = aiohttp .ClientSession ()
10+ yield
11+ await app .state .session .close ()
12+ app = FastAPI (lifespan = lifespan )
713
8- async def call_driver (url , payload ):
9- async with aiohttp .ClientSession () as session :
10- async with session .post (url , json = payload ) as response :
11- if response .status != 200 :
12- text = await response .text ()
13- raise Exception (f"driver call failed with status: { response .status } error: { text } " )
14- content_type = response .headers .get ("Content-Type" , "" )
15- if "application/json" not in content_type :
16- text = await response .text ()
17- raise Exception (f"driver returns unexpected Content-Type: { content_type } , response: { text } " )
18- return await response .json ()
14+
15+ async def call_driver (url , request : Request ):
16+ session : aiohttp .ClientSession = request .app .state .session
17+ body = await request .json ()
18+ payload = body .get ("template" , {}).get ("plugin" , {}).get ("driver-plugin" , {}).get ("args" , {})
19+ async with session .post (url , json = payload ) as response :
20+ if response .status != 200 :
21+ text = await response .text ()
22+ raise Exception (f"driver call failed with status: { response .status } error: { text } " )
23+ content_type = response .headers .get ("Content-Type" , "" )
24+ if "application/json" not in content_type :
25+ text = await response .text ()
26+ raise Exception (f"driver returns unexpected Content-Type: { content_type } , response: { text } " )
27+ return await response .json ()
1928
2029
2130@app .post ("/api/v1/template.execute" )
2231async def execute_plugin (request : Request ):
23- body = await request .json ()
24- payload = body .get ("template" , {}).get ("plugin" , {}).get ("driver-plugin" , {}).get ("args" , {})
25- print ("request payload:" + str (payload ))
26- response = await call_driver ("http://ml-pipeline-kfp-driver.kubeflow.svc.cluster.local:2948/api/v1/execute" , payload )
27- print ("response:" , response )
32+ response = await call_driver ("http://ml-pipeline-kfp-driver.kubeflow.svc:2948/api/v1/execute" , request )
2833 return response
2934
3035
3136if __name__ == "__main__" :
32- uvicorn .run ("main:app" , host = "0.0.0.0" , port = 2948 , reload = True )
37+ uvicorn .run ("main:app" , host = "0.0.0.0" , port = 2948 , reload = True )
0 commit comments