Skip to content

Commit 121e003

Browse files
author
arpechenin
committed
inor enhancement of the plugin mock
Signed-off-by: arpechenin <[email protected]>
1 parent 1935e0e commit 121e003

File tree

1 file changed

+23
-18
lines changed
  • proposals/separate-standalone-driver/src/driver-plugin

1 file changed

+23
-18
lines changed
Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,37 @@
11
from fastapi import FastAPI, Request
22
import uvicorn
3+
from contextlib import asynccontextmanager
34
import aiohttp
45

5-
app = FastAPI()
66

7+
@asynccontextmanager
8+
async def lifespan(app: FastAPI):
9+
app.state.session = aiohttp.ClientSession()
10+
yield
11+
await app.state.session.close()
12+
app = FastAPI(lifespan=lifespan)
713

8-
async def call_driver(url, payload):
9-
async with aiohttp.ClientSession() as session:
10-
async with session.post(url, json=payload) as response:
11-
if response.status != 200:
12-
text = await response.text()
13-
raise Exception(f"driver call failed with status: {response.status} error: {text}")
14-
content_type = response.headers.get("Content-Type", "")
15-
if "application/json" not in content_type:
16-
text = await response.text()
17-
raise Exception(f"driver returns unexpected Content-Type: {content_type}, response: {text}")
18-
return await response.json()
14+
15+
async def call_driver(url, request: Request):
16+
session: aiohttp.ClientSession = request.app.state.session
17+
body = await request.json()
18+
payload = body.get("template", {}).get("plugin", {}).get("driver-plugin", {}).get("args", {})
19+
async with session.post(url, json=payload) as response:
20+
if response.status != 200:
21+
text = await response.text()
22+
raise Exception(f"driver call failed with status: {response.status} error: {text}")
23+
content_type = response.headers.get("Content-Type", "")
24+
if "application/json" not in content_type:
25+
text = await response.text()
26+
raise Exception(f"driver returns unexpected Content-Type: {content_type}, response: {text}")
27+
return await response.json()
1928

2029

2130
@app.post("/api/v1/template.execute")
2231
async def execute_plugin(request: Request):
23-
body = await request.json()
24-
payload = body.get("template", {}).get("plugin", {}).get("driver-plugin", {}).get("args", {})
25-
print("request payload:" + str(payload))
26-
response = await call_driver("http://ml-pipeline-kfp-driver.kubeflow.svc.cluster.local:2948/api/v1/execute", payload)
27-
print("response:", response)
32+
response = await call_driver("http://ml-pipeline-kfp-driver.kubeflow.svc:2948/api/v1/execute", request)
2833
return response
2934

3035

3136
if __name__ == "__main__":
32-
uvicorn.run("main:app", host="0.0.0.0", port=2948, reload=True)
37+
uvicorn.run("main:app", host="0.0.0.0", port=2948, reload=True)

0 commit comments

Comments
 (0)