@@ -21,19 +21,30 @@ import za.co.absa.cobrix.cobol.parser.CopybookParser.CopybookAST
2121import za .co .absa .cobrix .cobol .parser .ast .{Group , Primitive , Statement }
2222import za .co .absa .cobrix .cobol .parser .asttransform .BinaryPropertiesAdder
2323
24+ import java .util .concurrent .ConcurrentHashMap
2425import scala .collection .mutable
2526import scala .collection .mutable .ArrayBuffer
2627
2728
2829class Copybook (val ast : CopybookAST ) extends Logging with Serializable {
2930 import Copybook ._
3031
31- def getCobolSchema : CopybookAST = ast
32+ private val cachePrimitives = new ConcurrentHashMap [String , Primitive ]()
33+ private val cacheStatements = new ConcurrentHashMap [String , Statement ]()
34+
35+ val isFlatCopybook : Boolean = ast.children.exists(f => f.isInstanceOf [Primitive ])
3236
3337 lazy val getRecordSize : Int = {
3438 ast.binaryProperties.offset + ast.binaryProperties.actualSize
3539 }
3640
41+ /**
42+ * Returns true if there is at least 1 parent-child relationship defined in any of segment redefines.
43+ */
44+ lazy val isHierarchical : Boolean = getAllSegmentRedefines.exists(_.parentSegment.nonEmpty)
45+
46+ def getCobolSchema : CopybookAST = ast
47+
3748 def isRecordFixedSize : Boolean = true
3849
3950 /**
@@ -57,13 +68,6 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
5768 def getRootSegmentIds (segmentIdRedefineMap : Map [String , String ], fieldParentMap : Map [String , String ]): List [String ] =
5869 CopybookParser .getRootSegmentIds(segmentIdRedefineMap, fieldParentMap)
5970
60- /**
61- * Returns true if there at least 1 parent-child relationships defined in any of segment redefines.
62- */
63- lazy val isHierarchical : Boolean = getAllSegmentRedefines.exists(_.parentSegment.nonEmpty)
64-
65- val isFlatCopybook : Boolean = ast.children.exists(f => f.isInstanceOf [Primitive ])
66-
6771 def getRootRecords : scala.collection.Seq [Statement ] = {
6872 if (isFlatCopybook) {
6973 scala.collection.Seq (ast)
@@ -84,11 +88,9 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
8488 * @return The value of the field
8589 */
8690 def getFieldValueByName (fieldName : String , recordBytes : Array [Byte ], startOffset : Int = 0 ): Any = {
87- val ast = getFieldByName(fieldName)
88- ast match {
89- case s : Primitive => extractPrimitiveField(s, recordBytes, startOffset)
90- case _ => throw new IllegalStateException (s " $fieldName is not a primitive field, cannot extract its value. " )
91- }
91+ val primitive = getPrimitiveFieldByName(fieldName)
92+
93+ extractPrimitiveField(primitive, recordBytes, startOffset)
9294 }
9395
9496 /**
@@ -105,11 +107,9 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
105107 * @param startOffset An offset where the record starts in the data (in bytes)
106108 */
107109 def setFieldValueByName (fieldName : String , recordBytes : Array [Byte ], value : Any , startOffset : Int = 0 ): Unit = {
108- val ast = getFieldByName(fieldName)
109- ast match {
110- case s : Primitive => setPrimitiveField(s, recordBytes, value, startOffset)
111- case _ => throw new IllegalStateException (s " $fieldName is not a primitive field, cannot set its value. " )
112- }
110+ val primitive = getPrimitiveFieldByName(fieldName)
111+
112+ setPrimitiveField(primitive, recordBytes, value, startOffset)
113113 }
114114
115115 /**
@@ -126,12 +126,10 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
126126
127127 def getFieldByNameInGroup (group : Group , fieldName : String ): Seq [Statement ] = {
128128 val groupMatch = if (group.name.equalsIgnoreCase(fieldName)) Seq (group) else Seq ()
129- groupMatch ++ group.children.flatMap(child => {
130- child match {
131- case g : Group => getFieldByNameInGroup(g, fieldName)
132- case st : Primitive => if (st.name.equalsIgnoreCase(fieldName)) Seq (st) else Seq ()
133- }
134- })
129+ groupMatch ++ group.children.flatMap {
130+ case g : Group => getFieldByNameInGroup(g, fieldName)
131+ case st : Primitive => if (st.name.equalsIgnoreCase(fieldName)) Seq (st) else Seq ()
132+ }
135133 }
136134
137135 def getFieldByUniqueName (schema : CopybookAST , fieldName : String ): Seq [Statement ] = {
@@ -143,18 +141,16 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
143141 if (path.length == 0 ) {
144142 throw new IllegalStateException (s " ' $fieldName' is a GROUP and not a primitive field. Cannot extract it's value. " )
145143 } else {
146- group.children.flatMap(child => {
147- child match {
148- case g : Group =>
149- if (g.name.equalsIgnoreCase(path.head))
150- getFieldByPathInGroup(g, path.drop(1 ))
151- else scala.collection.Seq .empty[Statement ]
152- case st : Primitive =>
153- if (st.name.equalsIgnoreCase(path.head))
154- Seq (st)
155- else scala.collection.Seq .empty[Statement ]
156- }
157- })
144+ group.children.flatMap {
145+ case g : Group =>
146+ if (g.name.equalsIgnoreCase(path.head))
147+ getFieldByPathInGroup(g, path.drop(1 ))
148+ else scala.collection.Seq .empty[Statement ]
149+ case st : Primitive =>
150+ if (st.name.equalsIgnoreCase(path.head))
151+ Seq (st)
152+ else scala.collection.Seq .empty[Statement ]
153+ }
158154 }
159155 }
160156
@@ -181,21 +177,27 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
181177 )
182178 }
183179
184- val schema = getCobolSchema
180+ val cachedStatement = cacheStatements.get(fieldName)
185181
186- val foundFields = if (fieldName.contains('.' )) {
187- getFieldByPathName(schema, fieldName)
188- } else {
189- getFieldByUniqueName(schema, fieldName)
190- }
182+ if (cachedStatement == null ) {
183+ val schema = getCobolSchema
191184
192- if (foundFields.isEmpty) {
193- throw new IllegalStateException (s " Field ' $fieldName' is not found in the copybook. " )
194- } else if (foundFields.lengthCompare(1 ) == 0 ) {
195- foundFields.head
185+ val foundFields = if (fieldName.contains('.' )) {
186+ getFieldByPathName(schema, fieldName)
187+ } else {
188+ getFieldByUniqueName(schema, fieldName)
189+ }
190+
191+ if (foundFields.isEmpty) {
192+ throw new IllegalStateException (s " Field ' $fieldName' is not found in the copybook. " )
193+ } else if (foundFields.lengthCompare(1 ) == 0 ) {
194+ foundFields.head
195+ } else {
196+ throw new IllegalStateException (s " Multiple fields with name ' $fieldName' found in the copybook. Please specify the exact field using '.' " +
197+ s " notation. " )
198+ }
196199 } else {
197- throw new IllegalStateException (s " Multiple fields with name ' $fieldName' found in the copybook. Please specify the exact field using '.' " +
198- s " notation. " )
200+ cachedStatement
199201 }
200202 }
201203
@@ -344,8 +346,23 @@ class Copybook(val ast: CopybookAST) extends Logging with Serializable {
344346 }
345347 visitGroup(ast)
346348 }
347- }
348349
350+ private def getPrimitiveFieldByName (fieldName : String ): Primitive = {
351+ val cachedPrimitive = cachePrimitives.get(fieldName)
352+
353+ if (cachedPrimitive == null ) {
354+ val ast = getFieldByName(fieldName)
355+ ast match {
356+ case s : Primitive =>
357+ cachePrimitives.put(fieldName, s)
358+ s
359+ case _ => throw new IllegalStateException (s " $fieldName is not a primitive field, cannot extract its value. " )
360+ }
361+ } else {
362+ cachedPrimitive
363+ }
364+ }
365+ }
349366
350367object Copybook {
351368 def merge (copybooks : Seq [Copybook ]): Copybook = {
@@ -443,4 +460,4 @@ object Copybook {
443460 throw new IllegalStateException (s " Cannot set value for field ' ${field.name}' because it does not have an encoder defined. " )
444461 }
445462 }
446- }
463+ }
0 commit comments