Skip to content

Commit 9e82ade

Browse files
committed
More efficient codec derivation
1 parent abe7982 commit 9e82ade

File tree

3 files changed

+273
-214
lines changed

3 files changed

+273
-214
lines changed

jsoniter-scala-macros/jvm/src/test/scala/com/github/plokhotnyuk/jsoniter_scala/macros/JsonCodecMakerJVMSpec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@ class JsonCodecMakerJVMSpec extends VerifyingSpec {
2828
val json = "{" + "\"n\":{" * 1000000 + "}" * 1000000 + "}"
2929
val readStackTrace = TestUtils.assertStackOverflow(readFromString[Nested](json))
3030
assert(readStackTrace.contains("d0"))
31-
assert(!readStackTrace.contains("d1"))
31+
assert(!readStackTrace.contains("d2"))
3232
assert(!readStackTrace.contains("decodeValue"))
3333
val writeStackTrace = TestUtils.assertStackOverflow(writeToString[Nested](construct()))
34-
assert(writeStackTrace.contains("e0"))
35-
assert(!writeStackTrace.contains("e1"))
34+
assert(writeStackTrace.contains("e1"))
35+
assert(!writeStackTrace.contains("e3"))
3636
assert(!writeStackTrace.contains("encodeValue"))
3737
}
3838
}

jsoniter-scala-macros/shared/src/main/scala-2/com/github/plokhotnyuk/jsoniter_scala/macros/JsonCodecMaker.scala

Lines changed: 142 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -775,57 +775,80 @@ object JsonCodecMaker {
775775
}
776776
}
777777

778+
sealed trait TypeInfo
779+
780+
case class JavaEnumValueInfo(value: Tree, name: String)
781+
782+
case class JavaEnumInfo(valueInfos: List[JavaEnumValueInfo], hasTransformed: Boolean, doEncoding: Boolean) extends TypeInfo {
783+
val doLinearSearch: Boolean = valueInfos.size <= 8 && valueInfos.foldLeft(0)(_ + _.name.length) <= 64
784+
}
785+
786+
case class FieldInfo(symbol: TermSymbol, mappedName: String, tmpName: TermName, getter: MethodSymbol,
787+
defaultValue: Option[Tree], resolvedTpe: Type, isStringified: Boolean)
788+
789+
case class ClassInfo(tpe: Type, paramLists: List[List[FieldInfo]]) extends TypeInfo {
790+
val fields: List[FieldInfo] = paramLists.flatten
791+
}
792+
793+
sealed trait TermNameKey
794+
795+
case class DecoderMethodKey(tpe: Type, isStringified: Boolean, discriminator: Tree) extends TermNameKey
796+
797+
case class EncoderMethodKey(tpe: Type, isStringified: Boolean, discriminator: Tree) extends TermNameKey
798+
799+
case class EqualsMethodKey(tpe: Type) extends TermNameKey
800+
801+
case class FieldIndexMethodKey(tpe: Type) extends TermNameKey
802+
803+
case class NullValueKey(tpe: Type) extends TermNameKey
804+
805+
case class ScalaEnumValueKey(tpe: Type) extends TermNameKey
806+
807+
case class MathContextValueKey(precision: Int) extends TermNameKey
808+
778809
val rootTpe = weakTypeOf[A].dealias
779810
val inferredKeyCodecs = mutable.Map[Type, Tree]((rootTpe, EmptyTree))
811+
val inferredValueCodecs = mutable.Map[Type, Tree]((rootTpe, EmptyTree))
812+
val typeInfos = new mutable.HashMap[Type, TypeInfo]
813+
val termNames = new mutable.HashMap[TermNameKey, TermName]
814+
val trees = new mutable.ArrayBuffer[Tree]
780815

781816
def findImplicitKeyCodec(tpe: Type): Tree = inferredKeyCodecs.getOrElseUpdate(tpe, {
782817
c.inferImplicitValue(c.typecheck(tq"com.github.plokhotnyuk.jsoniter_scala.core.JsonKeyCodec[$tpe]", c.TYPEmode).tpe)
783818
})
784819

