|
1 | 1 | package primitive
|
2 | 2 |
|
3 | 3 | import (
|
4 |
| - "github.com/stretchr/testify/require" |
| 4 | + "encoding/json" |
| 5 | + "fmt" |
5 | 6 | "math/big"
|
6 | 7 | "testing"
|
| 8 | + |
| 9 | + "github.com/stretchr/testify/require" |
| 10 | + "go.mongodb.org/mongo-driver/internal/testutil/assert" |
7 | 11 | )
|
8 | 12 |
|
9 | 13 | type bigIntTestCase struct {
|
@@ -144,3 +148,61 @@ func TestParseDecimal128(t *testing.T) {
|
144 | 148 | }
|
145 | 149 | }
|
146 | 150 | }
|
| 151 | + |
| 152 | +func TestDecimal128_JSON(t *testing.T) { |
| 153 | + t.Run("roundTrip", func(t *testing.T) { |
| 154 | + decimal := NewDecimal128(0x3040000000000000, 12345) |
| 155 | + bytes, err := json.Marshal(decimal) |
| 156 | + assert.Nil(t, err, "json.Marshal error: %v", err) |
| 157 | + got := NewDecimal128(0, 0) |
| 158 | + err = json.Unmarshal([]byte(bytes), &got) |
| 159 | + assert.Nil(t, err, "json.Unmarshal error: %v", err) |
| 160 | + assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h) |
| 161 | + assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l) |
| 162 | + }) |
| 163 | + t.Run("unmarshal extendedJSON", func(t *testing.T) { |
| 164 | + want := NewDecimal128(0x3040000000000000, 12345) |
| 165 | + extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String()) |
| 166 | + |
| 167 | + got := NewDecimal128(0, 0) |
| 168 | + err := json.Unmarshal([]byte(extJSON), &got) |
| 169 | + assert.Nil(t, err, "json.Unmarshal error: %v", err) |
| 170 | + assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h) |
| 171 | + assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l) |
| 172 | + }) |
| 173 | + t.Run("unmarshal null", func(t *testing.T) { |
| 174 | + want := NewDecimal128(0, 0) |
| 175 | + extJSON := `null` |
| 176 | + |
| 177 | + got := NewDecimal128(0, 0) |
| 178 | + err := json.Unmarshal([]byte(extJSON), &got) |
| 179 | + assert.Nil(t, err, "json.Unmarshal error: %v", err) |
| 180 | + assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h) |
| 181 | + assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l) |
| 182 | + }) |
| 183 | + t.Run("unmarshal", func(t *testing.T) { |
| 184 | + cases := append(bigIntTestCases, |
| 185 | + []bigIntTestCase{ |
| 186 | + {s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800}, |
| 187 | + {s: "12345E+21", h: 0x306a000000000000, l: 12345}, |
| 188 | + {s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"}, |
| 189 | + {s: ".125e1", h: 0x303c000000000000, l: 125}, |
| 190 | + {s: ".125", h: 0x303a000000000000, l: 125}, |
| 191 | + }...) |
| 192 | + for _, c := range cases { |
| 193 | + input := fmt.Sprintf(`{"foo": %q}`, c.s) |
| 194 | + var got map[string]Decimal128 |
| 195 | + err := json.Unmarshal([]byte(input), &got) |
| 196 | + |
| 197 | + switch c.remark { |
| 198 | + case "overflow", "parse fail": |
| 199 | + assert.NotNil(t, err, "expected Unmarshal error, got nil") |
| 200 | + default: |
| 201 | + assert.Nil(t, err, "Unmarshal error: %v", err) |
| 202 | + gotDecimal := got["foo"] |
| 203 | + assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l) |
| 204 | + assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h) |
| 205 | + } |
| 206 | + } |
| 207 | + }) |
| 208 | +} |
0 commit comments