|
1 | 1 | import unittest |
| 2 | +from unittest.mock import MagicMock |
2 | 3 |
|
3 | 4 | import orjson |
4 | 5 | import pytest |
@@ -123,3 +124,67 @@ def test_tensor_special_values(self): |
123 | 124 | orjson.dumps(data, option=orjson.OPT_SERIALIZE_NUMPY), |
124 | 125 | b'{"nan":NaN,"inf":Infinity,"neg_inf":-Infinity,"mixed":[1.0,NaN,Infinity,-Infinity]}' |
125 | 126 | ) |
| 127 | + |
| 128 | + def test_tensor_in_list(self): |
| 129 | + """PyTorch tensor as element in a Python list""" |
| 130 | + assert orjson.dumps([torch.tensor([1, 2])], option=orjson.OPT_SERIALIZE_NUMPY) == b'[[1,2]]' |
| 131 | + |
| 132 | + def test_tensor_3d(self): |
| 133 | + """3D tensor""" |
| 134 | + tensor = torch.zeros(2, 3, 4) |
| 135 | + result = orjson.loads(orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY)) |
| 136 | + assert len(result) == 2 and len(result[0]) == 3 and len(result[0][0]) == 4 |
| 137 | + |
| 138 | + def test_tensor_dtypes(self): |
| 139 | + """Various tensor dtypes""" |
| 140 | + for dtype in [torch.float16, torch.float64, torch.int8, torch.int16, torch.int32]: |
| 141 | + tensor = torch.tensor([1, 2, 3], dtype=dtype) |
| 142 | + result = orjson.loads(orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY)) |
| 143 | + for i, v in enumerate(result): |
| 144 | + assert abs(v - [1, 2, 3][i]) < 0.01 |
| 145 | + |
| 146 | + def test_non_torch_duck_type(self): |
| 147 | + """Object with numpy/cpu/detach but __module__ not 'torch' is not treated as tensor""" |
| 148 | + class FakeTensor: |
| 149 | + def numpy(self): return [1, 2] |
| 150 | + def cpu(self): return self |
| 151 | + def detach(self): return self |
| 152 | + with self.assertRaises(orjson.JSONEncodeError): |
| 153 | + orjson.dumps(FakeTensor(), option=orjson.OPT_SERIALIZE_NUMPY) |
| 154 | + |
| 155 | + def test_magicmock_not_tensor(self): |
| 156 | + """MagicMock not detected as PyTorch tensor (post4 fix)""" |
| 157 | + with self.assertRaises(orjson.JSONEncodeError): |
| 158 | + orjson.dumps(MagicMock(), option=orjson.OPT_SERIALIZE_NUMPY) |
| 159 | + |
| 160 | + def test_tensor_pretty(self): |
| 161 | + """PyTorch tensor with OPT_INDENT_2""" |
| 162 | + tensor = torch.tensor([[1, 2], [3, 4]]) |
| 163 | + result = orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_INDENT_2) |
| 164 | + assert result == b'[\n [\n 1,\n 2\n ],\n [\n 3,\n 4\n ]\n]' |
| 165 | + |
| 166 | + def test_tensor_conversion_failure(self): |
| 167 | + """Sparse tensor fails numpy conversion - PyTorchTensorConversion error""" |
| 168 | + t = torch.sparse_coo_tensor(torch.tensor([[0, 1]]), torch.tensor([1.0, 2.0]), (3,)) |
| 169 | + with self.assertRaises(orjson.JSONEncodeError) as cm: |
| 170 | + orjson.dumps(t, option=orjson.OPT_SERIALIZE_NUMPY) |
| 171 | + assert "failed to convert PyTorch tensor to numpy array" in str(cm.exception) |
| 172 | + |
| 173 | + def test_tensor_conversion_failure_with_default(self): |
| 174 | + """Sparse tensor with default callback falls back to default""" |
| 175 | + t = torch.sparse_coo_tensor(torch.tensor([[0, 1]]), torch.tensor([1.0, 2.0]), (3,)) |
| 176 | + result = orjson.dumps(t, option=orjson.OPT_SERIALIZE_NUMPY, default=lambda x: "fallback") |
| 177 | + assert result == b'"fallback"' |
| 178 | + |
| 179 | + def test_tensor_unsupported_numpy_dtype(self): |
| 180 | + """Complex tensor: numpy() succeeds but numpy dtype is unsupported""" |
| 181 | + tensor = torch.tensor([1+2j, 3+4j]) |
| 182 | + with self.assertRaises(orjson.JSONEncodeError) as cm: |
| 183 | + orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY) |
| 184 | + assert "unsupported datatype in numpy array" in str(cm.exception) |
| 185 | + |
| 186 | + def test_tensor_unsupported_numpy_dtype_with_default(self): |
| 187 | + """Complex tensor with default: falls back to default via numpy unsupported path""" |
| 188 | + tensor = torch.tensor([1+2j, 3+4j]) |
| 189 | + result = orjson.dumps(tensor, option=orjson.OPT_SERIALIZE_NUMPY, default=lambda x: str(x)) |
| 190 | + assert len(result) > 0 |
0 commit comments