diff --git a/big/int.go b/big/int.go index 4b9cdb80..39ab9d15 100644 --- a/big/int.go +++ b/big/int.go @@ -218,6 +218,40 @@ func (bi *Int) UnmarshalJSON(b []byte) error { return nil } +func (bi *Int) WriteBytes(out io.Writer) (int, error) { + if bi.Int == nil { + return 0, fmt.Errorf("failed to convert to bytes, big is nil") + } + + var wrote int + switch { + case bi.Sign() > 0: + w, err := out.Write([]byte{0}) + if err != nil { + return wrote, err + } + wrote += w + w, err = out.Write(bi.Int.Bytes()) + if err != nil { + return wrote, err + } + wrote += w + case bi.Sign() < 0: + w, err := out.Write([]byte{1}) + if err != nil { + return wrote, err + } + wrote += w + w, err = out.Write(bi.Int.Bytes()) + if err != nil { + return wrote, err + } + wrote += w + default: // bi.Sign() == 0: + } + return wrote, nil +} + func (bi *Int) Bytes() ([]byte, error) { if bi.Int == nil { return []byte{}, fmt.Errorf("failed to convert to bytes, big is nil") @@ -233,6 +267,14 @@ func (bi *Int) Bytes() ([]byte, error) { } } +func (bi *Int) byteLength() int { + if bi.Int == nil || bi.Sign() == 0 { + return 0 + } + + return 1 + (bi.Int.BitLen()+7)/8 +} + func FromBytes(buf []byte) (Int, error) { if len(buf) == 0 { return NewInt(0), nil @@ -281,12 +323,7 @@ func (bi *Int) MarshalCBOR(w io.Writer) error { return zero.MarshalCBOR(w) } - enc, err := bi.Bytes() - if err != nil { - return err - } - - encLen := len(enc) + encLen := bi.byteLength() if encLen > BigIntMaxSerializedLen { return fmt.Errorf("big integer byte array too long (%d bytes)", encLen) } @@ -296,10 +333,11 @@ func (bi *Int) MarshalCBOR(w io.Writer) error { return err } - if _, err := w.Write(enc); err != nil { + if wrote, err := bi.WriteBytes(w); err != nil { return err + } else if wrote != encLen { + return fmt.Errorf("failed to write full big int byte array (%d bytes written, %d expected)", wrote, encLen) } - return nil }