Skip to content

Commit d13ec50

Browse files
committed
clean up more
1 parent 834fa67 commit d13ec50

File tree

3 files changed

+82
-92
lines changed

3 files changed

+82
-92
lines changed

formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/CborParserInterface.kt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,4 @@ internal sealed interface CborParserInterface {
3737

3838
// Tag verification
3939
fun verifyTagsAndThrow(expected: ULongArray, actual: ULongArray?)
40-
41-
// Additional methods needed for CborTreeReader
42-
fun nextTag(): ULong
4340
}

formats/cbor/commonMain/src/kotlinx/serialization/cbor/internal/Decoder.kt

Lines changed: 81 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ internal open class CborReader(override val cbor: Cbor, protected val parser: Cb
1616
CborDecoder {
1717

1818
override fun decodeCborElement(): CborElement =
19-
when(parser) {
20-
is CborParser -> CborTreeReader(cbor.configuration, parser).read()
21-
is StructuredCborParser -> parser.element
19+
when (parser) {
20+
is CborParser -> CborTreeReader(cbor.configuration, parser).read()
21+
is StructuredCborParser -> parser.element
2222
}
2323

2424

@@ -58,7 +58,7 @@ internal open class CborReader(override val cbor: Cbor, protected val parser: Cb
5858
}
5959

6060
override fun endStructure(descriptor: SerialDescriptor) {
61-
if (!finiteMode) parser.end()
61+
if (!finiteMode || parser is StructuredCborParser) parser.end()
6262
}
6363

6464
override fun decodeElementIndex(descriptor: SerialDescriptor): Int {
@@ -158,7 +158,8 @@ internal open class CborReader(override val cbor: Cbor, protected val parser: Cb
158158
}
159159
}
160160

161-
internal class CborParser(private val input: ByteArrayInput, private val verifyObjectTags: Boolean) : CborParserInterface {
161+
internal class CborParser(private val input: ByteArrayInput, private val verifyObjectTags: Boolean) :
162+
CborParserInterface {
162163
var curByte: Int = -1
163164

164165
init {
@@ -196,7 +197,7 @@ internal class CborParser(private val input: ByteArrayInput, private val verifyO
196197
}
197198
}
198199

199-
override fun nextTag(): ULong {
200+
fun nextTag(): ULong {
200201
if ((curByte shr 5) != 6) {
201202
throw CborDecodingException("Expected tag (major type 6), got major type ${curByte shr 5}")
202203
}
@@ -554,165 +555,156 @@ private fun Iterable<ByteArray>.flatten(): ByteArray {
554555
return output
555556
}
556557

557-
private typealias ElementHolder = Pair<MutableList<ULong>, CborElement>
558-
private val ElementHolder.tags: MutableList<ULong> get() = first
559-
private val ElementHolder.element: CborElement get() = second
560-
internal class StructuredCborParser(val element: CborElement, private val verifyObjectTags: Boolean) : CborParserInterface {
561-
562558

563-
internal var current: ElementHolder = element.tags.toMutableList() to element
564-
private var listIterator: Iterator<CborElement>? = null
559+
internal class StructuredCborParser(internal val element: CborElement, private val verifyObjectTags: Boolean) :
560+
CborParserInterface {
565561

566-
// Implementation of methods needed for CborTreeReader
567-
override fun nextTag(): ULong {
568-
if (current.tags.isEmpty()) {
569-
throw CborDecodingException("Expected tag, but no tags found on current element")
570-
}
571-
return current.tags.removeFirst()
572-
}
573-
574-
override fun isNull() : Boolean {
575-
//TODO this is a bit wonky! if we are inside a map, we want to skip over the key, and check the value,
576-
// so the below call is not what it should be!
577-
processTags(null)
578-
return current.element is CborNull
562+
563+
internal var currentElement = element
564+
private var listIterator: ListIterator<CborElement>? = null
565+
private var isMap = false
566+
private val isMapStack = ArrayDeque<Boolean>()
567+
private val layerStack = ArrayDeque<ListIterator<CborElement>?>()
568+
569+
override fun isNull(): Boolean {
570+
return if (isMap) {
571+
val isNull = listIterator!!.next() is CborNull
572+
listIterator!!.previous()
573+
isNull
574+
} else currentElement is CborNull
579575
}
580-
576+
581577
override fun isEnd() = when {
582578
listIterator != null -> !listIterator!!.hasNext()
583579
else -> false
584580
}
585-
581+
586582
override fun end() {
587583
// Reset iterators when ending a structure
588-
listIterator = null
584+
isMap = isMapStack.removeLast()
585+
listIterator = layerStack.removeLast()
589586
}
590587

591588
override fun startArray(tags: ULongArray?): Int {
592589
processTags(tags)
593-
if (current.element !is CborList) {
594-
throw CborDecodingException("Expected array, got ${current.element::class.simpleName}")
595-
}
596-
597-
val list = current.element as CborList
598-
listIterator = list.iterator()
590+
if (currentElement !is CborList) {
591+
throw CborDecodingException("Expected array, got ${currentElement::class.simpleName}")
592+
}
593+
isMapStack+=isMap
594+
layerStack+=listIterator
595+
isMap = false
596+
val list = currentElement as CborList
597+
listIterator = list.listIterator()
599598
return list.size
600599
}
601-
600+
602601
override fun startMap(tags: ULongArray?): Int {
603602
processTags(tags)
604-
if (current.element !is CborMap) {
605-
throw CborDecodingException("Expected map, got ${current.element::class.simpleName}")
603+
if (currentElement !is CborMap) {
604+
throw CborDecodingException("Expected map, got ${currentElement::class.simpleName}")
606605
}
607-
608-
val map = current.element as CborMap
606+
layerStack+=listIterator
607+
isMapStack+=isMap
608+
isMap = true
609+
610+
val map = currentElement as CborMap
609611
//zip key, value, key, value, ... pairs to mirror byte-layout of CBOR map
610-
listIterator = map.entries.flatMap { listOf(it.key, it.value) }.iterator()
612+
listIterator = map.entries.flatMap { listOf(it.key, it.value) }.listIterator()
611613
return map.size //cbor map size is the size of the map, not the doubled size of the flattened pairs
612614
}
613-
615+
614616
override fun nextNull(tags: ULongArray?): Nothing? {
615617
processTags(tags)
616-
if (current.element !is CborNull) {
617-
throw CborDecodingException("Expected null, got ${current.element::class.simpleName}")
618+
if (currentElement !is CborNull) {
619+
throw CborDecodingException("Expected null, got ${currentElement::class.simpleName}")
618620
}
619621
return null
620622
}
621-
623+
622624
override fun nextBoolean(tags: ULongArray?): Boolean {
623625
processTags(tags)
624-
if (current.element !is CborBoolean) {
625-
throw CborDecodingException("Expected boolean, got ${current.element::class.simpleName}")
626+
if (currentElement !is CborBoolean) {
627+
throw CborDecodingException("Expected boolean, got ${currentElement::class.simpleName}")
626628
}
627-
return (current.element as CborBoolean).value
629+
return (currentElement as CborBoolean).value
628630
}
629-
631+
630632
override fun nextNumber(tags: ULongArray?): Long {
631633
processTags(tags)
632-
return when (current.element) {
633-
is CborPositiveInt -> (current.element as CborPositiveInt).value.toLong()
634-
is CborNegativeInt -> (current.element as CborNegativeInt).value
635-
else -> throw CborDecodingException("Expected number, got ${current.element::class.simpleName}")
634+
return when (currentElement) {
635+
is CborPositiveInt -> (currentElement as CborPositiveInt).value.toLong()
636+
is CborNegativeInt -> (currentElement as CborNegativeInt).value
637+
else -> throw CborDecodingException("Expected number, got ${currentElement::class.simpleName}")
636638
}
637639
}
638-
640+
639641
override fun nextString(tags: ULongArray?): String {
640642
processTags(tags)
641-
642-
// Special handling for polymorphic serialization
643-
// If we have a CborList with a string as first element, return that string
644-
if (current.element is CborList && (current.element as CborList).isNotEmpty() && (current.element as CborList)[0] is CborString) {
645-
val stringElement = (current.element as CborList)[0] as CborString
646-
// Move to the next element (the map) for subsequent operations
647-
current = (current.element as CborList)[1].tags.toMutableList() to (current.element as CborList)[1]
648-
return stringElement.value
649-
}
650-
651-
if (current.element !is CborString) {
652-
throw CborDecodingException("Expected string, got ${current.element::class.simpleName}")
653-
}
654-
return (current.element as CborString).value
655-
}
656-
643+
if (currentElement !is CborString) {
644+
throw CborDecodingException("Expected string, got ${currentElement::class.simpleName}")
645+
}
646+
return (currentElement as CborString).value
647+
}
648+
657649
override fun nextByteString(tags: ULongArray?): ByteArray {
658650
processTags(tags)
659-
if (current.element !is CborByteString) {
660-
throw CborDecodingException("Expected byte string, got ${current.element::class.simpleName}")
651+
if (currentElement !is CborByteString) {
652+
throw CborDecodingException("Expected byte string, got ${currentElement::class.simpleName}")
661653
}
662-
return (current.element as CborByteString).value
654+
return (currentElement as CborByteString).value
663655
}
664-
656+
665657
override fun nextDouble(tags: ULongArray?): Double {
666658
processTags(tags)
667-
return when (current.element) {
668-
is CborDouble -> (current.element as CborDouble).value
669-
else -> throw CborDecodingException("Expected double, got ${current.element::class.simpleName}")
659+
return when (currentElement) {
660+
is CborDouble -> (currentElement as CborDouble).value
661+
else -> throw CborDecodingException("Expected double, got ${currentElement::class.simpleName}")
670662
}
671663
}
672-
664+
673665
override fun nextFloat(tags: ULongArray?): Float {
674666
return nextDouble(tags).toFloat()
675667
}
676-
668+
677669
override fun nextTaggedStringOrNumber(): Triple<String?, Long?, ULongArray?> {
678670
val tags = processTags(null)
679-
680-
return when (val key = current.element) {
671+
672+
return when (val key = currentElement) {
681673
is CborString -> Triple(key.value, null, tags)
682674
is CborPositiveInt -> Triple(null, key.value.toLong(), tags)
683675
is CborNegativeInt -> Triple(null, key.value, tags)
684676
else -> throw CborDecodingException("Expected string or number key, got ${key?.let { it::class.simpleName } ?: "null"}")
685677
}
686678
}
687-
679+
688680
private fun processTags(tags: ULongArray?): ULongArray? {
689681

690682
// If we're in a list, advance to the next element
691683
if (listIterator != null && listIterator!!.hasNext()) {
692-
listIterator!!.next().let { current = it.tags.toMutableList() to it }
684+
currentElement= listIterator!!.next()
693685
}
694-
686+
695687
// Store collected tags for verification
696-
val collectedTags = if (current.tags.isEmpty()) null else current.tags.toULongArray()
697-
688+
val collectedTags = if (currentElement.tags.isEmpty()) null else currentElement.tags
689+
698690
// Verify tags if needed
699691
if (verifyObjectTags) {
700692
tags?.let {
701693
verifyTagsAndThrow(it, collectedTags)
702694
}
703695
}
704-
696+
705697
return collectedTags
706698
}
707-
699+
708700
override fun verifyTagsAndThrow(expected: ULongArray, actual: ULongArray?) {
709701
if (!expected.contentEquals(actual)) {
710702
throw CborDecodingException(
711703
"CBOR tags ${actual?.contentToString()} do not match expected tags ${expected.contentToString()}"
712704
)
713705
}
714706
}
715-
707+
716708
override fun skipElement(tags: ULongArray?) {
717709
// Process tags but don't do anything with the element
718710
processTags(tags)

formats/cbor/commonTest/src/kotlinx/serialization/cbor/CborDecoderTest.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class CborDecoderTest {
4747
)
4848

4949
val struct = Cbor.decodeFromHexString<CborElement>(hex)
50+
assertEquals(Cbor.encodeToCbor(test), struct)
5051
assertEquals(test, Cbor.decodeFromCbor(TypesUmbrella.serializer(), struct))
5152

5253

0 commit comments

Comments
 (0)