Skip to content

Commit 6e83de7

Browse files
committed
GODRIVER-1199 Add UnmarshalJSON and MarshalJSON for primitive.Decimal128 (#556)
1 parent ba6725f commit 6e83de7

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

bson/primitive/decimal.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
package primitive
1111

1212
import (
13+
"encoding/json"
1314
"errors"
1415
"fmt"
1516
"math/big"
@@ -211,6 +212,49 @@ func (d Decimal128) IsZero() bool {
211212
return d.h == 0 && d.l == 0
212213
}
213214

215+
// MarshalJSON returns Decimal128 as a string.
216+
func (d Decimal128) MarshalJSON() ([]byte, error) {
217+
return json.Marshal(d.String())
218+
}
219+
220+
// UnmarshalJSON creates a primitive.Decimal128 from a JSON string, an extended JSON $numberDecimal value, or the string
221+
// "null". If b is a JSON string or extended JSON value, d will have the value of that string, and if b is "null", d will
222+
// be unchanged.
223+
func (d *Decimal128) UnmarshalJSON(b []byte) error {
224+
// Ignore "null" to keep parity with the standard library. Decoding a JSON null into a non-pointer Decimal128 field
225+
// will leave the field unchanged. For pointer values, encoding/json will set the pointer to nil and will not
226+
// enter the UnmarshalJSON hook.
227+
if string(b) == "null" {
228+
return nil
229+
}
230+
231+
var res interface{}
232+
err := json.Unmarshal(b, &res)
233+
if err != nil {
234+
return err
235+
}
236+
str, ok := res.(string)
237+
238+
// Extended JSON
239+
if !ok {
240+
m, ok := res.(map[string]interface{})
241+
if !ok {
242+
return errors.New("not an extended JSON Decimal128: expected document")
243+
}
244+
d128, ok := m["$numberDecimal"]
245+
if !ok {
246+
return errors.New("not an extended JSON Decimal128: expected key $numberDecimal")
247+
}
248+
str, ok = d128.(string)
249+
if !ok {
250+
return errors.New("not an extended JSON Decimal128: expected decimal to be string")
251+
}
252+
}
253+
254+
*d, err = ParseDecimal128(str)
255+
return err
256+
}
257+
214258
func divmod(h, l uint64, div uint32) (qh, ql uint64, rem uint32) {
215259
div64 := uint64(div)
216260
a := h >> 32

bson/primitive/decimal_test.go

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package primitive
22

33
import (
4-
"github.com/stretchr/testify/require"
4+
"encoding/json"
5+
"fmt"
56
"math/big"
67
"testing"
8+
9+
"github.com/stretchr/testify/require"
10+
"go.mongodb.org/mongo-driver/internal/testutil/assert"
711
)
812

913
type bigIntTestCase struct {
@@ -144,3 +148,61 @@ func TestParseDecimal128(t *testing.T) {
144148
}
145149
}
146150
}
151+
152+
func TestDecimal128_JSON(t *testing.T) {
153+
t.Run("roundTrip", func(t *testing.T) {
154+
decimal := NewDecimal128(0x3040000000000000, 12345)
155+
bytes, err := json.Marshal(decimal)
156+
assert.Nil(t, err, "json.Marshal error: %v", err)
157+
got := NewDecimal128(0, 0)
158+
err = json.Unmarshal([]byte(bytes), &got)
159+
assert.Nil(t, err, "json.Unmarshal error: %v", err)
160+
assert.Equal(t, decimal.h, got.h, "expected h: %v got: %v", decimal.h, got.h)
161+
assert.Equal(t, decimal.l, got.l, "expected l: %v got: %v", decimal.l, got.l)
162+
})
163+
t.Run("unmarshal extendedJSON", func(t *testing.T) {
164+
want := NewDecimal128(0x3040000000000000, 12345)
165+
extJSON := fmt.Sprintf(`{"$numberDecimal": %q}`, want.String())
166+
167+
got := NewDecimal128(0, 0)
168+
err := json.Unmarshal([]byte(extJSON), &got)
169+
assert.Nil(t, err, "json.Unmarshal error: %v", err)
170+
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
171+
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
172+
})
173+
t.Run("unmarshal null", func(t *testing.T) {
174+
want := NewDecimal128(0, 0)
175+
extJSON := `null`
176+
177+
got := NewDecimal128(0, 0)
178+
err := json.Unmarshal([]byte(extJSON), &got)
179+
assert.Nil(t, err, "json.Unmarshal error: %v", err)
180+
assert.Equal(t, want.h, got.h, "expected h: %v got: %v", want.h, got.h)
181+
assert.Equal(t, want.l, got.l, "expected l: %v got: %v", want.l, got.l)
182+
})
183+
t.Run("unmarshal", func(t *testing.T) {
184+
cases := append(bigIntTestCases,
185+
[]bigIntTestCase{
186+
{s: "-0001231.453454000000565600000000E-21", h: 0xafe6000003faa269, l: 0x81cfeceaabdb1800},
187+
{s: "12345E+21", h: 0x306a000000000000, l: 12345},
188+
{s: "0.10000000000000000000000000000000000000000001", remark: "parse fail"},
189+
{s: ".125e1", h: 0x303c000000000000, l: 125},
190+
{s: ".125", h: 0x303a000000000000, l: 125},
191+
}...)
192+
for _, c := range cases {
193+
input := fmt.Sprintf(`{"foo": %q}`, c.s)
194+
var got map[string]Decimal128
195+
err := json.Unmarshal([]byte(input), &got)
196+
197+
switch c.remark {
198+
case "overflow", "parse fail":
199+
assert.NotNil(t, err, "expected Unmarshal error, got nil")
200+
default:
201+
assert.Nil(t, err, "Unmarshal error: %v", err)
202+
gotDecimal := got["foo"]
203+
assert.Equal(t, c.h, gotDecimal.h, "expected h: %v got: %v", c.h, gotDecimal.l)
204+
assert.Equal(t, c.l, gotDecimal.l, "expected l: %v got: %v", c.l, gotDecimal.h)
205+
}
206+
}
207+
})
208+
}

0 commit comments

Comments
 (0)