Skip to content

Commit 7c50190

Browse files
committed
#62: Make composition of SQL expression easier and safer
* Introduced a new class `SqlEntry` which represents a part of and SQL expression * `SqlEntry` brings a number of methods and helper classes * `SqlEntryComposition` offers a way to compose a full SQL expression as `SqlEntry` in SQL format (almost) * `SqlEntry` is now leveraged in `SqlItem`, `DBFunction` and `DBTable`
1 parent b73ff73 commit 7c50190

File tree

11 files changed

+424
-68
lines changed

11 files changed

+424
-68
lines changed

balta/src/main/scala/za/co/absa/db/balta/classes/DBFunction.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package za.co.absa.db.balta.classes
1919
import za.co.absa.db.balta.classes.DBFunction.{DBFunctionWithNamedParamsToo, DBFunctionWithPositionedParamsOnly}
2020
import za.co.absa.db.balta.typeclasses.QueryParamType
2121
import za.co.absa.db.balta.classes.inner.Params.{NamedParams, OrderedParams}
22+
import za.co.absa.db.mag.core.SqlEntry
23+
import za.co.absa.db.mag.core.SqlEntryComposition._
2224

2325
/**
2426
* A class that represents a database function call. It can be used to execute a function and verify the result.
@@ -32,18 +34,18 @@ import za.co.absa.db.balta.classes.inner.Params.{NamedParams, OrderedParams}
3234
* @param namedParams - the list of parameters identified by their name (following the positioned parameters)
3335
*
3436
*/
35-
sealed abstract class DBFunction private(functionName: String,
37+
sealed abstract class DBFunction private(functionName: SqlEntry,
3638
orderedParams: OrderedParams,
3739
namedParams: NamedParams) extends DBQuerySupport {
3840

39-
private def sql(orderBy: String): String = {
41+
private def sql(orderBy: Option[SqlEntry]): SqlEntry = {
4042
val positionedParamEntries = orderedParams.values.map(_.sqlEntry)
4143
val namedParamEntries = namedParams.items.map{ case (columnName, queryParamValue) =>
42-
columnName.sqlEntry + " := " + queryParamValue.sqlEntry
44+
columnName.sqlEntry := queryParamValue.sqlEntry
4345
}
4446
val paramEntries = positionedParamEntries ++ namedParamEntries
4547
val paramsLine = paramEntries.mkString(",")
46-
s"SELECT * FROM $functionName($paramsLine) $orderBy"
48+
SELECT(ALL) FROM functionName(paramsLine) ORDER BY (orderBy)
4749
}
4850

4951
/**
@@ -92,7 +94,7 @@ sealed abstract class DBFunction private(functionName: String,
9294
* @return - the result of the verify function
9395
*/
9496
def execute[R](orderBy: String)(verify: QueryResult => R /* Assertion */)(implicit connection: DBConnection): R = {
95-
val orderByPart = if (orderBy.nonEmpty) {s"ORDER BY $orderBy"} else ""
97+
val orderByPart = SqlEntry(orderBy).toOption
9698
runQuery(sql(orderByPart), orderedParams.values ++ namedParams.values)(verify)
9799
}
98100

@@ -140,15 +142,15 @@ object DBFunction {
140142
* @return - a new instance of the DBFunction class
141143
*/
142144
def apply(functionName: String): DBFunctionWithPositionedParamsOnly = {
143-
DBFunctionWithPositionedParamsOnly(functionName)
145+
DBFunctionWithPositionedParamsOnly(SqlEntry(functionName))
144146
}
145147

146148
def apply(functionName: String, params: NamedParams): DBFunctionWithNamedParamsToo = {
147-
DBFunctionWithNamedParamsToo(functionName, OrderedParams(), params)
149+
DBFunctionWithNamedParamsToo(SqlEntry(functionName), OrderedParams(), params)
148150
}
149151

150152
def apply(functionName: String, params: OrderedParams): DBFunctionWithPositionedParamsOnly = {
151-
DBFunctionWithPositionedParamsOnly(functionName, params, NamedParams())
153+
DBFunctionWithPositionedParamsOnly(SqlEntry(functionName), params, NamedParams())
152154
}
153155

154156
/**
@@ -159,7 +161,7 @@ object DBFunction {
159161
* @param orderedParams - the list of parameters identified by their position (preceding the named parameters)
160162
* @param namedParams - the list of parameters identified by their name (following the positioned parameters)
161163
*/
162-
sealed case class DBFunctionWithPositionedParamsOnly private(functionName: String,
164+
sealed case class DBFunctionWithPositionedParamsOnly private(functionName: SqlEntry,
163165
orderedParams: OrderedParams = OrderedParams(),
164166
namedParams: NamedParams = NamedParams()
165167
) extends DBFunction(functionName, orderedParams, namedParams) {
@@ -195,7 +197,7 @@ object DBFunction {
195197
* @param orderedParams - the list of parameters identified by their position (preceding the named parameters)
196198
* @param namedParams - the list of parameters identified by their name (following the positioned parameters)
197199
*/
198-
sealed case class DBFunctionWithNamedParamsToo private(functionName: String,
200+
sealed case class DBFunctionWithNamedParamsToo private(functionName: SqlEntry,
199201
orderedParams: OrderedParams = OrderedParams(),
200202
namedParams: NamedParams = NamedParams()
201203
) extends DBFunction(functionName, orderedParams, namedParams)

balta/src/main/scala/za/co/absa/db/balta/classes/DBQuerySupport.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,17 @@
1717
package za.co.absa.db.balta.classes
1818

1919
import za.co.absa.db.balta.typeclasses.QueryParamValue
20+
import za.co.absa.db.mag.core.SqlEntry
2021

2122
/**
2223
* This is a based trait providing the ability to run an SQL query and verify the result via a provided function.
2324
*/
2425
trait DBQuerySupport {
2526

26-
protected def runQuery[R](sql: String, queryValues: Vector[QueryParamValue])
27+
protected def runQuery[R](sql: SqlEntry, queryValues: Vector[QueryParamValue])
2728
(verify: QueryResult => R /* Assertion */)
2829
(implicit connection: DBConnection): R = {
29-
val preparedStatement = connection.connection.prepareStatement(sql)
30+
val preparedStatement = connection.connection.prepareStatement(sql.entry)
3031

3132
queryValues.foldLeft(1) { case (parameterIndex, queryValue) =>
3233
queryValue.assign match { // this is better readable-wise than map + getOrElse

balta/src/main/scala/za/co/absa/db/balta/classes/DBTable.scala

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ package za.co.absa.db.balta.classes
1818

1919
import za.co.absa.db.balta.classes.inner.Params
2020
import za.co.absa.db.balta.classes.inner.Params.NamedParams
21-
import za.co.absa.db.balta.typeclasses.{QueryParamValue, QueryParamType}
21+
import za.co.absa.db.balta.typeclasses.{QueryParamType, QueryParamValue}
2222
import za.co.absa.db.balta.classes.inner.Params.OrderedParams
23+
import za.co.absa.db.mag.core.SqlEntry
24+
import za.co.absa.db.mag.core.SqlEntryComposition._
2325

2426
/**
2527
* This class represents a database table. It allows to perform INSERT, SELECT and COUNT operations on the table easily.
2628
*
2729
* @param tableName The name of the table
2830
*/
2931
case class DBTable(tableName: String) extends DBQuerySupport{
30-
32+
private val table: SqlEntry = SqlEntry(tableName)
3133
/**
3234
* Inserts a new row into the table.
3335
*
@@ -42,7 +44,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
4244
}
4345

4446
val paramStr = values.values.map(_.sqlEntry).mkString(",")
45-
val sql = s"INSERT INTO $tableName $columns VALUES($paramStr) RETURNING *;"
47+
val sql = INSERT INTO table(columns) VALUES(paramStr) RETURNING ALL
4648
runQuery(sql, values.values){_.next()}
4749
}
4850

@@ -78,7 +80,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
7880
* @return - the result of the verify function
7981
*/
8082
def where[R](params: NamedParams)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
81-
composeSelectAndRun(strToOption(paramsToWhereCondition(params)), None, params.values)(verify)
83+
composeSelectAndRun(paramsToWhereCondition(params).toOption, None, params.values)(verify)
8284
}
8385

8486
/**
@@ -91,7 +93,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
9193
* @return - the result of the verify function
9294
*/
9395
def where[R](params: NamedParams, orderBy: String)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
94-
composeSelectAndRun(strToOption(paramsToWhereCondition(params)), strToOption(orderBy), params.values)(verify)
96+
composeSelectAndRun(paramsToWhereCondition(params).toOption, SqlEntry(orderBy).toOption, params.values)(verify)
9597
}
9698

9799
/**
@@ -103,7 +105,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
103105
* @return - the result of the verify function
104106
*/
105107
def where[R](condition: String)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
106-
composeSelectAndRun(strToOption(condition), None)(verify)
108+
composeSelectAndRun(SqlEntry(condition).toOption, None)(verify)
107109
}
108110

109111
/**
@@ -116,7 +118,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
116118
* @return - the result of the verify function
117119
*/
118120
def where[R](condition: String, orderBy: String)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
119-
composeSelectAndRun(strToOption(condition), strToOption(orderBy))(verify)
121+
composeSelectAndRun(SqlEntry(condition).toOption, SqlEntry(orderBy).toOption)(verify)
120122
}
121123

122124
/**
@@ -141,27 +143,27 @@ case class DBTable(tableName: String) extends DBQuerySupport{
141143
* @return - the result of the verify function
142144
*/
143145
def all[R](orderBy: String)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
144-
composeSelectAndRun(None, strToOption(orderBy))(verify)
146+
composeSelectAndRun(None, SqlEntry(orderBy).toOption)(verify)
145147
}
146148

147149
def deleteWithCheck[R](verify: QueryResult => R)(implicit connection: DBConnection): R = {
148150
composeDeleteAndRun(None)(verify)
149151
}
150152

151153
def deleteWithCheck[R](whereParams: NamedParams)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
152-
composeDeleteAndRun(strToOption(paramsToWhereCondition(whereParams)), whereParams.values)(verify)
154+
composeDeleteAndRun(paramsToWhereCondition(whereParams).toOption, whereParams.values)(verify)
153155
}
154156

155157
def deleteWithCheck[R](whereCondition: String)(verify: QueryResult => R)(implicit connection: DBConnection): R = {
156-
composeDeleteAndRun(strToOption(whereCondition))(verify)
158+
composeDeleteAndRun(SqlEntry(whereCondition).toOption)(verify)
157159
}
158160

159161
def delete(whereParams: NamedParams)(implicit connection: DBConnection): Unit = {
160-
composeDeleteAndRun(strToOption(paramsToWhereCondition(whereParams)), whereParams.values)(_ => ())
162+
composeDeleteAndRun(paramsToWhereCondition(whereParams).toOption, whereParams.values)(_ => ())
161163
}
162164

163165
def delete(whereCondition: String = "")(implicit connection: DBConnection): Unit = {
164-
composeDeleteAndRun(strToOption(whereCondition))(_ => ())
166+
composeDeleteAndRun(SqlEntry(whereCondition).toOption)(_ => ())
165167
}
166168
/**
167169
* Counts the rows in the table.
@@ -180,7 +182,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
180182
*/
181183
@deprecated("Use countOnCondition instead", "0.2.0")
182184
def count(params: NamedParams)(implicit connection: DBConnection): Long = {
183-
composeCountAndRun(strToOption(paramsToWhereCondition(params)), params.values)
185+
composeCountAndRun(paramsToWhereCondition(params).toOption, params.values)
184186
}
185187

186188
/**
@@ -191,7 +193,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
191193
*/
192194
@deprecated("Use countOnCondition instead", "0.2.0")
193195
def count(condition: String)(implicit connection: DBConnection): Long = {
194-
composeCountAndRun(strToOption(condition))
196+
composeCountAndRun(SqlEntry(condition).toOption)
195197
}
196198

197199
/**
@@ -201,7 +203,7 @@ case class DBTable(tableName: String) extends DBQuerySupport{
201203
* @return - the number of rows
202204
*/
203205
def countOnCondition(params: NamedParams)(implicit connection: DBConnection): Long = {
204-
composeCountAndRun(strToOption(paramsToWhereCondition(params)), params.values)
206+
composeCountAndRun(paramsToWhereCondition(params).toOption, params.values)
205207
}
206208

207209
/**
@@ -211,47 +213,36 @@ case class DBTable(tableName: String) extends DBQuerySupport{
211213
* @return - the number of rows
212214
*/
213215
def countOnCondition(condition: String)(implicit connection: DBConnection): Long = {
214-
composeCountAndRun(strToOption(condition))
216+
composeCountAndRun(SqlEntry(condition).toOption)
215217
}
216218

217-
private def composeSelectAndRun[R](whereCondition: Option[String], orderByExpr: Option[String], values: Vector[QueryParamValue] = Vector.empty)
219+
private def composeSelectAndRun[R](whereCondition: Option[SqlEntry], orderBy: Option[SqlEntry], values: Vector[QueryParamValue] = Vector.empty)
218220
(verify: QueryResult => R)
219221
(implicit connection: DBConnection): R = {
220-
val where = whereCondition.map("WHERE " + _).getOrElse("")
221-
val orderBy = orderByExpr.map("ORDER BY " + _).getOrElse("")
222-
val sql = s"SELECT * FROM $tableName $where $orderBy;"
222+
val sql = SELECT(ALL) FROM table WHERE whereCondition ORDER BY(orderBy)
223223
runQuery(sql, values)(verify)
224224
}
225225

226-
private def composeDeleteAndRun[R](whereCondition: Option[String], values: Vector[QueryParamValue] = Vector.empty)
226+
private def composeDeleteAndRun[R](whereCondition: Option[SqlEntry], values: Vector[QueryParamValue] = Vector.empty)
227227
(verify: QueryResult => R)
228228
(implicit connection: DBConnection): R = {
229-
val where = whereCondition.map("WHERE " + _).getOrElse("")
230-
val sql = s"DELETE FROM $tableName $where RETURNING *;"
229+
val sql = DELETE FROM table WHERE whereCondition RETURNING ALL
231230
runQuery(sql, values)(verify)
232231
}
233232

234-
private def composeCountAndRun(whereCondition: Option[String], values: Vector[QueryParamValue] = Vector.empty)
233+
private def composeCountAndRun(whereCondition: Option[SqlEntry], values: Vector[QueryParamValue] = Vector.empty)
235234
(implicit connection: DBConnection): Long = {
236-
val where = whereCondition.map("WHERE " + _).getOrElse("")
237-
val sql = s"SELECT count(1) AS cnt FROM $tableName $where;"
235+
val sql = SELECT(COUNT_ALL) FROM table WHERE whereCondition
238236
runQuery(sql, values) {resultSet =>
239237
resultSet.next().getLong("cnt").getOrElse(0)
240238
}
241239
}
242240

243-
private def strToOption(str: String): Option[String] = {
244-
if (str.isEmpty) {
245-
None
246-
} else {
247-
Option(str)
248-
}
249-
}
250-
251-
private def paramsToWhereCondition(params: NamedParams): String = {
252-
params.items.foldRight(List.empty[String]) {case ((columnName, value), acc) =>
253-
val condition = s"${columnName.sqlEntry} ${value.equalityOperator} ${value.sqlEntry}"
241+
private def paramsToWhereCondition(params: NamedParams): SqlEntry = {
242+
val resultList = params.items.foldRight(List.empty[SqlEntry]) {case ((columnName, value), acc) =>
243+
val condition = columnName.sqlEntry + value.equalityOperator + value.sqlEntry
254244
condition :: acc
255-
}.mkString(" AND ")
245+
}
246+
resultList.toSqlEntry(" AND ")
256247
}
257248
}

balta/src/main/scala/za/co/absa/db/balta/typeclasses/QueryParamValue.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,21 @@
1616

1717
package za.co.absa.db.balta.typeclasses
1818

19-
import QueryParamValue.AssignFunc
19+
import QueryParamValue.{AssignFunc, sqlEquals, sqlQuestionMark}
20+
import za.co.absa.db.mag.core.SqlEntry
21+
2022
import java.sql.PreparedStatement
2123

2224
trait QueryParamValue {
2325
def assign: Option[AssignFunc]
24-
def sqlEntry: String = "?"
25-
def equalityOperator: String = "="
26+
def sqlEntry: SqlEntry = sqlQuestionMark
27+
def equalityOperator: SqlEntry = sqlEquals
2628
}
2729

2830
object QueryParamValue {
31+
private val sqlQuestionMark = SqlEntry("?")
32+
private val sqlEquals = SqlEntry("=")
33+
2934
type AssignFunc = (PreparedStatement, Int) => Unit
3035

3136
class ObjectQueryParamValue(obj: Object) extends QueryParamValue {
@@ -41,8 +46,8 @@ object QueryParamValue {
4146

4247
object NullParamValue extends QueryParamValue {
4348
override val assign: Option[AssignFunc] = None
44-
override val sqlEntry: String = "NULL"
45-
override val equalityOperator: String = "IS"
49+
override val sqlEntry: SqlEntry = SqlEntry("NULL")
50+
override val equalityOperator: SqlEntry = SqlEntry("IS")
4651
}
4752

4853
}

balta/src/main/scala/za/co/absa/db/mag/core/ColumnReference.scala

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,21 @@
1616

1717
package za.co.absa.db.mag.core
1818

19-
trait ColumnReference extends SqlItem
20-
21-
abstract class ColumnName extends ColumnReference{
22-
def enteredName: String
23-
def sqlEntry: String
19+
trait ColumnReference extends SqlItem {
2420
override def equals(obj: Any): Boolean = {
2521
obj match {
26-
case that: ColumnName => this.sqlEntry == that.sqlEntry
22+
case that: ColumnReference => this.sqlEntry == that.sqlEntry
2723
case _ => false
2824
}
2925
}
3026
override def hashCode(): Int = sqlEntry.hashCode
3127
}
3228

29+
abstract class ColumnName extends ColumnReference{
30+
def enteredName: String
31+
def sqlEntry: SqlEntry
32+
}
33+
3334
object ColumnReference {
3435
private val regularColumnNamePattern = "^([a-z_][a-z0-9_]*)$".r
3536
private val quotedRegularColumnNamePattern = "^\"([a-z_][a-z0-9_]*)\"$".r
@@ -42,9 +43,9 @@ object ColumnReference {
4243
val trimmedName = name.trim
4344
trimmedName match {
4445
case regularColumnNamePattern(columnName) => ColumnNameSimple(columnName) // column name per SQL standard, no quoting needed
45-
case quotedRegularColumnNamePattern(columnName) => ColumnNameExact(trimmedName, columnName) // quoted but regular name, remove quotes
46+
case quotedRegularColumnNamePattern(columnName) => ColumnNameExact(trimmedName, SqlEntry(columnName)) // quoted but regular name, remove quotes
4647
case quotedColumnNamePattern(_) => ColumnNameSimple(trimmedName) // quoted name, use as is
47-
case _ => ColumnNameExact(trimmedName, quote(escapeQuote(trimmedName))) // needs quoting and perhaps escaping
48+
case _ => ColumnNameExact(trimmedName, SqlEntry(quote(escapeQuote(trimmedName)))) // needs quoting and perhaps escaping
4849
}
4950
}
5051

@@ -55,13 +56,13 @@ object ColumnReference {
5556
def unapply(columnName: ColumnName): String = columnName.enteredName
5657

5758
final case class ColumnNameSimple private(enteredName: String) extends ColumnName {
58-
override def sqlEntry: String = enteredName
59+
override def sqlEntry: SqlEntry = SqlEntry(enteredName)
5960
}
6061

61-
final case class ColumnNameExact private(enteredName: String, sqlEntry: String) extends ColumnName
62+
final case class ColumnNameExact private(enteredName: String, sqlEntry: SqlEntry) extends ColumnName
6263

6364
final case class ColumnIndex private(index: Int) extends ColumnReference {
64-
val sqlEntry: String = index.toString
65+
val sqlEntry: SqlEntry = SqlEntry(index.toString)
6566
}
6667
}
6768

0 commit comments

Comments
 (0)