Skip to content

Commit 6b8909b

Browse files
committed
#807 Add field name to AST object memoization for better performance of raw record processing.
1 parent cd292b4 commit 6b8909b

File tree

1 file changed

+67
-50
lines changed
  • cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser

1 file changed

+67
-50
lines changed

cobol-parser/src/main/scala/za/co/absa/cobrix/cobol/parser/Copybook.scala

Lines changed: 67 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -21,19 +21,30 @@ import za.co.absa.cobrix.cobol.parser.CopybookParser.CopybookAST
2121
import za.co.absa.cobrix.cobol.parser.ast.{Group, Primitive, Statement}
2222
import za.co.absa.cobrix.cobol.parser.asttransform.BinaryPropertiesAdder
2323

24+
import java.util.concurrent.ConcurrentHashMap
2425
import scala.collection.mutable
2526
import scala.collection.mutable.ArrayBuffer
2627

2728

2829
class Copybook(val ast: CopybookAST) extends Logging with Serializable {
2930
import Copybook._
3031

31-
def getCobolSchema: CopybookAST = ast
32+
private val cachePrimitives = new ConcurrentHashMap[String, Primitive]()
33+
private val cacheStatements = new ConcurrentHashMap[String, Statement]()
34+
35+
val isFlatCopybook: Boolean = ast.children.exists(f => f.isInstanceOf[Primitive])
3236

3337
lazy val getRecordSize: Int = {
3438
ast.binaryProperties.offset + ast.binaryProperties.actualSize
3539
}
3640

41+
/**
42+
* Returns true if there is at least 1 parent-child relationship defined in any of segment redefines.
43+
*/
44+
lazy val isHierarchical: Boolean = getAllSegmentRedefines.exists(_.parentSegment.nonEmpty)
45+
46+
def getCobolSchema: CopybookAST = ast
47+
3748
def isRecordFixedSize: Boolean = true
3849

3950
/**
@@ -57,13 +68,6 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
5768
def getRootSegmentIds(segmentIdRedefineMap: Map[String, String], fieldParentMap: Map[String, String]): List[String] =
5869
CopybookParser.getRootSegmentIds(segmentIdRedefineMap, fieldParentMap)
5970

60-
/**
61-
* Returns true if there at least 1 parent-child relationships defined in any of segment redefines.
62-
*/
63-
lazy val isHierarchical: Boolean = getAllSegmentRedefines.exists(_.parentSegment.nonEmpty)
64-
65-
val isFlatCopybook: Boolean = ast.children.exists(f => f.isInstanceOf[Primitive])
66-
6771
def getRootRecords: scala.collection.Seq[Statement] = {
6872
if (isFlatCopybook) {
6973
scala.collection.Seq(ast)
@@ -84,11 +88,9 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
8488
* @return The value of the field
8589
*/
8690
def getFieldValueByName(fieldName: String, recordBytes: Array[Byte], startOffset: Int = 0): Any = {
87-
val ast = getFieldByName(fieldName)
88-
ast match {
89-
case s: Primitive => extractPrimitiveField(s, recordBytes, startOffset)
90-
case _ => throw new IllegalStateException(s"$fieldName is not a primitive field, cannot extract its value.")
91-
}
91+
val primitive = getPrimitiveFieldByName(fieldName)
92+
93+
extractPrimitiveField(primitive, recordBytes, startOffset)
9294
}
9395

9496
/**
@@ -105,11 +107,9 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
105107
* @param startOffset An offset where the record starts in the data (in bytes)
106108
*/
107109
def setFieldValueByName(fieldName: String, recordBytes: Array[Byte], value: Any, startOffset: Int = 0): Unit = {
108-
val ast = getFieldByName(fieldName)
109-
ast match {
110-
case s: Primitive => setPrimitiveField(s, recordBytes, value, startOffset)
111-
case _ => throw new IllegalStateException(s"$fieldName is not a primitive field, cannot set its value.")
112-
}
110+
val primitive = getPrimitiveFieldByName(fieldName)
111+
112+
setPrimitiveField(primitive, recordBytes, value, startOffset)
113113
}
114114

115115
/**
@@ -126,12 +126,10 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
126126

127127
def getFieldByNameInGroup(group: Group, fieldName: String): Seq[Statement] = {
128128
val groupMatch = if (group.name.equalsIgnoreCase(fieldName)) Seq(group) else Seq()
129-
groupMatch ++ group.children.flatMap(child => {
130-
child match {
131-
case g: Group => getFieldByNameInGroup(g, fieldName)
132-
case st: Primitive => if (st.name.equalsIgnoreCase(fieldName)) Seq(st) else Seq()
133-
}
134-
})
129+
groupMatch ++ group.children.flatMap {
130+
case g: Group => getFieldByNameInGroup(g, fieldName)
131+
case st: Primitive => if (st.name.equalsIgnoreCase(fieldName)) Seq(st) else Seq()
132+
}
135133
}
136134

137135
def getFieldByUniqueName(schema: CopybookAST, fieldName: String): Seq[Statement] = {
@@ -143,18 +141,16 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
143141
if (path.length == 0) {
144142
throw new IllegalStateException(s"'$fieldName' is a GROUP and not a primitive field. Cannot extract it's value.")
145143
} else {
146-
group.children.flatMap(child => {
147-
child match {
148-
case g: Group =>
149-
if (g.name.equalsIgnoreCase(path.head))
150-
getFieldByPathInGroup(g, path.drop(1))
151-
else scala.collection.Seq.empty[Statement]
152-
case st: Primitive =>
153-
if (st.name.equalsIgnoreCase(path.head))
154-
Seq(st)
155-
else scala.collection.Seq.empty[Statement]
156-
}
157-
})
144+
group.children.flatMap {
145+
case g: Group =>
146+
if (g.name.equalsIgnoreCase(path.head))
147+
getFieldByPathInGroup(g, path.drop(1))
148+
else scala.collection.Seq.empty[Statement]
149+
case st: Primitive =>
150+
if (st.name.equalsIgnoreCase(path.head))
151+
Seq(st)
152+
else scala.collection.Seq.empty[Statement]
153+
}
158154
}
159155
}
160156

@@ -181,21 +177,27 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
181177
)
182178
}
183179

184-
val schema = getCobolSchema
180+
val cachedStatement = cacheStatements.get(fieldName)
185181

186-
val foundFields = if (fieldName.contains('.')) {
187-
getFieldByPathName(schema, fieldName)
188-
} else {
189-
getFieldByUniqueName(schema, fieldName)
190-
}
182+
if (cachedStatement == null) {
183+
val schema = getCobolSchema
191184

192-
if (foundFields.isEmpty) {
193-
throw new IllegalStateException(s"Field '$fieldName' is not found in the copybook.")
194-
} else if (foundFields.lengthCompare(1) == 0) {
195-
foundFields.head
185+
val foundFields = if (fieldName.contains('.')) {
186+
getFieldByPathName(schema, fieldName)
187+
} else {
188+
getFieldByUniqueName(schema, fieldName)
189+
}
190+
191+
if (foundFields.isEmpty) {
192+
throw new IllegalStateException(s"Field '$fieldName' is not found in the copybook.")
193+
} else if (foundFields.lengthCompare(1) == 0) {
194+
foundFields.head
195+
} else {
196+
throw new IllegalStateException(s"Multiple fields with name '$fieldName' found in the copybook. Please specify the exact field using '.' " +
197+
s"notation.")
198+
}
196199
} else {
197-
throw new IllegalStateException(s"Multiple fields with name '$fieldName' found in the copybook. Please specify the exact field using '.' " +
198-
s"notation.")
200+
cachedStatement
199201
}
200202
}
201203

@@ -344,8 +346,23 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
344346
}
345347
visitGroup(ast)
346348
}
347-
}
348349

350+
private def getPrimitiveFieldByName(fieldName: String): Primitive = {
351+
val cachedPrimitive = cachePrimitives.get(fieldName)
352+
353+
if (cachedPrimitive == null) {
354+
val ast = getFieldByName(fieldName)
355+
ast match {
356+
case s: Primitive =>
357+
cachePrimitives.put(fieldName, s)
358+
s
359+
case _ => throw new IllegalStateException(s"$fieldName is not a primitive field, cannot extract its value.")
360+
}
361+
} else {
362+
cachedPrimitive
363+
}
364+
}
365+
}
349366

350367
object Copybook {
351368
def merge(copybooks: Seq[Copybook]): Copybook = {
@@ -443,4 +460,4 @@ object Copybook {
443460
throw new IllegalStateException(s"Cannot set value for field '${field.name}' because it does not have an encoder defined.")
444461
}
445462
}
446-
}
463+
}

0 commit comments

Comments
 (0)