785-
val inferredValueCodecs = mutable.Map[Type, Tree]((rootTpe, EmptyTree))
786-
787820
def findImplicitValueCodec(tpe: Type): Tree = inferredValueCodecs.getOrElseUpdate(tpe, {
788821
c.inferImplicitValue(c.typecheck(tq"com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec[$tpe]", c.TYPEmode).tpe)
789822
})
790823

791-
val mathContexts = new mutable.LinkedHashMap[Int, (TermName, Tree)]
792-
793824
def withMathContextFor(precision: Int): Tree =
794825
if (precision == java.math.MathContext.DECIMAL128.getPrecision) q"_root_.java.math.MathContext.DECIMAL128"
795826
else if (precision == java.math.MathContext.DECIMAL64.getPrecision) q"_root_.java.math.MathContext.DECIMAL64"
796827
else if (precision == java.math.MathContext.DECIMAL32.getPrecision) q"_root_.java.math.MathContext.DECIMAL32"
797828
else if (precision == java.math.MathContext.UNLIMITED.getPrecision) q"_root_.java.math.MathContext.UNLIMITED"
798-
else Ident(mathContexts.getOrElseUpdate(precision, {
799-
val name = TermName(s"mc${mathContexts.size}")
800-
(name, q"private[this] val $name = new _root_.java.math.MathContext(${cfg.bigDecimalPrecision}, _root_.java.math.RoundingMode.HALF_EVEN)")
801-
})._1)
802-
803-
val scalaEnumCaches = new mutable.LinkedHashMap[Type, (TermName, Tree)]
804-
805-
def withScalaEnumCacheFor(tpe: Type): Tree = Ident(scalaEnumCaches.getOrElseUpdate(tpe, {
806-
val name = TermName(s"ec${scalaEnumCaches.size}")
807-
val keyTpe =
808-
if (cfg.useScalaEnumValueId) tq"Int"
809-
else tq"String"
810-
(name, q"private[this] val $name = new _root_.java.util.concurrent.ConcurrentHashMap[$keyTpe, $tpe]")
811-
})._1)
812-
813-
sealed trait TypeInfo
814-
815-
case class JavaEnumValueInfo(value: Tree, name: String)
816-
817-
case class JavaEnumInfo(valueInfos: List[JavaEnumValueInfo], hasTransformed: Boolean, doEncoding: Boolean) extends TypeInfo {
818-
val doLinearSearch: Boolean = valueInfos.size <= 8 && valueInfos.foldLeft(0)(_ + _.name.length) <= 64
819-
}
820-
821-
case class FieldInfo(symbol: TermSymbol, mappedName: String, tmpName: TermName, getter: MethodSymbol,
822-
defaultValue: Option[Tree], resolvedTpe: Type, isStringified: Boolean)
829+
else Ident({
830+
val termNameKey = new MathContextValueKey(precision)
831+
termNames.getOrElse(termNameKey, {
832+
val name = TermName(s"mc${termNames.size}")
833+
termNames.update(termNameKey, name)
834+
trees += q"private[this] val $name = new _root_.java.math.MathContext(${cfg.bigDecimalPrecision}, _root_.java.math.RoundingMode.HALF_EVEN)"
835+
name
836+
})
837+
})
823838

824-
case class ClassInfo(tpe: Type, paramLists: List[List[FieldInfo]]) extends TypeInfo {
825-
val fields: List[FieldInfo] = paramLists.flatten
826-
}
839+
def withScalaEnumCacheFor(tpe: Type): Tree = Ident({
840+
val termNameKey = new ScalaEnumValueKey(tpe)
841+
termNames.getOrElse(termNameKey, {
842+
val name = TermName(s"ec${termNames.size}")
843+
termNames.update(termNameKey, name)
844+
val keyTpe =
845+
if (cfg.useScalaEnumValueId) tq"Int"
846+
else tq"String"
847+
trees += q"private[this] val $name = new _root_.java.util.concurrent.ConcurrentHashMap[$keyTpe, $tpe]"
848+
name
849+
})
850+
})
827851

