Skip to content

Commit 0b3bb93

Browse files
authored
Fix the source task generation (#98)
1 parent f3d2805 commit 0b3bb93

File tree

7 files changed

+380
-126
lines changed

7 files changed

+380
-126
lines changed

src/main/scala/com/sap/kafka/connect/source/GenericSourceTask.scala

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@ import com.sap.kafka.connect.source.querier.{BulkTableQuerier, IncrColTableQueri
99
import com.sap.kafka.utils.ExecuteWithExceptions
1010
import org.apache.kafka.common.config.ConfigException
1111
import org.apache.kafka.common.utils.{SystemTime, Time}
12-
import org.apache.kafka.connect.errors.ConnectException
1312
import org.apache.kafka.connect.source.{SourceRecord, SourceTask}
1413
import org.slf4j.LoggerFactory
1514

1615
import scala.collection.JavaConverters._
16+
1717
import scala.collection.mutable
1818

1919
abstract class GenericSourceTask extends SourceTask {
20+
protected var configRawProperties: Option[util.Map[String, String]] = None
2021
protected var config: BaseConfig = _
2122
private val tableQueue = new mutable.Queue[TableQuerier]()
2223
protected var time: Time = new SystemTime()
@@ -28,6 +29,7 @@ abstract class GenericSourceTask extends SourceTask {
2829

2930
override def start(props: util.Map[String, String]): Unit = {
3031
log.info("Read records from HANA")
32+
configRawProperties = Some(props)
3133

3234
ExecuteWithExceptions[Unit, ConfigException, HANAConfigMissingException] (
3335
new HANAConfigMissingException("Couldn't start HANASourceTask due to configuration error")) { () =>
@@ -40,54 +42,31 @@ abstract class GenericSourceTask extends SourceTask {
4042

4143
val topics = config.topics
4244

43-
var tables: List[(String, String)] = Nil
44-
if (topics.forall(topic => config.topicProperties(topic).keySet.contains("table.name"))) {
45-
tables = topics.map(topic =>
46-
(config.topicProperties(topic)("table.name"), topic))
47-
}
48-
49-
var query: List[(String, String)] = Nil
50-
if (topics.forall(topic => config.topicProperties(topic).keySet.contains("query"))) {
51-
query = topics.map(topic =>
52-
(config.topicProperties(topic)("query"), topic))
53-
}
54-
55-
if (tables.isEmpty && query.isEmpty) {
56-
throw new ConnectException("Invalid configuration: each HANASourceTask must have" +
57-
" one table assigned to it")
58-
}
59-
6045
val queryMode = config.queryMode
61-
62-
val tableOrQueryInfos = queryMode match {
63-
case BaseConfigConstants.QUERY_MODE_TABLE =>
64-
getTables(tables)
65-
case BaseConfigConstants.QUERY_MODE_SQL =>
66-
getQueries(query)
67-
}
46+
val tableOrQueryInfos = getTableOrQueryInfos()
6847

6948
val mode = config.mode
7049
var offsets: util.Map[util.Map[String, String], util.Map[String, Object]] = null
7150
var incrementingCols: List[String] = List()
7251

7352
if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) {
7453
val partitions =
75-
new util.ArrayList[util.Map[String, String]](tables.length)
54+
new util.ArrayList[util.Map[String, String]](tableOrQueryInfos.length)
7655

7756
queryMode match {
7857
case BaseConfigConstants.QUERY_MODE_TABLE =>
7958
tableOrQueryInfos.foreach(tableInfo => {
8059
val partition = new util.HashMap[String, String]()
81-
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableInfo._3)
60+
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableInfo._1}${tableInfo._2}")
8261
partitions.add(partition)
83-
incrementingCols :+= config.topicProperties(tableInfo._4)("incrementing.column.name")
62+
incrementingCols :+= config.topicProperties(tableInfo._3)("incrementing.column.name")
8463
})
8564
case BaseConfigConstants.QUERY_MODE_SQL =>
8665
tableOrQueryInfos.foreach(queryInfo => {
8766
val partition = new util.HashMap[String, String]()
8867
partition.put(SourceConnectorConstants.QUERY_NAME_KEY, queryInfo._1)
8968
partitions.add(partition)
90-
incrementingCols :+= config.topicProperties(queryInfo._4)("incrementing.column.name")
69+
incrementingCols :+= config.topicProperties(queryInfo._3)("incrementing.column.name")
9170
})
9271

9372
}
@@ -99,7 +78,7 @@ abstract class GenericSourceTask extends SourceTask {
9978
val partition = new util.HashMap[String, String]()
10079
queryMode match {
10180
case BaseConfigConstants.QUERY_MODE_TABLE =>
102-
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, tableOrQueryInfo._3)
81+
partition.put(SourceConnectorConstants.TABLE_NAME_KEY, s"${tableOrQueryInfo._1}${tableOrQueryInfo._2}")
10382
case BaseConfigConstants.QUERY_MODE_SQL =>
10483
partition.put(SourceConnectorConstants.QUERY_NAME_KEY, tableOrQueryInfo._1)
10584
case _ =>
@@ -108,13 +87,11 @@ abstract class GenericSourceTask extends SourceTask {
10887

10988
val offset = if (offsets == null) null else offsets.get(partition)
11089

111-
val topic = tableOrQueryInfo._4
112-
11390
if (mode.equals(BaseConfigConstants.MODE_BULK)) {
114-
tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic,
91+
tableQueue += new BulkTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3,
11592
config, Some(jdbcClient))
11693
} else if (mode.equals(BaseConfigConstants.MODE_INCREMENTING)) {
117-
tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, topic,
94+
tableQueue += new IncrColTableQuerier(queryMode, tableOrQueryInfo._1, tableOrQueryInfo._2, tableOrQueryInfo._3,
11895
incrementingCols(count),
11996
if (offset == null) null else offset.asScala.toMap,
12097
config, Some(jdbcClient))
@@ -182,11 +159,14 @@ abstract class GenericSourceTask extends SourceTask {
182159
null
183160
}
184161

185-
protected def getTables(tables: List[Tuple2[String, String]])
186-
: List[Tuple4[String, Int, String, String]]
187-
188-
protected def getQueries(query: List[(String, String)])
189-
: List[Tuple4[String, Int, String, String]]
162+
def getTableOrQueryInfos(): List[Tuple3[String, Int, String]] = {
163+
val props = configRawProperties.get
164+
props.asScala.filter(p => p._1.startsWith("_tqinfos.") && p._1.endsWith(".name")).map(
165+
t => Tuple3(
166+
t._2,
167+
props.get(t._1.replace("name", "partition")).toInt,
168+
props.get(t._1.replace("name", "topic")))).toList
169+
}
190170

191171
protected def createJdbcClient(): HANAJdbcClient
192172
}
Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,60 @@
11
package com.sap.kafka.connect.source.hana
22

3+
import com.sap.kafka.client.hana.{HANAConfigInvalidInputException, HANAConfigMissingException, HANAJdbcClient}
4+
import com.sap.kafka.connect.config.BaseConfigConstants
5+
import com.sap.kafka.connect.config.hana.{HANAConfig, HANAParameters}
6+
import com.sap.kafka.utils.ExecuteWithExceptions
7+
38
import java.util
4-
import org.apache.kafka.common.config.ConfigDef
9+
import org.apache.kafka.common.config.{ConfigDef, ConfigException}
510
import org.apache.kafka.connect.connector.Task
6-
import org.apache.kafka.connect.source.SourceConnector
11+
import org.apache.kafka.connect.errors.ConnectException
12+
import org.apache.kafka.connect.source.{SourceConnector, SourceConnectorContext}
713

814
import scala.collection.JavaConverters._
915

1016
class HANASourceConnector extends SourceConnector {
11-
private var configProperties: Option[util.Map[String, String]] = None
12-
17+
private var configRawProperties: Option[util.Map[String, String]] = None
18+
private var hanaClient: HANAJdbcClient = _
19+
private var tableOrQueryInfos: List[Tuple3[String, Int, String]] = _
20+
private var configProperties: HANAConfig = _
21+
override def context(): SourceConnectorContext = super.context()
1322
override def version(): String = getClass.getPackage.getImplementationVersion
1423

1524
override def start(properties: util.Map[String, String]): Unit = {
16-
configProperties = Some(properties)
25+
configRawProperties = Some(properties)
26+
configProperties = HANAParameters.getConfig(properties)
27+
hanaClient = new HANAJdbcClient(configProperties)
28+
29+
val topics = configProperties.topics
30+
var tables: List[(String, String)] = Nil
31+
if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("table.name"))) {
32+
tables = topics.map(topic =>
33+
(configProperties.topicProperties(topic)("table.name"), topic))
34+
}
35+
var query: List[(String, String)] = Nil
36+
if (topics.forall(topic => configProperties.topicProperties(topic).keySet.contains("query"))) {
37+
query = topics.map(topic =>
38+
(configProperties.topicProperties(topic)("query"), topic))
39+
}
40+
41+
if (tables.isEmpty && query.isEmpty) {
42+
throw new ConnectException("Invalid configuration: each HANAConnector must have one table or query associated")
43+
}
44+
45+
tableOrQueryInfos = configProperties.queryMode match {
46+
case BaseConfigConstants.QUERY_MODE_TABLE =>
47+
getTables(hanaClient, tables)
48+
case BaseConfigConstants.QUERY_MODE_SQL =>
49+
getQueries(query)
50+
}
1751
}
1852

1953
override def taskClass(): Class[_ <: Task] = classOf[HANASourceTask]
2054

2155
override def taskConfigs(maxTasks: Int): util.List[util.Map[String, String]] = {
22-
(1 to maxTasks).map(c => configProperties.get).toList.asJava
56+
val tableOrQueryGroups = createTableOrQueryGroups(tableOrQueryInfos, maxTasks)
57+
createTaskConfigs(tableOrQueryGroups, configRawProperties.get).asJava
2358
}
2459

2560
override def stop(): Unit = {
@@ -29,4 +64,92 @@ class HANASourceConnector extends SourceConnector {
2964
override def config(): ConfigDef = {
3065
new ConfigDef
3166
}
67+
68+
private def getTables(hanaClient: HANAJdbcClient, tables: List[Tuple2[String, String]]) : List[Tuple3[String, Int, String]] = {
69+
val connection = hanaClient.getConnection
70+
71+
// contains fullTableName, partitionNum, topicName
72+
var tableInfos: List[Tuple3[String, Int, String]] = List()
73+
val noOfTables = tables.size
74+
var tablecount = 1
75+
76+
var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE "
77+
tables.foreach(table => {
78+
if (!(configProperties.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) {
79+
table._1 match {
80+
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) =>
81+
stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')"
82+
83+
if (tablecount < noOfTables) {
84+
stmtToFetchPartitions += " OR "
85+
}
86+
tablecount = tablecount + 1
87+
case _ =>
88+
throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions")
89+
}
90+
}
91+
})
92+
93+
if (tablecount > 1) {
94+
val stmt = connection.createStatement()
95+
val partitionRs = stmt.executeQuery(stmtToFetchPartitions)
96+
97+
while (partitionRs.next()) {
98+
val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\""
99+
tableInfos :+= Tuple3(tableName, partitionRs.getInt(3),
100+
tables.filter(table => table._1 == tableName).map(table => table._2).head.toString)
101+
}
102+
}
103+
104+
// fill tableInfo for tables whose entry is not in M_CS_PARTITIONS
105+
val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1)
106+
val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1))
107+
108+
tablesToBeAdded.foreach(tableToBeAdded => {
109+
if (configProperties.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) {
110+
tableInfos :+= Tuple3(getTableName(tableToBeAdded._1)._2, 0, tableToBeAdded._2)
111+
} else {
112+
tableInfos :+= Tuple3(tableToBeAdded._1, 0, tableToBeAdded._2)
113+
}
114+
})
115+
116+
tableInfos
117+
}
118+
119+
private def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String)] =
120+
queryTuple.map(query => (query._1, 0, query._2))
121+
122+
private def createTableOrQueryGroups(tableOrQueryInfos: List[Tuple3[String, Int, String]], count: Int)
123+
: List[List[Tuple3[String, Int, String]]] = {
124+
val groupSize = count match {
125+
case c if c > tableOrQueryInfos.size => 1
126+
case _ => ((tableOrQueryInfos.size + count - 1) / count)
127+
}
128+
tableOrQueryInfos.grouped(groupSize).toList
129+
}
130+
131+
private def createTaskConfigs(tableOrQueryGroups: List[List[Tuple3[String, Int, String]]], config: java.util.Map[String, String])
132+
: List[java.util.Map[String, String]] = {
133+
tableOrQueryGroups.map(g => {
134+
var gconfig = new java.util.HashMap[String,String](config)
135+
for ((t, i) <- g.zipWithIndex) {
136+
gconfig.put(s"_tqinfos.$i.name", t._1)
137+
gconfig.put(s"_tqinfos.$i.partition", t._2.toString)
138+
gconfig.put(s"_tqinfos.$i.topic", t._3)
139+
}
140+
gconfig
141+
})
142+
}
143+
144+
private def getTableName(tableName: String): (Option[String], String) = {
145+
tableName match {
146+
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) =>
147+
(Some(schema), table)
148+
case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) =>
149+
(None, table)
150+
case _ =>
151+
throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." +
152+
s" Does not follow naming conventions")
153+
}
154+
}
32155
}

