|
1 | | -from typing import Any |
| 1 | +from typing import Any, Dict, List |
2 | 2 |
|
3 | 3 | from django.core.serializers import serialize |
4 | 4 | from django.db.models import Model, QuerySet |
@@ -48,13 +48,51 @@ def _json_serializer(obj): |
48 | 48 | raise TypeError |
49 | 49 |
|
50 | 50 |
|
| 51 | +def _fix_floats(current: Dict, data: Dict = None, paths: List = []) -> None: |
| 52 | + """ |
| 53 | + Recursively change any Python floats to string so that JavaScript |
| 54 | + won't change float to integers. |
| 55 | +
|
| 56 | + Params: |
| 57 | + current: Dictionary in which to check for and fix floats. |
| 58 | + """ |
| 59 | + |
| 60 | + if data is None: |
| 61 | + data = current |
| 62 | + |
| 63 | + if isinstance(current, dict): |
| 64 | + for key, val in current.items(): |
| 65 | + paths.append(key) |
| 66 | + _fix_floats(val, data, paths=paths) |
| 67 | + paths.pop() |
| 68 | + elif isinstance(current, list): |
| 69 | + for (idx, item) in enumerate(current): |
| 70 | + paths.append(idx) |
| 71 | + _fix_floats(item, data, paths=paths) |
| 72 | + paths.pop() |
| 73 | + elif isinstance(current, float): |
| 74 | + _piece = data |
| 75 | + |
| 76 | + for (idx, path) in enumerate(paths): |
| 77 | + if idx == len(paths) - 1: |
| 78 | + # `path` can be a dictionary key or list index, |
| 79 | + # but either way it is retrieved the same way |
| 80 | + _piece[path] = str(current) |
| 81 | + else: |
| 82 | + _piece = _piece[path] |
| 83 | + |
| 84 | + |
51 | 85 | def dumps(data: dict) -> str: |
52 | 86 | """ |
53 | 87 | Converts the passed-in dictionary to a string representation. |
54 | 88 |
|
55 | 89 | Handles the following objects: dataclass, datetime, enum, float, int, numpy, str, uuid, |
56 | 90 | Django Model, Django QuerySet, any object with `to_json` method. |
57 | 91 | """ |
58 | | - dumped_data = orjson.dumps(data, default=_json_serializer).decode("utf-8") |
| 92 | + serialized_data = orjson.dumps(data, default=_json_serializer) |
| 93 | + dict_data = orjson.loads(serialized_data) |
| 94 | + _fix_floats(dict_data) |
| 95 | + |
| 96 | + dumped_data = orjson.dumps(dict_data).decode("utf-8") |
59 | 97 |
|
60 | 98 | return dumped_data |
0 commit comments