Skip to content

Commit f81cc87

Browse files
authored
Added Scala BsonIgnore annotation (#584)
JAVA-3814
1 parent 417f97f commit f81cc87

File tree

3 files changed

+109
-21
lines changed

3 files changed

+109
-21
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.mongodb.scala.bson.annotations
2+
3+
import scala.annotation.StaticAnnotation
4+
5+
/**
6+
* Annotation to ignore a property.
7+
*/
8+
case class BsonIgnore() extends StaticAnnotation

bson-scala/src/main/scala/org/mongodb/scala/bson/codecs/macrocodecs/CaseClassCodec.scala

Lines changed: 56 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717
package org.mongodb.scala.bson.codecs.macrocodecs
1818

1919
import scala.reflect.macros.whitebox
20-
2120
import org.bson.codecs.Codec
2221
import org.bson.codecs.configuration.CodecRegistry
23-
24-
import org.mongodb.scala.bson.annotations.BsonProperty
22+
import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty }
2523

2624
private[codecs] object CaseClassCodec {
2725

@@ -157,6 +155,36 @@ private[codecs] object CaseClassCodec {
157155
.toMap
158156
}
159157

158+
val ignoredFields: Map[Type, Seq[(TermName, Tree)]] = {
159+
knownTypes.map { tpe =>
160+
if (!isCaseClass(tpe)) {
161+
(tpe, Nil)
162+
} else {
163+
val constructor = tpe.decl(termNames.CONSTRUCTOR)
164+
if (!constructor.isMethod) c.abort(c.enclosingPosition, "No constructor, unsupported class type")
165+
166+
val defaults = constructor.asMethod.paramLists.head
167+
.map(_.asTerm)
168+
.zipWithIndex
169+
.filter(_._1.annotations.exists(_.tree.tpe == typeOf[BsonIgnore]))
170+
.map {
171+
case (p, i) =>
172+
if (p.isParamWithDefault) {
173+
val getterName = TermName("apply$default$" + (i + 1))
174+
p.name -> q"${tpe.typeSymbol.companion}.$getterName"
175+
} else {
176+
c.abort(
177+
c.enclosingPosition,
178+
s"Field [${p.name}] with BsonIgnore annotation must have a default value"
179+
)
180+
}
181+
}
182+
183+
tpe -> defaults
184+
}
185+
}.toMap
186+
}
187+
160188
// Data converters
161189
def keyName(t: Type): Literal = Literal(Constant(t.typeSymbol.name.decodedName.toString))
162190
def keyNameTerm(t: TermName): Literal = Literal(classAnnotatedFieldsMap.getOrElse(t, Constant(t.toString)))
@@ -284,12 +312,14 @@ private[codecs] object CaseClassCodec {
284312
* @param fields the list of fields
285313
* @return the tree that writes the case class fields
286314
*/
287-
def writeClassValues(fields: List[(TermName, Type)]): List[Tree] = {
288-
fields.map({
289-
case (name, f) =>
290-
val key = keyNameTerm(name)
291-
f match {
292-
case optional if isOption(optional) => q"""
315+
def writeClassValues(fields: List[(TermName, Type)], ignoredFields: Seq[(TermName, Tree)]): List[Tree] = {
316+
fields
317+
.filterNot { case (name, _) => ignoredFields.exists { case (iname, _) => name == iname } }
318+
.map({
319+
case (name, f) =>
320+
val key = keyNameTerm(name)
321+
f match {
322+
case optional if isOption(optional) => q"""
293323
val localVal = instanceValue.$name
294324
if (localVal.isDefined) {
295325
writer.writeName($key)
@@ -298,13 +328,13 @@ private[codecs] object CaseClassCodec {
298328
writer.writeName($key)
299329
this.writeFieldValue($key, writer, this.bsonNull, encoderContext)
300330
}"""
301-
case _ => q"""
331+
case _ => q"""
302332
val localVal = instanceValue.$name
303333
writer.writeName($key)
304334
this.writeFieldValue($key, writer, localVal, encoderContext)
305335
"""
306-
}
307-
})
336+
}
337+
})
308338
}
309339

