Skip to content

Commit c734710

Browse files
authored
Fix update count handling for multi-statement queries executed via PreparedStatement. (#2722) (#2737)
* Fix update count handling for multi-statement queries executed via PreparedStatement. * Addressed PR review comments. * Remove redundant exception handling
1 parent 44d6010 commit c734710

File tree

3 files changed

+261
-4
lines changed

3 files changed

+261
-4
lines changed

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,17 @@ boolean onRetValue(TDSReader tdsReader) throws SQLServerException {
788788
return false;
789789
}
790790

791+
/**
792+
* Override TDS token processing behavior for PreparedStatement.
793+
* For regular Statement, the execute API for INSERT requires reading an additional explicit
794+
* TDS_DONE token that contains the actual update count returned by the server.
795+
* PreparedStatement does not require this additional token processing.
796+
*/
797+
@Override
798+
protected boolean hasUpdateCountTDSTokenForInsertCmd() {
799+
return false;
800+
}
801+
791802
/**
792803
* Sends the statement parameters by RPC.
793804
*/

src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,8 +1601,8 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException {
16011601
if (null != procedureName)
16021602
return false;
16031603

1604-
// For Insert, we must fetch additional TDS_DONE token that comes with the actual update count
1605-
if ((StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount())
1604+
// For Insert operations, check if additional TDS_DONE token processing is required.
1605+
if (hasUpdateCountTDSTokenForInsertCmd() && (StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount())
16061606
&& EXECUTE == executeMethod) {
16071607
return true;
16081608
}
@@ -1845,6 +1845,19 @@ boolean consumeExecOutParam(TDSReader tdsReader) throws SQLServerException {
18451845
return false;
18461846
}
18471847

1848+
/**
1849+
* Determines whether to continue processing additional TDS_DONE tokens for INSERT statements.
1850+
* For INSERT operations, regular Statement requires reading an additional TDS_DONE token that contains
1851+
* the actual update count. This method can be overridden by subclasses to customize
1852+
* TDS token processing behavior.
1853+
*
1854+
* @return true to continue processing more tokens to get the actual update count for INSERT operations
1855+
*/
1856+
protected boolean hasUpdateCountTDSTokenForInsertCmd() {
1857+
// For Insert, we must fetch additional TDS_DONE token that comes with the actual update count
1858+
return true;
1859+
}
1860+
18481861
// --------------------------JDBC 2.0-----------------------------
18491862

18501863
@Override

src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java

Lines changed: 235 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2961,7 +2961,7 @@ public void testExecuteInsertManyRowsAndSelect() {
29612961
// no more results
29622962
break;
29632963
} else {
2964-
assertEquals(count, 3, "update count should have been 6");
2964+
assertEquals(count, 3, "update count should have been 3");
29652965
}
29662966
} else {
29672967
// process ResultSet
@@ -2998,7 +2998,7 @@ public void testExecuteTwoInsertsRowsAndSelect() {
29982998
// no more results
29992999
break;
30003000
} else {
3001-
assertEquals(count, 1, "update count should have been 2");
3001+
assertEquals(count, 1, "update count should have been 1");
30023002
}
30033003
} else {
30043004
// process ResultSet
@@ -3091,6 +3091,239 @@ public void testExecuteDelAndSelect() {
30913091
}
30923092
}
30933093

3094+
/**
3095+
* Tests multi-statement PreparedStatement with loop to process all results
3096+
*
3097+
* @throws SQLException
3098+
*/
3099+
@Test
3100+
public void testMultiStatementPreparedStatementLoopResults() throws SQLException {
3101+
try (Connection con = getConnection()) {
3102+
try (PreparedStatement ps = con.prepareStatement("DELETE FROM " + tableName + " " +
3103+
"INSERT INTO " + tableName + " (NAME) VALUES (?) " +
3104+
"INSERT INTO " + tableName + " (NAME) VALUES (?) " +
3105+
"UPDATE " + tableName + " SET NAME = 'updated' " +
3106+
"INSERT INTO " + tableName + " (NAME) VALUES (?) " +
3107+
"INSERT INTO " + tableName + " (NAME) VALUES (?) " +
3108+
"SELECT * FROM " + tableName)) {
3109+
3110+
ps.setString(1, "test1");
3111+
ps.setString(2, "test2");
3112+
ps.setString(3, "test3");
3113+
ps.setString(4, "test4");
3114+
3115+
boolean retval = ps.execute();
3116+
do {
3117+
if (!retval) {
3118+
int count = ps.getUpdateCount();
3119+
if (count == -1) {
3120+
// no more results
3121+
break;
3122+
} else {
3123+
assertTrue(count >= 0, "update count should be non-negative: " + count);
3124+
}
3125+
} else {
3126+
// process ResultSet
3127+
try (ResultSet rs = ps.getResultSet()) {
3128+
int rowCount = 0;
3129+
while (rs.next()) {
3130+
String name = rs.getString("NAME");
3131+
assertTrue(name != null, "name should not be null");
3132+
rowCount++;
3133+
}
3134+
assertEquals(4, rowCount, "Expected 4 rows in result set");
3135+
}
3136+
}
3137+
retval = ps.getMoreResults();
3138+
} while (true);
3139+
}
3140+
}
3141+
}
3142+
3143+
/**
3144+
* Tests PreparedStatement execute for Insert followed by select
3145+
*
3146+
* @throws SQLException
3147+
*/
3148+
@Test
3149+
public void testPreparedStatementExecuteInsertAndSelect() throws SQLException {
3150+
try (Connection con = getConnection()) {
3151+
String sql = "INSERT INTO " + tableName + " (NAME) VALUES(?) " +
3152+
"SELECT NAME FROM " + tableName + " WHERE ID = 1";
3153+
try (PreparedStatement ps = con.prepareStatement(sql)) {
3154+
ps.setString(1, "test");
3155+
boolean retval = ps.execute();
3156+
do {
3157+
if (!retval) {
3158+
int count = ps.getUpdateCount();
3159+
if (count == -1) {
3160+
// no more results
3161+
break;
3162+
} else {
3163+
assertEquals(count, 1, "update count should have been 1");
3164+
}
3165+
} else {
3166+
// process ResultSet
3167+
try (ResultSet rs = ps.getResultSet()) {
3168+
if (rs.next()) {
3169+
String val = rs.getString(1);
3170+
assertEquals(val, "test", "read value should have been 'test'");
3171+
}
3172+
}
3173+
}
3174+
retval = ps.getMoreResults();
3175+
} while (true);
3176+
}
3177+
}
3178+
}
3179+
3180+
/**
3181+
* Tests PreparedStatement execute for Merge followed by select
3182+
*
3183+
* @throws SQLException
3184+
*/
3185+
@Test
3186+
public void testPreparedStatementExecuteMergeAndSelect() throws SQLException {
3187+
try (Connection con = getConnection()) {
3188+
String sql = "MERGE INTO " + tableName + " AS target " +
3189+
"USING (VALUES (?)) AS source (name) " +
3190+
"ON target.name = source.name " +
3191+
"WHEN NOT MATCHED THEN INSERT (name) VALUES (?); " +
3192+
"SELECT NAME FROM " + tableName + " WHERE ID = 1";
3193+
try (PreparedStatement ps = con.prepareStatement(sql)) {
3194+
ps.setString(1, "test1");
3195+
ps.setString(2, "test1");
3196+
boolean retval = ps.execute();
3197+
do {
3198+
if (!retval) {
3199+
int count = ps.getUpdateCount();
3200+
if (count == -1) {
3201+
// no more results
3202+
break;
3203+
} else {
3204+
assertEquals(count, 1, "update count should have been 1");
3205+
}
3206+
} else {
3207+
// process ResultSet
3208+
try (ResultSet rs = ps.getResultSet()) {
3209+
if (rs.next()) {
3210+
String val = rs.getString(1);
3211+
assertEquals(val, "test", "read value should have been 'test'");
3212+
}
3213+
}
3214+
}
3215+
retval = ps.getMoreResults();
3216+
} while (true);
3217+
}
3218+
}
3219+
}
3220+
3221+
/**
3222+
* Tests PreparedStatement execute two Inserts followed by select
3223+
*
3224+
* @throws SQLException
3225+
*/
3226+
@Test
3227+
public void testPreparedStatementExecuteTwoInsertsRowsAndSelect() throws SQLException {
3228+
try (Connection con = getConnection()) {
3229+
try (PreparedStatement ps = con.prepareStatement("INSERT INTO " + tableName + " (NAME) VALUES(?) INSERT INTO " + tableName + " (NAME) VALUES(?) SELECT NAME from " + tableName + " WHERE ID = 1")) {
3230+
ps.setString(1, "test");
3231+
ps.setString(2, "test");
3232+
boolean retval = ps.execute();
3233+
do {
3234+
if (!retval) {
3235+
int count = ps.getUpdateCount();
3236+
if (count == -1) {
3237+
// no more results
3238+
break;
3239+
} else {
3240+
assertEquals(count, 1, "update count should have been 1");
3241+
}
3242+
} else {
3243+
// process ResultSet
3244+
try (ResultSet rs = ps.getResultSet()) {
3245+
if (rs.next()) {
3246+
String val = rs.getString(1);
3247+
assertEquals(val, "test", "read value should have been 'test'");
3248+
}
3249+
}
3250+
}
3251+
retval = ps.getMoreResults();
3252+
} while (true);
3253+
}
3254+
}
3255+
}
3256+
3257+
/**
3258+
* Tests PreparedStatement execute for Update followed by select
3259+
*
3260+
* @throws SQLException
3261+
*/
3262+
@Test
3263+
public void testPreparedStatementExecuteUpdAndSelect() throws SQLException {
3264+
try (Connection con = getConnection()) {
3265+
try (PreparedStatement ps = con.prepareStatement("UPDATE " + tableName + " SET NAME = ? SELECT NAME FROM " + tableName + " WHERE ID = 1")) {
3266+
ps.setString(1, "test");
3267+
boolean retval = ps.execute();
3268+
do {
3269+
if (!retval) {
3270+
int count = ps.getUpdateCount();
3271+
if (count == -1) {
3272+
// no more results
3273+
break;
3274+
} else {
3275+
assertEquals(count, 3, "update count should have been 3");
3276+
}
3277+
} else {
3278+
// process ResultSet
3279+
try (ResultSet rs = ps.getResultSet()) {
3280+
if (rs.next()) {
3281+
String val = rs.getString(1);
3282+
assertEquals(val, "test", "read value should have been 'test'");
3283+
}
3284+
}
3285+
}
3286+
retval = ps.getMoreResults();
3287+
} while (true);
3288+
}
3289+
}
3290+
}
3291+
3292+
/**
3293+
* Tests PreparedStatement execute for Delete followed by select
3294+
*
3295+
* @throws SQLException
3296+
*/
3297+
@Test
3298+
public void testPreparedStatementExecuteDelAndSelect() throws SQLException {
3299+
try (Connection con = getConnection()) {
3300+
try (PreparedStatement ps = con.prepareStatement("DELETE FROM " + tableName + " WHERE ID = ? SELECT NAME FROM " + tableName + " WHERE ID = 2")) {
3301+
ps.setInt(1, 1);
3302+
boolean retval = ps.execute();
3303+
do {
3304+
if (!retval) {
3305+
int count = ps.getUpdateCount();
3306+
if (count == -1) {
3307+
// no more results
3308+
break;
3309+
} else {
3310+
assertEquals(count, 1, "update count should have been 1");
3311+
}
3312+
} else {
3313+
// process ResultSet
3314+
try (ResultSet rs = ps.getResultSet()) {
3315+
if (rs.next()) {
3316+
String val = rs.getString(1);
3317+
assertEquals(val, "test", "read value should have been 'test'");
3318+
}
3319+
}
3320+
}
3321+
retval = ps.getMoreResults();
3322+
} while (true);
3323+
}
3324+
}
3325+
}
3326+
30943327
@AfterEach
30953328
public void terminate() {
30963329
try (Connection con = getConnection(); Statement stmt = con.createStatement()) {

0 commit comments

Comments
 (0)