|
5 | 5 | import traceback
|
6 | 6 | from copy import deepcopy
|
7 | 7 | from multiprocessing import Queue
|
| 8 | +from tempfile import TemporaryDirectory |
8 | 9 | from threading import Event, Lock, Thread
|
9 | 10 | from typing import Dict, List, Mapping, Optional
|
10 | 11 |
|
11 | 12 | import uvicorn
|
12 | 13 | from deepdiff import DeepDiff, Delta
|
13 |
| -from fastapi import FastAPI, HTTPException, Request, Response, status, WebSocket |
| 14 | +from fastapi import FastAPI, File, HTTPException, Request, Response, status, UploadFile, WebSocket |
14 | 15 | from fastapi.middleware.cors import CORSMiddleware
|
15 | 16 | from fastapi.params import Header
|
16 | 17 | from fastapi.responses import HTMLResponse, JSONResponse
|
|
23 | 24 | from lightning_app.api.request_types import DeltaRequest
|
24 | 25 | from lightning_app.core.constants import FRONTEND_DIR
|
25 | 26 | from lightning_app.core.queues import RedisQueue
|
| 27 | +from lightning_app.storage import Drive |
26 | 28 | from lightning_app.utilities.app_helpers import InMemoryStateStore, Logger, StateStore
|
27 | 29 | from lightning_app.utilities.enum import OpenAPITags
|
28 | 30 | from lightning_app.utilities.imports import _is_redis_available, _is_starsessions_available
|
@@ -234,6 +236,29 @@ async def post_state(
|
234 | 236 | api_app_delta_queue.put(DeltaRequest(delta=Delta(deep_diff)))
|
235 | 237 |
|
236 | 238 |
|
| 239 | +@fastapi_service.put("/api/v1/upload_file/{filename}") |
| 240 | +async def upload_file(filename: str, uploaded_file: UploadFile = File(...)): |
| 241 | + with TemporaryDirectory() as tmp: |
| 242 | + drive = Drive( |
| 243 | + "lit://uploaded_files", |
| 244 | + component_name="file_server", |
| 245 | + allow_duplicates=True, |
| 246 | + root_folder=tmp, |
| 247 | + ) |
| 248 | + tmp_file = os.path.join(tmp, filename) |
| 249 | + |
| 250 | + with open(tmp_file, "wb") as f: |
| 251 | + done = False |
| 252 | + while not done: |
| 253 | + # Note: The 8192 number doesn't have a strong reason. |
| 254 | + content = await uploaded_file.read(8192) |
| 255 | + f.write(content) |
| 256 | + done = content == b"" |
| 257 | + |
| 258 | + drive.put(filename) |
| 259 | + return f"Successfully uploaded '{filename}' to the Drive" |
| 260 | + |
| 261 | + |
237 | 262 | @fastapi_service.get("/healthz", status_code=200)
|
238 | 263 | async def healthz(response: Response):
|
239 | 264 | """Health check endpoint used in the cloud FastAPI servers to check the status periodically. This requires
|
|
0 commit comments