src/main/scala/com/sap/kafka/connect/source/hana/HANASourceTask.scala

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,79 +16,11 @@ class HANASourceTask extends GenericSourceTask {
1616

1717
override def version(): String = getClass.getPackage.getImplementationVersion
1818

19-
2019
override def createJdbcClient(): HANAJdbcClient = {
2120
config match {
2221
case hanaConfig: HANAConfig => new HANAJdbcClient(hanaConfig)
2322
case _ => throw new RuntimeException("Cannot create HANA Jdbc Client")
2423
}
2524
}
2625

27-
override def getTables(tables: List[Tuple2[String, String]])
28-
: List[Tuple4[String, Int, String, String]] = {
29-
val connection = jdbcClient.getConnection
30-
31-
// contains fullTableName, partitionNum, fullTableName + partitionNum, topicName
32-
var tableInfos: List[Tuple4[String, Int, String, String]] = List()
33-
val noOfTables = tables.size
34-
var tablecount = 1
35-
36-
var stmtToFetchPartitions = s"SELECT SCHEMA_NAME, TABLE_NAME, PARTITION FROM SYS.M_CS_PARTITIONS WHERE "
37-
tables.foreach(table => {
38-
if (!(config.topicProperties(table._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE)) {
39-
table._1 match {
40-
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, tablename) =>
41-
stmtToFetchPartitions += s"(SCHEMA_NAME = '$schema' AND TABLE_NAME = '$tablename')"
42-
43-
if (tablecount < noOfTables) {
44-
stmtToFetchPartitions += " OR "
45-
}
46-
tablecount = tablecount + 1
47-
case _ =>
48-
throw new HANAConfigInvalidInputException("The table name is invalid. Does not follow naming conventions")
49-
}
50-
}
51-
})
52-
53-
if (tablecount > 1) {
54-
val stmt = connection.createStatement()
55-
val partitionRs = stmt.executeQuery(stmtToFetchPartitions)
56-
57-
while (partitionRs.next()) {
58-
val tableName = "\"" + partitionRs.getString(1) + "\".\"" + partitionRs.getString(2) + "\""
59-
tableInfos :+= Tuple4(tableName, partitionRs.getInt(3), tableName + partitionRs.getInt(3),
60-
tables.filter(table => table._1 == tableName)
61-
.map(table => table._2).head.toString)
62-
}
63-
}
64-
65-
// fill tableInfo for tables whose entry is not in M_CS_PARTITIONS
66-
val tablesInInfo = tableInfos.map(tableInfo => tableInfo._1)
67-
val tablesToBeAdded = tables.filterNot(table => tablesInInfo.contains(table._1))
68-
69-
tablesToBeAdded.foreach(tableToBeAdded => {
70-
if (config.topicProperties(tableToBeAdded._2)("table.type") == BaseConfigConstants.COLLECTION_TABLE_TYPE) {
71-
tableInfos :+= Tuple4(getTableName(tableToBeAdded._1)._2, 0, getTableName(tableToBeAdded._1)._2 + "0", tableToBeAdded._2)
72-
} else {
73-
tableInfos :+= Tuple4(tableToBeAdded._1, 0, tableToBeAdded._1 + "0", tableToBeAdded._2)
74-
}
75-
})
76-
77-
tableInfos
78-
}
79-
80-
override protected def getQueries(queryTuple: List[(String, String)]): List[(String, Int, String, String)] =
81-
queryTuple.map(query => (query._1, 0, null, query._2))
82-
83-
private def getTableName(tableName: String): (Option[String], String) = {
84-
tableName match {
85-
case BaseConfigConstants.TABLE_NAME_FORMAT(schema, table) =>
86-
(Some(schema), table)
87-
case BaseConfigConstants.COLLECTION_NAME_FORMAT(table) =>
88-
(None, table)
89-
case _ =>
90-
throw new HANAConfigInvalidInputException(s"The table name mentioned in `{topic}.table.name` is invalid." +
91-
s" Does not follow naming conventions")
92-
}
93-
}
9426
}

src/test/scala/com/sap/kafka/connect/source/HANASourceTaskConversionTest.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.sap.kafka.connect.source
22

33
import com.sap.kafka.client.MetaSchema
4+
import com.sap.kafka.connect.source.hana.HANASourceConnector
45
import org.apache.kafka.connect.data.Schema.Type
56
import org.apache.kafka.connect.data.{Field, Schema, Struct}
67
import org.apache.kafka.connect.source.SourceRecord
@@ -11,11 +12,14 @@ class HANASourceTaskConversionTest extends HANASourceTaskTestBase {
1112

1213
override def beforeAll(): Unit = {
1314
super.beforeAll()
14-
task.start(singleTableConfig())
15+
connector = new HANASourceConnector
16+
connector.start(singleTableConfig())
17+
task.start(connector.taskConfigs(1).get(0))
1518
}
1619

1720
override def afterAll(): Unit = {
1821
task.stop()
22+
connector.stop()
1923
super.afterAll()
2024
}
2125

0 commit comments

Comments
 (0)