Skip to content

Commit 0db96b5

Browse files
committed
add tests for ndjson
1 parent 6c9fa4e commit 0db96b5

File tree

2 files changed

+57
-31
lines changed

2 files changed

+57
-31
lines changed

tests/routes/test_items.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,9 +561,7 @@ def test_output_response_type(app):
561561
body = response.json()
562562
assert len(body) == 10
563563
feat = body[0]
564-
assert ["collectionId", "itemId", "id", "pr", "row", "path", "ogc_fid"] == list(
565-
feat.keys()
566-
)
564+
assert "geometry" not in feat.keys()
567565

568566
response = app.get(
569567
"/collections/public.landsat_wrs/items",
@@ -573,3 +571,30 @@ def test_output_response_type(app):
573571
assert response.headers["content-type"] == "application/json"
574572
body = response.json()
575573
assert len(body) == 10
574+
575+
# ndjson output
576+
response = app.get("/collections/public.landsat_wrs/items?f=ndjson")
577+
assert response.status_code == 200
578+
assert response.headers["content-type"] == "application/ndjson"
579+
body = response.text.splitlines()
580+
assert len(body) == 10
581+
feat = json.loads(body[0])
582+
assert [
583+
"collectionId",
584+
"itemId",
585+
"id",
586+
"pr",
587+
"row",
588+
"path",
589+
"ogc_fid",
590+
"geometry",
591+
] == list(feat.keys())
592+
593+
response = app.get(
594+
"/collections/public.landsat_wrs/items",
595+
headers={"accept": "application/ndjson"},
596+
)
597+
assert response.status_code == 200
598+
assert response.headers["content-type"] == "application/ndjson"
599+
body = response.text.splitlines()
600+
assert len(body) == 10

tifeatures/factory.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import csv
44
import json
55
from dataclasses import dataclass, field
6-
from typing import Any, Callable, Dict, Iterable, List, Optional
6+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional
77

88
import jinja2
99
from pygeofilter.ast import AstType
@@ -39,30 +39,6 @@
3939
from starlette.responses import StreamingResponse
4040
from starlette.templating import Jinja2Templates, _TemplateResponse
4141

42-
43-
class DummyWriter:
44-
"""Dummy writer that implements write for use with csv.writer."""
45-
46-
def write(self, line: str):
47-
"""Return line."""
48-
return line
49-
50-
51-
def iter_csv(data: Iterable[Dict]):
52-
"""Creates an iterator that returns lines of csv from an iterable of dicts."""
53-
54-
initial = True
55-
writer = None
56-
for row in data:
57-
if initial:
58-
fieldnames = row.keys()
59-
writer = csv.DictWriter(DummyWriter(), fieldnames=fieldnames)
60-
yield writer.writeheader()
61-
initial = False
62-
if writer:
63-
yield writer.writerow(row)
64-
65-
6642
settings = APISettings()
6743

6844
# custom template directory
@@ -112,6 +88,29 @@ def create_html_response(
11288
)
11389

11490

91+
def create_csv_rows(data: Iterable[Dict]) -> Generator[str, None, None]:
92+
"""Create Template response."""
93+
94+
class DummyWriter:
95+
"""Dummy writer that implements write for use with csv.writer."""
96+
97+
def write(self, line: str):
98+
"""Return line."""
99+
return line
100+
101+
"""Creates an iterator that returns lines of csv from an iterable of dicts."""
102+
initial = True
103+
writer = None
104+
for row in data:
105+
if initial:
106+
fieldnames = row.keys()
107+
writer = csv.DictWriter(DummyWriter(), fieldnames=fieldnames)
108+
yield writer.writeheader()
109+
initial = False
110+
if writer:
111+
yield writer.writerow(row)
112+
113+
115114
@dataclass
116115
class Endpoints:
117116
"""Endpoints Factory."""
@@ -534,6 +533,7 @@ def queryables(
534533
MediaType.csv.value: {},
535534
MediaType.json.value: {},
536535
MediaType.geojsonseq.value: {},
536+
MediaType.ndjson.value: {},
537537
}
538538
},
539539
},
@@ -621,7 +621,7 @@ async def items(
621621
MediaType.json,
622622
MediaType.ndjson,
623623
):
624-
if items[0].geometry is not None:
624+
if items and items[0].geometry is not None:
625625
rows = (
626626
{
627627
"collectionId": collection.id,
@@ -631,6 +631,7 @@ async def items(
631631
}
632632
for f in items
633633
)
634+
634635
else:
635636
rows = (
636637
{
@@ -644,7 +645,7 @@ async def items(
644645
# CSV Response
645646
if output_type == MediaType.csv:
646647
return StreamingResponse(
647-
iter_csv(rows),
648+
create_csv_rows(rows),
648649
media_type=MediaType.csv,
649650
headers={
650651
"Content-Disposition": "attachment;filename=items.csv"
@@ -658,7 +659,7 @@ async def items(
658659
# NDJSON Response
659660
if output_type == MediaType.ndjson:
660661
return StreamingResponse(
661-
(row + "\n" for row in rows),
662+
(json.dumps(row) + "\n" for row in rows),
662663
media_type=MediaType.ndjson,
663664
headers={
664665
"Content-Disposition": "attachment;filename=items.ndjson"

0 commit comments

Comments
 (0)