Skip to content

Commit 90f867b

Browse files
committed
fix: parameter replacement needs to keep order
1 parent 23fae6e commit 90f867b

File tree

1 file changed

+52
-38
lines changed

1 file changed

+52
-38
lines changed

src/main/scala/dev/mongocamp/driver/mongodb/jdbc/statement/MongoPreparedStatement.scala

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import dev.mongocamp.driver.mongodb.exception.SqlCommandNotSupportedException
66
import dev.mongocamp.driver.mongodb.jdbc.{ MongoJdbcCloseable, MongoJdbcConnection }
77
import dev.mongocamp.driver.mongodb.jdbc.resultSet.MongoDbResultSet
88
import dev.mongocamp.driver.mongodb.sql.MongoSqlQueryHolder
9-
import org.mongodb.scala.bson.collection.immutable.Document
9+
import org.joda.time.DateTime
1010

1111
import java.io.{ InputStream, Reader }
1212
import java.net.URL
@@ -41,9 +41,9 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
4141

4242
private var _queryTimeout: Int = 10
4343
private var _sql: String = null
44-
private var _org_sql: String = null
4544
private var _lastResultSet: ResultSet = null
4645
private var _lastUpdateCount: Int = -1
46+
private lazy val parameters = mutable.Map[Int, String]()
4747

