Skip to content

Commit 8e1f1ef

Browse files
authored
refactor: decode PlutusData at a lower level to avoid type issues (#107)
Signed-off-by: Aurora Gaffney <[email protected]>
1 parent 678c325 commit 8e1f1ef

File tree

3 files changed

+207
-50
lines changed

3 files changed

+207
-50
lines changed

data/data_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,56 @@ var testDefs = []struct {
4848
),
4949
CborHex: "9f0102ff",
5050
},
51+
{
52+
Data: NewMap(
53+
[][2]PlutusData{
54+
{
55+
NewInteger(big.NewInt(1)),
56+
NewInteger(big.NewInt(2)),
57+
},
58+
},
59+
),
60+
CborHex: "a10102",
61+
},
62+
{
63+
Data: NewMap(
64+
[][2]PlutusData{
65+
{
66+
NewConstr(
67+
0,
68+
NewInteger(big.NewInt(0)),
69+
NewInteger(big.NewInt(406)),
70+
),
71+
NewConstr(
72+
0,
73+
NewInteger(big.NewInt(1725522262478821201)),
74+
),
75+
},
76+
},
77+
),
78+
CborHex: "a1d8799f00190196ffd8799f1b17f2495b03141751ff",
79+
},
80+
{
81+
Data: NewConstr(
82+
0,
83+
NewMap(
84+
[][2]PlutusData{
85+
{
86+
NewConstr(
87+
0,
88+
NewInteger(big.NewInt(0)),
89+
NewInteger(big.NewInt(406)),
90+
),
91+
NewConstr(
92+
0,
93+
NewInteger(big.NewInt(1725522262478821201)),
94+
),
95+
},
96+
},
97+
),
98+
),
99+
CborHex: "d8799fa1d8799f00190196ffd8799f1b17f2495b03141751ffff",
100+
},
51101
}
52102

