Skip to content

Commit 2aec744

Browse files
authored
feat: decode for PROTO_PAIR and PROTO_LIST (#127)
Fixes #126 Signed-off-by: Aurora Gaffney <[email protected]>
1 parent 7030e4a commit 2aec744

File tree

2 files changed

+99
-14
lines changed

2 files changed

+99
-14
lines changed

cek/runtime.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ func unwrapList[T syn.Eval](
311311
switch c := v.Constant.(type) {
312312
case *syn.ProtoList:
313313
if typ != nil && !reflect.DeepEqual(typ, c.LTyp) {
314-
return nil, fmt.Errorf("Value not a List of type %v", typ)
314+
return nil, fmt.Errorf("Value not a List of type %T", typ)
315315
}
316316

317317
i = c

syn/flat_decode.go

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package syn
22

33
import (
4+
"bytes"
45
"errors"
56
"fmt"
67
"math/big"
@@ -189,12 +190,19 @@ func DecodeConstant(d *decoder) (IConstant, error) {
189190
if err != nil {
190191
return nil, err
191192
}
193+
typ, err := decodeConstantType(bytes.NewBuffer(tags))
194+
if err != nil {
195+
return nil, err
196+
}
197+
return decodeConstantValue(d, typ)
198+
}
192199

200+
func decodeConstantValue(d *decoder, typ Typ) (IConstant, error) {
193201
var constant IConstant
194202

195-
switch {
203+
switch t := typ.(type) {
196204
// Integer
197-
case len(tags) == 1 && tags[0] == IntegerTag:
205+
case *TInteger:
198206
i, err := d.integer()
199207
if err != nil {
200208
return nil, err
@@ -203,7 +211,7 @@ func DecodeConstant(d *decoder) (IConstant, error) {
203211
constant = &Integer{i}
204212

205213
// ByteString
206-
case len(tags) == 1 && tags[0] == ByteStringTag:
214+
case *TByteString:
207215
b, err := d.bytes()
208216
if err != nil {
209217
return nil, err
@@ -212,7 +220,7 @@ func DecodeConstant(d *decoder) (IConstant, error) {
212220
constant = &ByteString{b}
213221

214222
// String
215-
case len(tags) == 1 && tags[0] == StringTag:
223+
case *TString:
216224
s, err := d.utf8()
217225
if err != nil {
218226
return nil, err
@@ -221,11 +229,11 @@ func DecodeConstant(d *decoder) (IConstant, error) {
221229
constant = &String{s}
222230

223231
// Unit
224-
case len(tags) == 1 && tags[0] == UnitTag:
232+
case *TUnit:
225233
constant = &Unit{}
226234

227235
// Bool
228-
case len(tags) == 1 && tags[0] == BoolTag:
236+
case *TBool:
229237
v, err := d.bit()
230238
if err != nil {
231239
return nil, err
@@ -234,17 +242,35 @@ func DecodeConstant(d *decoder) (IConstant, error) {
234242
constant = &Bool{v}
235243

236244
// ProtoList
237-
case len(tags) >= 2 && tags[0] == ProtoListOneTag && tags[1] == ProtoListTwoTag:
238-
// Handle PROTO_LIST_ONE, PROTO_LIST_TWO, rest...
239-
return nil, errors.New("unimplemented: PROTO_LIST")
245+
case *TList:
246+
items, err := DecodeList(d, func(d *decoder) (IConstant, error) { return decodeConstantValue(d, t.Typ) })
247+
if err != nil {
248+
return nil, err
249+
}
250+
constant = &ProtoList{
251+
LTyp: t.Typ,
252+
List: items,
253+
}
240254

241255
// ProtoPair
242-
case len(tags) >= 3 && tags[0] == ProtoPairOneTag && tags[1] == ProtoPairTwoTag && tags[2] == ProtoPairThreeTag:
243-
// Handle PROTO_PAIR_ONE, PROTO_PAIR_TWO, PROTO_PAIR_THREE, rest...
244-
return nil, errors.New("unimplemented: PROTO_PAIR")
256+
case *TPair:
257+
first, err := decodeConstantValue(d, t.First)
258+
if err != nil {
259+
return nil, err
260+
}
261+
second, err := decodeConstantValue(d, t.Second)
262+
if err != nil {
263+
return nil, err
264+
}
265+
constant = &ProtoPair{
266+
FstType: t.First,
267+
SndType: t.Second,
268+
First: first,
269+
Second: second,
270+
}
245271

246272
// Data
247-
case len(tags) == 1 && tags[0] == DataTag:
273+
case *TData:
248274
cborBytes, err := d.bytes()
249275
if err != nil {
250276
return nil, err
@@ -264,6 +290,65 @@ func DecodeConstant(d *decoder) (IConstant, error) {
264290
return constant, nil
265291
}
266292

293+
func decodeConstantType(tags *bytes.Buffer) (Typ, error) {
294+
next, err := tags.ReadByte()
295+
if err != nil {
296+
return nil, err
297+
}
298+
switch next {
299+
case IntegerTag:
300+
return &TInteger{}, nil
301+
case ByteStringTag:
302+
return &TByteString{}, nil
303+
case StringTag:
304+
return &TString{}, nil
305+
case UnitTag:
306+
return &TUnit{}, nil
307+
case BoolTag:
308+
return &TBool{}, nil
309+
case DataTag:
310+
return &TData{}, nil
311+
// NOTE: this also covers ProtoPairOneTag, but it's the same value as ProtoListOneTag, and
312+
// the compiler doesn't like them both being present in the 'case'
313+
case ProtoListOneTag:
314+
next, err = tags.ReadByte()
315+
if err != nil {
316+
return nil, err
317+
}
318+
switch next {
319+
case ProtoListTwoTag:
320+
subType, err := decodeConstantType(tags)
321+
if err != nil {
322+
return nil, err
323+
}
324+
return &TList{
325+
Typ: subType,
326+
}, nil
327+
case ProtoPairTwoTag:
328+
next, err = tags.ReadByte()
329+
if err != nil {
330+
return nil, err
331+
}
332+
switch next {
333+
case ProtoPairThreeTag:
334+
subType1, err := decodeConstantType(tags)
335+
if err != nil {
336+
return nil, err
337+
}
338+
subType2, err := decodeConstantType(tags)
339+
if err != nil {
340+
return nil, err
341+
}
342+
return &TPair{
343+
First: subType1,
344+
Second: subType2,
345+
}, nil
346+
}
347+
}
348+
}
349+
return nil, errors.New("unknown type tag")
350+
}
351+
267352
func decodeConstantTags(d *decoder) ([]byte, error) {
268353
return DecodeList(d, decodeConstantTag)
269354
}

0 commit comments

Comments
 (0)