Skip to content

Commit a904d95

Browse files
committed
file uploading fixed
1 parent f12af67 commit a904d95

File tree

4 files changed

+259
-54
lines changed

4 files changed

+259
-54
lines changed

client/src/components/Dragger.js

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,38 @@ import UTIF from 'utif'
88

99
const path = require('path')
1010

11+
const ensureTrailingSeparator = (dirPath) => {
12+
if (!dirPath) return ''
13+
return dirPath.endsWith(path.sep) ? dirPath : dirPath + path.sep
14+
}
15+
16+
const getFolderPath = (uploadFile, originPath) => {
17+
if (uploadFile?.folderPath) {
18+
return ensureTrailingSeparator(uploadFile.folderPath)
19+
}
20+
if (!originPath) return ''
21+
return ensureTrailingSeparator(path.dirname(originPath))
22+
}
23+
24+
const enrichFileMetadata = (uploadFile) => {
25+
const originPath =
26+
uploadFile?.originFileObj?.path ||
27+
uploadFile?.path
28+
const folderPath = getFolderPath(uploadFile, originPath)
29+
30+
const enhancedFile = { ...uploadFile }
31+
if (originPath) {
32+
enhancedFile.path = originPath
33+
}
34+
if (folderPath) {
35+
enhancedFile.folderPath = folderPath
36+
}
37+
return enhancedFile
38+
}
39+
1140
export function Dragger () {
1241
const context = useContext(AppContext)
42+
const { setFiles, files } = context
1343

1444
// const getBase64 = (file) =>
1545
// new Promise((resolve, reject) => {
@@ -22,22 +52,21 @@ export function Dragger () {
2252
const onChange = (info) => {
2353
const { status } = info.file
2454
if (status === 'done') {
25-
console.log('file found at:', info.file.originFileObj.path)
55+
const originPath =
56+
info.file?.originFileObj?.path ||
57+
info.file?.path
58+
console.log('file found at:', originPath)
2659

2760
message.success(`${info.file.name} file uploaded successfully.`)
28-
if (window.require) {
29-
const modifiedFile = { ...info.file, path: info.file.originFileObj.path }
30-
context.setFiles([...context.files, modifiedFile])
31-
} else {
32-
context.setFiles([...info.fileList])
33-
}
61+
const updatedFiles = info.fileList.map(enrichFileMetadata)
62+
setFiles(updatedFiles)
3463
console.log('done')
3564
} else if (status === 'error') {
3665
console.log('error')
3766
message.error(`${info.file.name} file upload failed.`)
3867
} else if (status === 'removed') {
3968
console.log(info.fileList)
40-
context.setFiles([...info.fileList])
69+
setFiles(info.fileList.map(enrichFileMetadata))
4170
}
4271
}
4372

@@ -58,6 +87,22 @@ export function Dragger () {
5887
const [previewFileFolderPath, setPreviewFileFolderPath] = useState('')
5988
const [fileType, setFileType] = useState('Image')
6089

90+
useEffect(() => {
91+
if (!files || files.length === 0) return
92+
const needsFolderPath = files.some(
93+
(file) =>
94+
!file.folderPath &&
95+
(file?.originFileObj?.path || file?.path)
96+
)
97+
if (needsFolderPath) {
98+
setFiles(
99+
files.map((file) =>
100+
!file.folderPath ? enrichFileMetadata(file) : file
101+
)
102+
)
103+
}
104+
}, [files, setFiles])
105+
61106
const handleText = (event) => {
62107
setValue(event.target.value)
63108
}
@@ -177,8 +222,12 @@ export function Dragger () {
177222
.folderPath
178223
)
179224
} else {
180-
// Directory name with trailing slash
181-
setPreviewFileFolderPath(path.dirname(file.originFileObj.path) + '/')
225+
const originPath =
226+
file?.originFileObj?.path ||
227+
file?.path
228+
setPreviewFileFolderPath(
229+
originPath ? ensureTrailingSeparator(path.dirname(originPath)) : ''
230+
)
182231
}
183232
}
184233

client/src/utils/api.js

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,47 @@ import { message } from 'antd'
55
const API_PROTOCOL = process.env.REACT_APP_API_PROTOCOL || 'http'
66
const API_URL = process.env.REACT_APP_API_URL || 'localhost:4242'
77

8+
const buildFilePath = (file) => {
9+
if (!file) return ''
10+
if (file.folderPath) return file.folderPath + file.name
11+
if (file.path) return file.path
12+
if (file.originFileObj && file.originFileObj.path) {
13+
return file.originFileObj.path
14+
}
15+
return file.name
16+
}
17+
18+
const hasBrowserFile = (file) =>
19+
file && file.originFileObj instanceof File
20+
821
export async function getNeuroglancerViewer (image, label, scales) {
922
try {
23+
const url = `${API_PROTOCOL}://${API_URL}/neuroglancer`
24+
if (hasBrowserFile(image)) {
25+
const formData = new FormData()
26+
formData.append(
27+
'image',
28+
image.originFileObj,
29+
image.originFileObj.name || image.name || 'image'
30+
)
31+
if (label && hasBrowserFile(label)) {
32+
formData.append(
33+
'label',
34+
label.originFileObj,
35+
label.originFileObj.name || label.name || 'label'
36+
)
37+
}
38+
formData.append('scales', JSON.stringify(scales))
39+
const res = await axios.post(url, formData)
40+
return res.data
41+
}
42+
1043
const data = JSON.stringify({
11-
image: image.folderPath + image.name,
12-
label: label.folderPath + label.name,
44+
image: buildFilePath(image),
45+
label: buildFilePath(label),
1346
scales
1447
})
15-
const res = await axios.post(
16-
`${API_PROTOCOL}://${API_URL}/neuroglancer`,
17-
data
18-
)
48+
const res = await axios.post(url, data)
1949
return res.data
2050
} catch (error) {
2151
message.error(

server_api/main.py

Lines changed: 86 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import json
12
import pathlib
3+
import shutil
4+
import tempfile
5+
from typing import List, Optional
26

37
import requests
48
import uvicorn
5-
from fastapi import FastAPI, Request
9+
from fastapi import FastAPI, HTTPException, Request, UploadFile
610
from fastapi.middleware.cors import CORSMiddleware
711
from utils.io import readVol
812

@@ -21,11 +25,21 @@
2125

2226

2327
def process_path(path):
24-
# Get the absolute path of the current script"s parent directory
25-
current_script_dir = pathlib.Path(__file__).parent
26-
# The root of the repository is assumed to be one level up from the current script"s directory
27-
repo_root = current_script_dir.parent.absolute()
28-
return repo_root / path
28+
if not path:
29+
return None
30+
candidate = pathlib.Path(path).expanduser()
31+
if candidate.is_absolute():
32+
return candidate
33+
return candidate.resolve(strict=False)
34+
35+
36+
def save_upload_to_tempfile(upload: UploadFile) -> pathlib.Path:
37+
suffix = pathlib.Path(upload.filename or "").suffix
38+
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
39+
upload.file.seek(0)
40+
shutil.copyfileobj(upload.file, tmp)
41+
temp_path = pathlib.Path(tmp.name)
42+
return temp_path
2943

3044

3145
@app.get("/hello")
@@ -37,35 +51,73 @@ def hello():
3751
async def neuroglancer(req: Request):
3852
import neuroglancer
3953

40-
req = await req.json()
41-
image = process_path(req["image"])
42-
label = process_path(req["label"])
43-
scales = req["scales"]
44-
print(image, label, scales)
45-
# neuroglancer setting -- bind to this to make accessible outside of container
46-
ip = "0.0.0.0"
47-
port = 4244
48-
neuroglancer.set_server_bind_address(ip, port)
49-
viewer = neuroglancer.Viewer()
50-
# SNEMI (# 3d vol dim: z,y,x)
51-
res = neuroglancer.CoordinateSpace(
52-
names=["z", "y", "x"], units=["nm", "nm", "nm"], scales=scales
53-
)
54-
im = readVol(image, image_type="im")
55-
gt = readVol(label, image_type="im")
56-
57-
def ngLayer(data, res, oo=[0, 0, 0], tt="segmentation"):
58-
return neuroglancer.LocalVolume(
59-
data, dimensions=res, volume_type=tt, voxel_offset=oo
54+
cleanup_paths: List[pathlib.Path] = []
55+
try:
56+
content_type = req.headers.get("content-type", "")
57+
if "multipart/form-data" in content_type:
58+
form = await req.form()
59+
image_upload = form.get("image")
60+
if not image_upload or not getattr(image_upload, "filename", None):
61+
raise HTTPException(status_code=400, detail="Image file is required.")
62+
scales_raw = form.get("scales")
63+
if scales_raw is None:
64+
raise HTTPException(status_code=400, detail="Scales are required.")
65+
try:
66+
scales = json.loads(scales_raw)
67+
except json.JSONDecodeError:
68+
raise HTTPException(status_code=400, detail="Scales payload is invalid.")
69+
70+
image = save_upload_to_tempfile(image_upload)
71+
cleanup_paths.append(image)
72+
73+
label_upload = form.get("label")
74+
label: Optional[pathlib.Path] = None
75+
if label_upload and getattr(label_upload, "filename", None):
76+
label = save_upload_to_tempfile(label_upload)
77+
cleanup_paths.append(label)
78+
else:
79+
payload = await req.json()
80+
image = process_path(payload["image"])
81+
label = process_path(payload.get("label"))
82+
scales = payload["scales"]
83+
84+
print(image, label, scales)
85+
86+
if image is None:
87+
raise HTTPException(status_code=400, detail="Image path or file is required.")
88+
89+
# neuroglancer setting -- bind to this to make accessible outside of container
90+
ip = "0.0.0.0"
91+
port = 4244
92+
neuroglancer.set_server_bind_address(ip, port)
93+
viewer = neuroglancer.Viewer()
94+
# SNEMI (# 3d vol dim: z,y,x)
95+
res = neuroglancer.CoordinateSpace(
96+
names=["z", "y", "x"], units=["nm", "nm", "nm"], scales=scales
6097
)
61-
62-
with viewer.txn() as s:
63-
s.layers.append(name="im", layer=ngLayer(im, res, tt="image"))
64-
if label:
65-
s.layers.append(name="gt", layer=ngLayer(gt, res, tt="segmentation"))
66-
67-
print(viewer)
68-
return str(viewer)
98+
im = readVol(image, image_type="im")
99+
gt = readVol(label, image_type="im") if label else None
100+
101+
def ngLayer(data, res, oo=[0, 0, 0], tt="segmentation"):
102+
return neuroglancer.LocalVolume(
103+
data, dimensions=res, volume_type=tt, voxel_offset=oo
104+
)
105+
106+
with viewer.txn() as s:
107+
s.layers.append(name="im", layer=ngLayer(im, res, tt="image"))
108+
if gt is not None:
109+
s.layers.append(name="gt", layer=ngLayer(gt, res, tt="segmentation"))
110+
111+
print(viewer)
112+
return str(viewer)
113+
finally:
114+
for path in cleanup_paths:
115+
try:
116+
path.unlink()
117+
except FileNotFoundError:
118+
pass
119+
except PermissionError:
120+
pass
69121

70122

71123
@app.post("/start_model_training")

tests/api.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,68 @@
1+
import io
2+
import json
3+
import sys
4+
15
import httpx
26
import pytest
7+
import numpy as np
38
from fastapi import FastAPI
49
from httpx import AsyncClient
510
from server_api.main import app as fastapi_app
11+
from types import SimpleNamespace
12+
13+
14+
class DummyLayers(list):
15+
def append(self, name=None, layer=None): # type: ignore[override]
16+
super().append({"name": name, "layer": layer})
17+
18+
19+
class DummyTxn:
20+
def __init__(self, viewer):
21+
self.viewer = viewer
22+
23+
def __enter__(self):
24+
return SimpleNamespace(layers=self.viewer.layers)
25+
26+
def __exit__(self, exc_type, exc_val, exc_tb):
27+
return False
28+
29+
30+
class DummyViewer:
31+
def __init__(self):
32+
self.layers = DummyLayers()
33+
34+
def txn(self):
35+
return DummyTxn(self)
36+
37+
def __str__(self):
38+
return "dummy_viewer"
39+
40+
41+
class DummyCoordinateSpace:
42+
def __init__(self, names, units, scales):
43+
self.names = names
44+
self.units = units
45+
self.scales = scales
46+
47+
48+
class DummyLocalVolume:
49+
def __init__(self, data, dimensions, volume_type, voxel_offset):
50+
self.data = data
51+
self.dimensions = dimensions
52+
self.volume_type = volume_type
53+
self.voxel_offset = voxel_offset
54+
55+
56+
@pytest.fixture(autouse=True)
57+
def stub_neuroglancer(monkeypatch):
58+
dummy_module = SimpleNamespace(
59+
set_server_bind_address=lambda ip, port: None,
60+
Viewer=DummyViewer,
61+
CoordinateSpace=DummyCoordinateSpace,
62+
LocalVolume=DummyLocalVolume,
63+
)
64+
monkeypatch.setitem(sys.modules, "neuroglancer", dummy_module)
65+
yield
666

767

868
@pytest.fixture
@@ -25,8 +85,22 @@ async def test_hello(client: AsyncClient) -> None:
2585

2686
@pytest.mark.asyncio
2787
async def test_neuroglancer(client: AsyncClient) -> None:
28-
data = {"image": "test_image", "label": "test_label", "scales": [4, 4, 4]}
29-
# The actual testing of this route may need to be adjusted based on your actual functionality.
30-
# You might need to mock calls to `neuroglancer` or other external libraries.
31-
response = await client.post("/neuroglancer", json=data)
88+
array = np.zeros((2, 2), dtype=np.uint8)
89+
image_buffer = io.BytesIO()
90+
label_buffer = io.BytesIO()
91+
np.save(image_buffer, array)
92+
np.save(label_buffer, array)
93+
image_buffer.seek(0)
94+
label_buffer.seek(0)
95+
96+
files = {
97+
"image": ("image.npy", image_buffer, "application/octet-stream"),
98+
"label": ("label.npy", label_buffer, "application/octet-stream"),
99+
}
100+
101+
response = await client.post(
102+
"/neuroglancer",
103+
data={"scales": json.dumps([4, 4, 4])},
104+
files=files,
105+
)
32106
assert response.status_code == 200

0 commit comments

Comments
 (0)