828-
val typeInfos = new mutable.HashMap[Type, TypeInfo]
829852

830853
def getJavaEnumInfo(tpe: Type): JavaEnumInfo = typeInfos.getOrElseUpdate(tpe, {
831854
val javaEnumValueNameMapper: String => String = n => cfg.javaEnumValueNameMapper.lift(n).getOrElse(n)
@@ -948,38 +971,49 @@ object JsonCodecMaker {
948971
(tpe.typeSymbol.asClass.isDerivedValueClass || cfg.inlineOneValueClasses && !isCollection(tpe) && getClassInfo(tpe).fields.size == 1)
949972

950973
def adtLeafClasses(adtBaseTpe: Type): Seq[Type] = {
951-
def collectRecursively(tpe: Type): Seq[Type] = {
974+
val seen = new mutable.HashSet[Type]
975+
val subTypes = new mutable.ListBuffer[Type]
976+
implicit val subTypeOrdering: Ordering[Symbol] = (x: Symbol, y: Symbol) => x.fullName.compareTo(y.fullName)
977+
978+
def collectRecursively(tpe: Type): Unit = {
979+
val tpeTypeArgs = typeArgs(tpe)
952980
val tpeClass = tpe.typeSymbol.asClass
953-
val leafTpes = tpeClass.knownDirectSubclasses.toSeq.sortBy(_.fullName).flatMap { s =>
981+
var typeParamsAndArgs = Map.empty[String, Type]
982+
if (tpeTypeArgs ne Nil) tpeClass.typeParams.zip(tpeTypeArgs).foreach { case (typeParam, typeArg) =>
983+
typeParamsAndArgs = typeParamsAndArgs.updated(typeParam.toString, typeArg)
984+
}
985+
val subClasses = tpeClass.knownDirectSubclasses.toArray
986+
scala.util.Sorting.stableSort(subClasses)
987+
subClasses.foreach { s =>
954988
val classSymbol = s.asClass
955-
val typeParams = classSymbol.typeParams
956-
val subTpe =
957-
if (typeParams eq Nil) classSymbol.toType
958-
else {
959-
val typeParamsAndArgs = tpeClass.typeParams.map(_.toString).zip(typeArgs(tpe)).toMap
960-
classSymbol.toType.substituteTypes(typeParams, typeParams.map(tp => typeParamsAndArgs.getOrElse(tp.toString, fail {
961-
s"Cannot resolve generic type(s) for '${classSymbol.toType}'. Please provide a custom implicitly accessible codec for it."
962-
})))
963-
}
989+
var subTpe = classSymbol.toType
990+
if (tpeTypeArgs ne Nil) {
991+
val typeParams = classSymbol.typeParams
992+
subTpe = subTpe.substituteTypes(typeParams, typeParams.map(tp => typeParamsAndArgs.getOrElse(tp.toString, fail {
993+
s"Cannot resolve generic type(s) for '$subTpe'. Please provide a custom implicitly accessible codec for it."
994+
})))
995+
}
964996
if (isSealedClass(subTpe)) collectRecursively(subTpe)
965997
else if (isValueClass(subTpe)) {
966998
fail("'AnyVal' and one value classes with 'CodecMakerConfig.withInlineOneValueClasses(true)' are not " +
967999
s"supported as leaf classes for ADT with base '$adtBaseTpe'.")
968-
} else if (isNonAbstractScalaClass(subTpe)) Seq(subTpe)
969-
else fail((if (s.isAbstract) {
1000+
} else if (isNonAbstractScalaClass(subTpe)) {
1001+
if (seen.add(subTpe)) subTypes += subTpe
1002+
} else fail((if (s.isAbstract) {
9701003
"Only sealed intermediate traits or abstract classes are supported."
9711004
} else {
9721005
"Only concrete (no free type parameters) Scala classes & objects are supported for ADT leaf classes."
9731006
}) + s" Please consider using of them for ADT with base '$adtBaseTpe' or provide a custom implicitly accessible codec for the ADT base.")
9741007
}
975-
if (isNonAbstractScalaClass(tpe)) leafTpes :+ tpe
976-
else leafTpes
1008+
if (isNonAbstractScalaClass(tpe)) {
1009+
if (seen.add(tpe)) subTypes += tpe
1010+
}
9771011
}
9781012

979-
val classes = distinct(collectRecursively(adtBaseTpe))
980-
if (classes.isEmpty) fail(s"Cannot find leaf classes for ADT base '$adtBaseTpe'. " +
1013+
collectRecursively(adtBaseTpe)
1014+
if (subTypes.isEmpty) fail(s"Cannot find leaf classes for ADT base '$adtBaseTpe'. " +
9811015
"Please add them or provide a custom implicitly accessible codec for the ADT base.")
982-
classes
1016+
subTypes.toList
9831017
}
9841018

9851019
def genReadKey(types: List[Type]): Tree = {
@@ -1337,44 +1371,49 @@ object JsonCodecMaker {
13371371
}
13381372
}
13391373

1340-
val nullValues = new mutable.LinkedHashMap[Type, (TermName, Tree)]
1341-
1342-
def withNullValueFor(tpe: Type)(f: => Tree): Tree = Ident(nullValues.getOrElseUpdate(tpe, {
1343-
val name = TermName(s"c${nullValues.size}")
1344-
(name, q"private[this] val $name: $tpe = $f")
1345-
})._1)
1346-
1347-
val fields = new mutable.LinkedHashMap[Type, (TermName, Tree)]
1348-
1349-
def withFieldsFor(tpe: Type)(f: => Seq[String]): Tree = Ident(fields.getOrElseUpdate(tpe, {
1350-
val name = TermName(s"f${fields.size}")
1351-
val cases = f.map {
1352-
var i = -1
1353-
n =>
1354-
i += 1
1355-
cq"$i => $n"
1356-
}
1357-
(name,
1358-
q"""private[this] def $name(i: Int): String =
1359-
(i: @_root_.scala.annotation.switch @_root_.scala.unchecked) match {
1360-
case ..$cases
1361-
}""")
1362-
})._1)
1363-
1364-
val equalsMethods = new mutable.LinkedHashMap[Type, (TermName, Tree)]
1374+
def withNullValueFor(tpe: Type)(f: => Tree): Tree = Ident({
1375+
val termNameKey = new NullValueKey(tpe)
1376+
termNames.getOrElse(termNameKey, {
1377+
val name = TermName(s"c${termNames.size}")
1378+
termNames.update(termNameKey, name)
1379+
trees += q"private[this] val $name: $tpe = $f"
1380+
name
1381+
})
1382+
})
13651383

1366-
def withEqualsFor(tpe: Type, arg1: Tree, arg2: Tree)(f: => Tree): Tree = {
1367-
val equalsMethodName = equalsMethods.getOrElseUpdate(tpe, {
1368-
val name = TermName(s"q${equalsMethods.size}")
1369-
(name, q"private[this] def $name(x1: $tpe, x2: $tpe): _root_.scala.Boolean = $f")
1370-
})._1
1384+
def withFieldsByIndexFor(termNameKey: FieldIndexMethodKey)(f: => Seq[String]): Tree =
1385+
Ident(termNames.getOrElse(termNameKey, {
1386+
val name = TermName(s"f${termNames.size}")
1387+
termNames.update(termNameKey, name)
1388+
val cases = f.map {
1389+
var i = -1
1390+
n =>
1391+
i += 1
1392+
cq"$i => $n"
1393+
}
1394+
trees +=
1395+
q"""private[this] def $name(i: Int): String =
1396+
(i: @_root_.scala.annotation.switch @_root_.scala.unchecked) match {
1397+
case ..$cases
1398+
}"""
1399+
name
1400+
}))
1401+
1402+
def withEqualsFor(termNameKey: EqualsMethodKey, arg1: Tree, arg2: Tree)(f: => Tree): Tree = {
1403+
val equalsMethodName = termNames.getOrElse(termNameKey, {
1404+
val name = TermName(s"q${termNames.size}")
1405+
termNames.update(termNameKey, name)
1406+
val mTpe = termNameKey.tpe
1407+
trees += q"private[this] def $name(x1: $mTpe, x2: $mTpe): _root_.scala.Boolean = $f"
1408+
name
1409+
})
13711410
q"$equalsMethodName($arg1, $arg2)"
13721411
}
13731412

13741413
def genArrayEquals(tpe: Type): Tree = {
13751414
val tpe1 = typeArg1(tpe)
13761415
if (tpe1 <:< typeOf[Array[?]]) {
1377-
val equals = withEqualsFor(tpe1, q"x1(i)", q"x2(i)")(genArrayEquals(tpe1))
1416+
val equals = withEqualsFor(new EqualsMethodKey(tpe1), q"x1(i)", q"x2(i)")(genArrayEquals(tpe1))
13781417
q"""(x1 eq x2) || ((x1 ne null) && (x2 ne null) && {
13791418
val l = x1.length
13801419
(x2.length == l) && {
@@ -1386,31 +1425,24 @@ object JsonCodecMaker {
13861425
} else q"_root_.java.util.Arrays.equals(x1, x2)"
13871426
}
13881427

1389-
case class MethodKey(tpe: Type, isStringified: Boolean, discriminator: Tree)
1390-
1391-
val decodeMethodNames = new mutable.HashMap[MethodKey, TermName]
1392-
val methodTrees = new mutable.ArrayBuffer[Tree]
1393-
1394-
def withDecoderFor(methodKey: MethodKey, arg: Tree)(f: => Tree): Tree = {
1395-
val decodeMethodName = decodeMethodNames.getOrElse(methodKey, {
1396-
val name = TermName(s"d${decodeMethodNames.size}")
1397-
val mtpe = methodKey.tpe
1398-
decodeMethodNames.update(methodKey, name)
1399-
methodTrees +=
1400-
q"private[this] def $name(in: _root_.com.github.plokhotnyuk.jsoniter_scala.core.JsonReader, default: $mtpe): $mtpe = $f"
1428+
def withDecoderFor(termNameKey: DecoderMethodKey, arg: Tree)(f: => Tree): Tree = {
1429+
val decodeMethodName = termNames.getOrElse(termNameKey, {
1430+
val name = TermName(s"d${termNames.size}")
1431+
termNames.update(termNameKey, name)
1432+
val mTpe = termNameKey.tpe
1433+
trees +=
1434+
q"private[this] def $name(in: _root_.com.github.plokhotnyuk.jsoniter_scala.core.JsonReader, default: $mTpe): $mTpe = $f"
14011435
name
14021436
})
14031437
q"$decodeMethodName(in, $arg)"
14041438
}
14051439

1406-
val encodeMethodNames = new mutable.HashMap[MethodKey, TermName]
1407-
1408-
def withEncoderFor(methodKey: MethodKey, arg: Tree)(f: => Tree): Tree = {
1409-
val encodeMethodName = encodeMethodNames.getOrElse(methodKey, {
1410-
val name = TermName(s"e${encodeMethodNames.size}")
1411-
encodeMethodNames.update(methodKey, name)
1412-
methodTrees +=
1413-
q"private[this] def $name(x: ${methodKey.tpe}, out: _root_.com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter): _root_.scala.Unit = $f"
1440+
def withEncoderFor(termNameKey: EncoderMethodKey, arg: Tree)(f: => Tree): Tree = {
1441+
val encodeMethodName = termNames.getOrElse(termNameKey, {
1442+
val name = TermName(s"e${termNames.size}")
1443+
termNames.update(termNameKey, name)
1444+
val mTpe = termNameKey.tpe
1445+
trees += q"private[this] def $name(x: $mTpe, out: _root_.com.github.plokhotnyuk.jsoniter_scala.core.JsonWriter): _root_.scala.Unit = $f"
14141446
name
14151447
})
14161448
q"$encodeMethodName($arg, out)"
@@ -1510,7 +1542,7 @@ object JsonCodecMaker {
15101542
val checkReqVars =
15111543
if (required.isEmpty) Nil
15121544
else {
1513-
val names = withFieldsFor(tpe)(mappedNames)
1545+
val names = withFieldsByIndexFor(new FieldIndexMethodKey(tpe))(mappedNames)
15141546
val reqMasks = fields.grouped(32).toArray.map(_.foldLeft(0) {
15151547
var i = -1
15161548
(acc, fieldInfo) =>
@@ -1695,7 +1727,7 @@ object JsonCodecMaker {
16951727
q"new $tpe(${genReadVal(types1, genNullValue(types1), isStringified, EmptyTree)})"
16961728
} else {
16971729
val isColl = isCollection(tpe)
1698-
val methodKey = new MethodKey(tpe, isColl & isStringified, discriminator)
1730+
val methodKey = new DecoderMethodKey(tpe, isColl & isStringified, discriminator)
16991731
if (isColl) {
17001732
if (tpe <:< typeOf[Array[?]] || isImmutableArraySeq(tpe) || isMutableArraySeq(tpe)) withDecoderFor(methodKey, default) {
17011733
val tpe1 = typeArg1(tpe)
@@ -2074,10 +2106,11 @@ object JsonCodecMaker {
20742106
..${genWriteVal(q"v.get", typeArg1(fTpe) :: allTypes, fieldInfo.isStringified, EmptyTree)}
20752107
}"""
20762108
} else if (fTpe <:< typeOf[Array[?]]) {
2109+
val methodKey = new EqualsMethodKey(fTpe)
20772110
val cond =
20782111
if (cfg.transientEmpty) {
2079-
q"v.length != 0 && !${withEqualsFor(fTpe, q"v", d)(genArrayEquals(fTpe))}"
2080-
} else q"!${withEqualsFor(fTpe, q"v", d)(genArrayEquals(fTpe))}"
2112+
q"v.length != 0 && !${withEqualsFor(methodKey, q"v", d)(genArrayEquals(fTpe))}"
2113+
} else q"!${withEqualsFor(methodKey, q"v", d)(genArrayEquals(fTpe))}"
20812114
q"""val v = x.${fieldInfo.getter}
20822115
if ($cond) {
20832116
..${genWriteConstantKey(fieldInfo.mappedName)}
@@ -2169,7 +2202,7 @@ object JsonCodecMaker {
21692202
genWriteVal(q"$m.${valueClassValueSymbol(tpe)}", valueClassValueType(tpe) :: types, isStringified, EmptyTree)
21702203
} else {
21712204
val isColl = isCollection(tpe)
2172-
val methodKey = new MethodKey(tpe, isColl & isStringified, discriminator)
2205+
val methodKey = new EncoderMethodKey(tpe, isColl & isStringified, discriminator)
21732206
if (isColl) {
21742207
if (tpe <:< typeOf[Array[?]] || isImmutableArraySeq(tpe) || isMutableArraySeq(tpe)) withEncoderFor(methodKey, m) {
21752208
val tpe1 = typeArg1(tpe)
@@ -2363,12 +2396,8 @@ object JsonCodecMaker {
23632396
if (cfg.decodingOnly) q"_root_.scala.Predef.???"
23642397
else genWriteVal(q"x", types, cfg.isStringified, EmptyTree)
23652398
}
2366-
..$methodTrees
2367-
..${fields.values.map(_._2)}
2368-
..${equalsMethods.values.map(_._2)}
2369-
..${nullValues.values.map(_._2)}
2370-
..${mathContexts.values.map(_._2)}
2371-
..${scalaEnumCaches.values.map(_._2)}
2399+
2400+
..$trees
23722401
}
23732402
x
23742403
}"""
@@ -2397,9 +2426,4 @@ object JsonCodecMaker {
23972426
val seen = new mutable.HashSet[A]
23982427
x => !seen.add(x)
23992428
}
2400-
2401-
private[this] def distinct[A](xs: Seq[A]): Seq[A] = xs.filter {
2402-
val seen = new mutable.HashSet[A]
2403-
x => seen.add(x)
2404-
}
24052429
}

0 commit comments

Comments
 (0)