310340
/*
@@ -314,7 +344,7 @@ private[codecs] object CaseClassCodec {
314344
val cases: Seq[Tree] = {
315345
fields.map(field => cq""" ${keyName(field._1)} =>
316346
val instanceValue = value.asInstanceOf[${field._1}]
317-
..${writeClassValues(field._2)}""").toSeq
347+
..${writeClassValues(field._2, ignoredFields(field._1))}""").toSeq
318348
}
319349

320350
q"""
@@ -325,23 +355,29 @@ private[codecs] object CaseClassCodec {
325355
"""
326356
}
327357

328-
def fieldSetters(fields: List[(TermName, Type)]) = {
358+
def fieldSetters(fields: List[(TermName, Type)], ignoredFields: Seq[(TermName, Tree)]) = {
329359
fields.map({
330360
case (name, f) =>
331361
val key = keyNameTerm(name)
332362
val missingField = Literal(Constant(s"Missing field: $key"))
333-
f match {
334-
case optional if isOption(optional) =>
335-
q"$name = (if (fieldData.contains($key)) Option(fieldData($key)) else None).asInstanceOf[$f]"
336-
case _ =>
337-
q"""$name = fieldData.getOrElse($key, throw new BsonInvalidOperationException($missingField)).asInstanceOf[$f]"""
363+
364+
ignoredFields.find { case (iname, _) => name == iname }.map(_._2) match {
365+
case Some(default) =>
366+
q"$name = $default"
367+
case None =>
368+
f match {
369+
case optional if isOption(optional) =>
370+
q"$name = (if (fieldData.contains($key)) Option(fieldData($key)) else None).asInstanceOf[$f]"
371+
case _ =>
372+
q"""$name = fieldData.getOrElse($key, throw new BsonInvalidOperationException($missingField)).asInstanceOf[$f]"""
373+
}
338374
}
339375
})
340376
}
341377

342378
def getInstance = {
343379
val cases = knownTypes.map { st =>
344-
cq"${keyName(st)} => new $st(..${fieldSetters(fields(st))})"
380+
cq"${keyName(st)} => new $st(..${fieldSetters(fields(st), ignoredFields(st))})"
345381
} :+ cq"""_ => throw new BsonInvalidOperationException("Unexpected class type: " + className)"""
346382
q"className match { case ..$cases }"
347383
}

bson-scala/src/test/scala/org/mongodb/scala/bson/codecs/MacrosSpec.scala

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import org.bson.codecs.{ Codec, DecoderContext, EncoderContext }
2626
import org.bson.io.{ BasicOutputBuffer, ByteBufferBsonInput, OutputBuffer }
2727
import org.bson.types.ObjectId
2828
import org.mongodb.scala.bson.BaseSpec
29-
import org.mongodb.scala.bson.annotations.BsonProperty
29+
import org.mongodb.scala.bson.annotations.{ BsonIgnore, BsonProperty }
3030
import org.mongodb.scala.bson.codecs.Macros.{ createCodecProvider, createCodecProviderIgnoreNone }
3131
import org.mongodb.scala.bson.codecs.Registry.DEFAULT_CODEC_REGISTRY
3232
import org.mongodb.scala.bson.collection.immutable.Document
@@ -43,6 +43,7 @@ class MacrosSpec extends BaseSpec {
4343
case class SeqOfStrings(name: String, value: Seq[String])
4444
case class RecursiveSeq(name: String, value: Seq[RecursiveSeq])
4545
case class AnnotatedClass(@BsonProperty("annotated_name") name: String)
46+
case class IgnoredFieldClass(name: String, @BsonIgnore meta: String = "ignored_default")
4647

4748
case class Binary(binary: Array[Byte]) {
4849

@@ -103,6 +104,11 @@ class MacrosSpec extends BaseSpec {
103104
case class Branch(@BsonProperty("l1") b1: Tree, @BsonProperty("r1") b2: Tree, value: Int) extends Tree
104105
case class Leaf(value: Int) extends Tree
105106

107+
sealed trait WithIgnored
108+
case class MetaIgnoredField(data: String, @BsonIgnore meta: Seq[String] = Vector("ignore_me")) extends WithIgnored
109+
case class LeafCountIgnoredField(branchCount: Int, @BsonIgnore leafCount: Int = 100) extends WithIgnored
110+
case class ContainsIgnoredField(list: Seq[WithIgnored])
111+
106112
case class ContainsADT(name: String, tree: Tree)
107113
case class ContainsSeqADT(name: String, trees: Seq[Tree])
108114
case class ContainsNestedSeqADT(name: String, trees: Seq[Seq[Tree]])
@@ -270,6 +276,23 @@ class MacrosSpec extends BaseSpec {
270276
)
271277
}
272278

279+
it should "be able to ignore fields" in {
280+
roundTrip(
281+
IgnoredFieldClass("Bob", "singer"),
282+
IgnoredFieldClass("Bob"),
283+
"""{name: "Bob"}""",
284+
classOf[IgnoredFieldClass]
285+
)
286+
287+
roundTrip(
288+
ContainsIgnoredField(Vector(MetaIgnoredField("Bob", List("singer")), LeafCountIgnoredField(1, 10))),
289+
ContainsIgnoredField(Vector(MetaIgnoredField("Bob"), LeafCountIgnoredField(1))),
290+
"""{"list" : [{"_t" : "MetaIgnoredField", "data" : "Bob" }, {"_t" : "LeafCountIgnoredField", "branchCount": 1}]}""",
291+
classOf[ContainsIgnoredField],
292+
classOf[WithIgnored]
293+
)
294+
}
295+
273296
it should "be able to round trip polymorphic nested case classes in a sealed class" in {
274297
roundTrip(
275298
ContainsSealedClass(List(SealedClassA("test"), SealedClassB(12))),
@@ -657,6 +680,15 @@ class MacrosSpec extends BaseSpec {
657680
roundTripCodec(value, Document(expected), codec)
658681
}
659682

683+
def roundTrip[T](value: T, decodedValue: T, expected: String, provider: CodecProvider, providers: CodecProvider*)(
684+
implicit ct: ClassTag[T]
685+
): Unit = {
686+
val codecProviders: util.List[CodecProvider] = (provider +: providers).asJava
687+
val registry = CodecRegistries.fromRegistries(CodecRegistries.fromProviders(codecProviders), DEFAULT_CODEC_REGISTRY)
688+
val codec = registry.get(ct.runtimeClass).asInstanceOf[Codec[T]]
689+
roundTripCodec(value, decodedValue, Document(expected), codec)
690+
}
691+
660692
def roundTripCodec[T](value: T, expected: Document, codec: Codec[T]): Unit = {
661693
val encoded = encode(codec, value)
662694
val actual = decode(documentCodec, encoded)
@@ -666,6 +698,18 @@ class MacrosSpec extends BaseSpec {
666698
assert(roundTripped == value, s"Round Tripped case class: ($roundTripped) did not equal the original: ($value)")
667699
}
668700

701+
def roundTripCodec[T](value: T, decodedValue: T, expected: Document, codec: Codec[T]): Unit = {
702+
val encoded = encode(codec, value)
703+
val actual = decode(documentCodec, encoded)
704+
assert(expected == actual, s"Encoded document: (${actual.toJson()}) did not equal: (${expected.toJson()})")
705+
706+
val roundTripped = decode(codec, encode(codec, value))
707+
assert(
708+
roundTripped == decodedValue,
709+
s"Round Tripped case class: ($roundTripped) did not equal the expected: ($decodedValue)"
710+
)
711+
}
712+
669713
def encode[T](codec: Codec[T], value: T): OutputBuffer = {
670714
val buffer = new BasicOutputBuffer()
671715
val writer = new BsonBinaryWriter(buffer)

0 commit comments

Comments
 (0)