Skip to content

Commit f38c113

Browse files
committed
attempt to correct API for nested data model retrieval.
1 parent 93c7e82 commit f38c113

File tree

1 file changed

+64
-15
lines changed

1 file changed

+64
-15
lines changed

opensensor/collection_apis.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta, timezone
2-
from typing import Generic, List, Optional, Type, TypeVar
2+
from typing import Generic, List, Optional, Type, TypeVar, get_args, get_origin
33

44
from bson import Binary
55
from fastapi import APIRouter, Depends, HTTPException, Path, Query, Response, status
@@ -182,6 +182,59 @@ def get_initial_match_clause(
182182
return match_clause
183183

184184

185+
def is_pydantic_model(obj):
186+
return isinstance(obj, type) and issubclass(obj, BaseModel)
187+
188+
189+
def get_nested_fields(model: Type[BaseModel]):
190+
nested_fields = {}
191+
for field_name, field in model.__fields__.items():
192+
if is_pydantic_model(field.type_):
193+
nested_fields[field_name] = field.type_
194+
elif get_origin(field.type_) is List and is_pydantic_model(get_args(field.type_)[0]):
195+
nested_fields[field_name] = get_args(field.type_)[0]
196+
return nested_fields
197+
198+
199+
def create_nested_pipeline(model: Type[BaseModel], prefix=""):
200+
nested_fields = get_nested_fields(model)
201+
pipeline = {}
202+
203+
for field_name, field_type in model.__fields__.items():
204+
full_field_name = f"{prefix}{field_name}"
205+
206+
if field_name in nested_fields:
207+
if get_origin(field_type.type_) is List:
208+
pipeline[field_name] = {
209+
"$map": {
210+
"input": f"${full_field_name}",
211+
"as": "item",
212+
"in": create_nested_pipeline(nested_fields[field_name], "$$item."),
213+
}
214+
}
215+
else:
216+
pipeline[field_name] = create_nested_pipeline(
217+
nested_fields[field_name], f"{full_field_name}."
218+
)
219+
else:
220+
pipeline[field_name] = f"${full_field_name}"
221+
222+
return pipeline
223+
224+
225+
def create_model_instance(model: Type[BaseModel], data: dict):
226+
nested_fields = get_nested_fields(model)
227+
for field_name, nested_model in nested_fields.items():
228+
if field_name in data:
229+
if isinstance(data[field_name], list):
230+
data[field_name] = [
231+
create_model_instance(nested_model, item) for item in data[field_name]
232+
]
233+
else:
234+
data[field_name] = create_model_instance(nested_model, data[field_name])
235+
return model(**data)
236+
237+
185238
def get_vpd_pipeline(
186239
device_ids: List[str],
187240
device_name: str,
@@ -264,7 +317,7 @@ def get_vpd_pipeline(
264317

265318
def get_uniform_sample_pipeline(
266319
response_model: Type[T],
267-
device_ids: List[str], # Update the type of the device_id parameter to List[str]
320+
device_ids: List[str],
268321
device_name: str,
269322
start_date: datetime,
270323
end_date: datetime,
@@ -277,13 +330,10 @@ def get_uniform_sample_pipeline(
277330
sampling_interval = timedelta(minutes=resolution)
278331
match_clause = get_initial_match_clause(device_ids, device_name, start_date, end_date)
279332

280-
# Determine the $project
281-
old_name = get_collection_name(response_model)
282-
new_collection_name = new_collections[old_name]
283-
project_projection = _get_project_projection(response_model)
284-
match_clause[new_collection_name] = {"$exists": True}
333+
# Create a generalized project pipeline
334+
project_pipeline = create_nested_pipeline(response_model)
335+
project_pipeline["timestamp"] = "$timestamp"
285336

286-
# Query a uniform sample of documents within the timestamp range
287337
pipeline = [
288338
{"$match": match_clause},
289339
{
@@ -300,10 +350,10 @@ def get_uniform_sample_pipeline(
300350
},
301351
{"$group": {"_id": "$group", "doc": {"$first": "$$ROOT"}}},
302352
{"$replaceRoot": {"newRoot": "$doc"}},
303-
{"$project": project_projection},
304-
{"$sort": {"timestamp": 1}}, # Sort by timestamp in ascending order
305-
# {"$count": "total"}
353+
{"$project": project_pipeline},
354+
{"$sort": {"timestamp": 1}},
306355
]
356+
307357
return pipeline
308358

309359

@@ -371,10 +421,9 @@ def sample_and_paginate_collection(
371421
# So, you can directly use it to create the response model instances.
372422
data = [VPD(**item) for item in raw_data]
373423
else:
374-
data_field = model_class_attributes[response_model]
375-
model_class = model_classes[data_field]
376-
data = [model_class(**item) for item in raw_data]
377-
if data_field == "temp" and unit:
424+
data = [create_model_instance(response_model, item) for item in raw_data]
425+
426+
if hasattr(response_model, "temp") and unit:
378427
for item in data:
379428
convert_temperature(item, unit)
380429

0 commit comments

Comments
 (0)