53103
func TestPlutusDataEncode(t *testing.T) {

data/decode.go

Lines changed: 139 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,62 +8,98 @@ import (
88
"github.com/fxamacker/cbor/v2"
99
)
1010

11+
const (
12+
CborTypeByteString uint8 = 0x40
13+
CborTypeArray uint8 = 0x80
14+
CborTypeMap uint8 = 0xa0
15+
CborTypeTag uint8 = 0xc0
16+
17+
// Only the top 3 bytes are used to specify the type
18+
CborTypeMask uint8 = 0xe0
19+
)
20+
1121
// Decode decodes a CBOR-encoded byte slice into a PlutusData value.
1222
// It returns an error if the input is invalid or not a valid PlutusData encoding.
1323
func Decode(b []byte) (PlutusData, error) {
14-
var raw cbor.RawMessage = b
15-
var v any
16-
17-
if err := cbor.Unmarshal(raw, &v); err != nil {
24+
v, err := decodeCborRaw(b)
25+
if err != nil {
1826
return nil, fmt.Errorf("failed to decode CBOR: %w", err)
1927
}
2028

2129
return decodeRaw(v)
2230
}
2331

32+
// decodeCborRaw is an alternative to cbor.Unmarshal() that converts cbor.Tag to Constr
33+
// This is needed because cbor.Tag with a slice as the content (such as in a Constr) is
34+
// not hashable and cannot be used as a map key
35+
func decodeCborRaw(data []byte) (any, error) {
36+
cborType := data[0] & CborTypeMask
37+
switch cborType {
38+
case CborTypeByteString:
39+
var tmpData cbor.ByteString
40+
if err := cbor.Unmarshal(data, &tmpData); err != nil {
41+
return nil, err
42+
}
43+
return tmpData, nil
44+
case CborTypeArray:
45+
return decodeCborRawList(data)
46+
case CborTypeMap:
47+
return decodeCborRawMap(data)
48+
case CborTypeTag:
49+
var tmpTag cbor.RawTag
50+
if err := cbor.Unmarshal(data, &tmpTag); err != nil {
51+
return nil, err
52+
}
53+
return decodeRawTag(tmpTag)
54+
default:
55+
// Decode using default representation
56+
var tmpData any
57+
if err := cbor.Unmarshal(data, &tmpData); err != nil {
58+
return nil, err
59+
}
60+
return tmpData, nil
61+
}
62+
}
63+
64+
func decodeCborRawList(data []byte) ([]any, error) {
65+
var tmpData []cbor.RawMessage
66+
if err := cbor.Unmarshal(data, &tmpData); err != nil {
67+
return nil, err
68+
}
69+
ret := make([]any, len(tmpData))
70+
for i, item := range tmpData {
71+
tmp, err := decodeCborRaw(item)
72+
if err != nil {
73+
return nil, err
74+
}
75+
ret[i] = tmp
76+
}
77+
return ret, nil
78+
}
79+
80+
func decodeCborRawMap(data []byte) (map[any]any, error) {
81+
var tmpData map[RawMessageStr]RawMessageStr
82+
if err := cbor.Unmarshal(data, &tmpData); err != nil {
83+
return nil, err
84+
}
85+
ret := make(map[any]any, len(tmpData))
86+
for k, v := range tmpData {
87+
tmpKey, err := decodeCborRaw(k.Bytes())
88+
if err != nil {
89+
return nil, err
90+
}
91+
tmpVal, err := decodeCborRaw(v.Bytes())
92+
if err != nil {
93+
return nil, err
94+
}
95+
ret[tmpKey] = tmpVal
96+
}
97+
return ret, nil
98+
}
99+
24100
// decodeRaw converts a raw CBOR-decoded value into PlutusData.
25101
func decodeRaw(v any) (PlutusData, error) {
26102
switch x := v.(type) {
27-
case cbor.Tag:
28-
// Handle tagged data (Constr, Bignum).
29-
switch x.Number {
30-
// Constr with tag 0..6.
31-
case 121, 122, 123, 124, 125, 126, 127:
32-
return decodeConstr(x.Number-121, x.Content)
33-
34-
// Constr with tag 7..127.
35-
case 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287,
36-
1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295,
37-
1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303,
38-
1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311,
39-
1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319,
40-
1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327,
41-
1328, 1329, 1330, 1331, 1332, 1333, 1334, 1335,
42-
1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343,
43-
1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351,
44-
1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359,
45-
1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367,
46-
1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375,
47-
1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383,
48-
1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391,
49-
1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400:
50-
51-
return decodeConstr((x.Number-1280)+7, x.Content)
52-
53-
// PosBignum
54-
case 2:
55-
return decodeBignum(x.Content, false)
56-
57-
// NegBignum
58-
case 3:
59-
return decodeBignum(x.Content, true)
60-
61-
case 102:
62-
return nil, errors.New("tagged data (tag 102) not implemented")
63-
64-
default:
65-
return nil, fmt.Errorf("unknown CBOR tag for PlutusData: %d", x.Number)
66-
}
67103
// Handle List (untagged array).
68104
case []any:
69105
items := make([]PlutusData, len(x))
@@ -113,14 +149,69 @@ func decodeRaw(v any) (PlutusData, error) {
113149
case uint64:
114150
return NewInteger(new(big.Int).SetUint64(x)), nil
115151

152+
case *Constr:
153+
return x, nil
154+
155+
case *Integer:
156+
return x, nil
157+
116158
default:
117159
return nil, fmt.Errorf("unsupported CBOR type for PlutusData: %T", x)
118160
}
119161
}
120162

163+
func decodeRawTag(tag cbor.RawTag) (PlutusData, error) {
164+
var ret PlutusData
165+
var retErr error
166+
// Handle tagged data (Constr, Bignum).
167+
switch tag.Number {
168+
// Constr with tag 0..6.
169+
case 121, 122, 123, 124, 125, 126, 127:
170+
ret, retErr = decodeConstr(tag.Number-121, tag.Content)
171+
172+
// Constr with tag 7..127.
173+
case 1280, 1281, 1282, 1283, 1284, 1285, 1286, 1287,
174+
1288, 1289, 1290, 1291, 1292, 1293, 1294, 1295,
175+
1296, 1297, 1298, 1299, 1300, 1301, 1302, 1303,
176+
1304, 1305, 1306, 1307, 1308, 1309, 1310, 1311,
177+
1312, 1313, 1314, 1315, 1316, 1317, 1318, 1319,
178+
1320, 1321, 1322, 1323, 1324, 1325, 1326, 1327,
179+
1328, 1329, 1330, 1331, 1332, 1333, 1334, 1335,
180+
1336, 1337, 1338, 1339, 1340, 1341, 1342, 1343,
181+
1344, 1345, 1346, 1347, 1348, 1349, 1350, 1351,
182+
1352, 1353, 1354, 1355, 1356, 1357, 1358, 1359,
183+
1360, 1361, 1362, 1363, 1364, 1365, 1366, 1367,
184+
1368, 1369, 1370, 1371, 1372, 1373, 1374, 1375,
185+
1376, 1377, 1378, 1379, 1380, 1381, 1382, 1383,
186+
1384, 1385, 1386, 1387, 1388, 1389, 1390, 1391,
187+
1392, 1393, 1394, 1395, 1396, 1397, 1398, 1399, 1400:
188+
189+
ret, retErr = decodeConstr((tag.Number-1280)+7, tag.Content)
190+
191+
// PosBignum
192+
case 2:
193+
ret, retErr = decodeBignum(tag.Content, false)
194+
195+
// NegBignum
196+
case 3:
197+
ret, retErr = decodeBignum(tag.Content, true)
198+
199+
case 102:
200+
return nil, errors.New("tagged data (tag 102) not implemented")
201+
202+
default:
203+
return nil, fmt.Errorf("unknown CBOR tag for PlutusData: %d", tag.Number)
204+
}
205+
return ret, retErr
206+
}
207+
121208
// decodeConstr decodes a Constr from a CBOR tag content (expected to be an array).
122-
func decodeConstr(tag uint64, content any) (PlutusData, error) {
123-
arr, ok := content.([]any)
209+
func decodeConstr(tag uint64, content cbor.RawMessage) (PlutusData, error) {
210+
tmpData, err := decodeCborRaw(content)
211+
if err != nil {
212+
return nil, err
213+
}
214+
arr, ok := tmpData.([]any)
124215
if !ok {
125216
return nil, fmt.Errorf(
126217
"expected array for Constr tag %d, got %T",
@@ -149,13 +240,13 @@ func decodeConstr(tag uint64, content any) (PlutusData, error) {
149240

150241
// decodeBignum decodes a big integer from CBOR tag content (expected to be bytes).
151242
func decodeBignum(content any, negative bool) (PlutusData, error) {
152-
bytes, ok := content.([]byte)
243+
bytes, ok := content.(cbor.RawMessage)
153244
if !ok {
154245
return nil, fmt.Errorf("expected bytes for Bignum, got %T", content)
155246
}
156247

157248
// Convert bytes to big.Int (assuming big-endian, as in Rust's rug::Integer::from_digits).
158-
n := new(big.Int).SetBytes(bytes)
249+
n := new(big.Int).SetBytes([]byte(bytes))
159250

160251
if negative {
161252
n.Neg(n)

data/encode.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func encodeConstr(c *Constr) (any, error) {
6666
// End indefinite-length list
6767
0xff,
6868
)
69-
fields = cbor.RawMessage(tmpData)
69+
fields = RawMessageStr(string(tmpData))
7070
}
7171

7272
// Determine CBOR tag based on Constr tag value
@@ -181,5 +181,21 @@ func encodeList(l *List) (any, error) {
181181
// End indefinite-length list
182182
0xff,
183183
)
184-
return cbor.RawMessage(tmpData), nil
184+
return RawMessageStr(string(tmpData)), nil
185+
}
186+
187+
// RawMessageStr is a hashable variant of cbor.RawMessage
188+
type RawMessageStr string
189+
190+
func (r *RawMessageStr) UnmarshalCBOR(data []byte) error {
191+
*r = RawMessageStr(string(data))
192+
return nil
193+
}
194+
195+
func (r RawMessageStr) MarshalCBOR() ([]byte, error) {
196+
return []byte(r), nil
197+
}
198+
199+
func (r RawMessageStr) Bytes() []byte {
200+
return []byte(r)
185201
}

0 commit comments

Comments
 (0)