4848
override def execute(sql: String): Boolean = {
4949
checkClosed()
@@ -78,8 +78,7 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
7878
var response = queryHolder.run(connection.getDatabaseProvider).results(getQueryTimeout)
7979
if (response.isEmpty && queryHolder.hasFunctionCallInSelect) {
8080
val emptyDocument = mutable.Map[String, Any]()
81-
queryHolder.getKeysForEmptyDocument.foreach(
82-
key => emptyDocument.put(key, null))
81+
queryHolder.getKeysForEmptyDocument.foreach(key => emptyDocument.put(key, null))
8382
val doc = Converter.toDocument(emptyDocument.toMap)
8483
response = Seq(doc)
8584
}
@@ -102,11 +101,11 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
102101

103102
override def executeQuery(): ResultSet = {
104103
checkClosed()
105-
executeQuery(_sql)
104+
executeQuery(replaceParameters(_sql))
106105
}
107106

108107
override def executeUpdate(): Int = {
109-
executeUpdate(_sql)
108+
executeUpdate(replaceParameters(_sql))
110109
}
111110

112111
override def setNull(parameterIndex: Int, sqlType: Int): Unit = {
@@ -165,6 +164,7 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
165164

166165
override def setBytes(parameterIndex: Int, x: Array[Byte]): Unit = {
167166
checkClosed()
167+
setObject(parameterIndex, x)
168168
}
169169

170170
override def setDate(parameterIndex: Int, x: Date): Unit = {
@@ -194,40 +194,54 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
194194
checkClosed()
195195
}
196196

197-
override def clearParameters(): Unit = {
198-
checkClosed()
199-
_sql = _org_sql
200-
}
201-
202-
override def setObject(parameterIndex: Int, x: Any, targetSqlType: Int): Unit = {
203-
setObject(parameterIndex, x)
204-
}
205-
206-
override def setObject(parameterIndex: Int, x: Any): Unit = {
207-
checkClosed()
197+
private def replaceParameters(sql: String): String = {
208198
var newSql = ""
209-
var paramCount = 0
210-
_org_sql = _sql
211-
_sql.foreach(c => {
199+
var paramCount = 1
200+
sql.foreach(c => {
212201
var replace = false
213202
if (c == '?') {
214-
if (paramCount == parameterIndex) {
203+
if (parameters.contains(paramCount)) {
204+
newSql += parameters(paramCount)
215205
replace = true
216206
}
217207
paramCount += 1
218208
}
219-
if (replace) {
220-
newSql += x.toString
221-
}
222-
else {
209+
if (!replace) {
223210
newSql += c
224211
}
225212
})
226-
_sql = newSql
213+
newSql
214+
}
215+
216+
override def clearParameters(): Unit = {
217+
checkClosed()
218+
parameters.clear()
219+
}
220+
221+
override def setObject(parameterIndex: Int, x: Any, targetSqlType: Int): Unit = {
222+
setObject(parameterIndex, x)
223+
}
224+
225+
override def setObject(parameterIndex: Int, x: Any): Unit = {
226+
checkClosed()
227+
x match {
228+
case d: Date =>
229+
parameters.put(parameterIndex, s"'${d.toInstant.toString}'")
230+
case d: DateTime =>
231+
parameters.put(parameterIndex, s"'${d.toInstant.toString}'")
232+
case t: Time =>
233+
parameters.put(parameterIndex, s"'${t.toInstant.toString}'")
234+
case a: Array[Byte] =>
235+
parameters.put(parameterIndex, a.mkString("[", ",", "]"))
236+
case a: Iterable[_] =>
237+
parameters.put(parameterIndex, a.mkString("[", ",", "]"))
238+
case _ =>
239+
parameters.put(parameterIndex, x.toString)
240+
}
227241
}
228242

229243
override def execute(): Boolean = {
230-
execute(_sql)
244+
execute(replaceParameters(_sql))
231245
}
232246

233247
override def addBatch(): Unit = {
@@ -378,7 +392,7 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
378392
}
379393

380394
override def setMaxRows(max: Int): Unit = {
381-
sqlFeatureNotSupported()
395+
checkClosed()
382396
}
383397

384398
override def setEscapeProcessing(enable: Boolean): Unit = {
@@ -531,23 +545,23 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
531545

532546
override def wasNull(): Boolean = ???
533547

534-
override def getString(parameterIndex: Int): String = ???
548+
override def getString(parameterIndex: Int): String = parameters.get(parameterIndex).orNull
535549

536-
override def getBoolean(parameterIndex: Int): Boolean = ???
550+
override def getBoolean(parameterIndex: Int): Boolean = parameters.get(parameterIndex).flatMap(_.toBooleanOption).getOrElse(false)
537551

538-
override def getByte(parameterIndex: Int): Byte = ???
552+
override def getByte(parameterIndex: Int): Byte = parameters.get(parameterIndex).flatMap(_.toByteOption).getOrElse(0)
539553

540-
override def getShort(parameterIndex: Int): Short = ???
554+
override def getShort(parameterIndex: Int): Short = parameters.get(parameterIndex).flatMap(_.toShortOption).getOrElse(0)
541555

542-
override def getInt(parameterIndex: Int): Int = ???
556+
override def getInt(parameterIndex: Int): Int = parameters.get(parameterIndex).flatMap(_.toIntOption).getOrElse(0)
543557

544-
override def getLong(parameterIndex: Int): Long = ???
558+
override def getLong(parameterIndex: Int): Long = parameters.get(parameterIndex).flatMap(_.toLongOption).getOrElse(0)
545559

546-
override def getFloat(parameterIndex: Int): Float = ???
560+
override def getFloat(parameterIndex: Int): Float = parameters.get(parameterIndex).flatMap(_.toFloatOption).getOrElse(0.0.toFloat)
547561

548-
override def getDouble(parameterIndex: Int): Double = ???
562+
override def getDouble(parameterIndex: Int): Double = parameters.get(parameterIndex).flatMap(_.toDoubleOption).getOrElse(0.0)
549563

550-
override def getBigDecimal(parameterIndex: Int, scale: Int): java.math.BigDecimal = ???
564+
override def getBigDecimal(parameterIndex: Int, scale: Int): java.math.BigDecimal = getBigDecimal(parameterIndex)
551565

552566
override def getBytes(parameterIndex: Int): Array[Byte] = ???
553567

@@ -559,7 +573,7 @@ case class MongoPreparedStatement(connection: MongoJdbcConnection) extends Calla
559573

560574
override def getObject(parameterIndex: Int): AnyRef = ???
561575

562-
override def getBigDecimal(parameterIndex: Int): java.math.BigDecimal = ???
576+
override def getBigDecimal(parameterIndex: Int): java.math.BigDecimal = parameters.get(parameterIndex).flatMap(_.toDoubleOption).map(new java.math.BigDecimal(_)).orNull
563577

564578
override def getObject(parameterIndex: Int, map: util.Map[String, Class[_]]): AnyRef = ???
565579

0 commit comments

Comments
 (0)