diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index bf34fb2d8..4632e6fc4 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -788,6 +788,17 @@ boolean onRetValue(TDSReader tdsReader) throws SQLServerException { return false; } + /** + * Override TDS token processing behavior for PreparedStatement. + * For regular Statement, the execute API for INSERT requires reading an additional explicit + * TDS_DONE token that contains the actual update count returned by the server. + * PreparedStatement does not require this additional token processing. + */ + @Override + protected boolean hasUpdateCountTDSTokenForInsertCmd() { + return false; + } + /** * Sends the statement parameters by RPC. */ diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java index 2b4d9865a..12997df49 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java @@ -1601,8 +1601,8 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException { if (null != procedureName) return false; - // For Insert, we must fetch additional TDS_DONE token that comes with the actual update count - if ((StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount()) + // For Insert operations, check if additional TDS_DONE token processing is required. + if (hasUpdateCountTDSTokenForInsertCmd() && (StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount()) && EXECUTE == executeMethod) { return true; } @@ -1845,6 +1845,19 @@ boolean consumeExecOutParam(TDSReader tdsReader) throws SQLServerException { return false; } + /** + * Determines whether to continue processing additional TDS_DONE tokens for INSERT statements. + * For INSERT operations, regular Statement requires reading an additional TDS_DONE token that contains + * the actual update count. This method can be overridden by subclasses to customize + * TDS token processing behavior. + * + * @return true to continue processing more tokens to get the actual update count for INSERT operations + */ + protected boolean hasUpdateCountTDSTokenForInsertCmd() { + // For Insert, we must fetch additional TDS_DONE token that comes with the actual update count + return true; + } + // --------------------------JDBC 2.0----------------------------- @Override diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java index fea7f08ac..5ac5116a1 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java @@ -2961,7 +2961,7 @@ public void testExecuteInsertManyRowsAndSelect() { // no more results break; } else { - assertEquals(count, 3, "update count should have been 6"); + assertEquals(count, 3, "update count should have been 3"); } } else { // process ResultSet @@ -2998,7 +2998,7 @@ public void testExecuteTwoInsertsRowsAndSelect() { // no more results break; } else { - assertEquals(count, 1, "update count should have been 2"); + assertEquals(count, 1, "update count should have been 1"); } } else { // process ResultSet @@ -3091,6 +3091,239 @@ public void testExecuteDelAndSelect() { } } + /** + * Tests multi-statement PreparedStatement with loop to process all results + * + * @throws SQLException + */ + @Test + public void testMultiStatementPreparedStatementLoopResults() throws SQLException { + try (Connection con = getConnection()) { + try (PreparedStatement ps = con.prepareStatement("DELETE FROM " + tableName + " " + + "INSERT INTO " + tableName + " (NAME) VALUES (?) " + + "INSERT INTO " + tableName + " (NAME) VALUES (?) " + + "UPDATE " + tableName + " SET NAME = 'updated' " + + "INSERT INTO " + tableName + " (NAME) VALUES (?) " + + "INSERT INTO " + tableName + " (NAME) VALUES (?) " + + "SELECT * FROM " + tableName)) { + + ps.setString(1, "test1"); + ps.setString(2, "test2"); + ps.setString(3, "test3"); + ps.setString(4, "test4"); + + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertTrue(count >= 0, "update count should be non-negative: " + count); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + int rowCount = 0; + while (rs.next()) { + String name = rs.getString("NAME"); + assertTrue(name != null, "name should not be null"); + rowCount++; + } + assertEquals(4, rowCount, "Expected 4 rows in result set"); + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + + /** + * Tests PreparedStatement execute for Insert followed by select + * + * @throws SQLException + */ + @Test + public void testPreparedStatementExecuteInsertAndSelect() throws SQLException { + try (Connection con = getConnection()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES(?) " + + "SELECT NAME FROM " + tableName + " WHERE ID = 1"; + try (PreparedStatement ps = con.prepareStatement(sql)) { + ps.setString(1, "test"); + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + + /** + * Tests PreparedStatement execute for Merge followed by select + * + * @throws SQLException + */ + @Test + public void testPreparedStatementExecuteMergeAndSelect() throws SQLException { + try (Connection con = getConnection()) { + String sql = "MERGE INTO " + tableName + " AS target " + + "USING (VALUES (?)) AS source (name) " + + "ON target.name = source.name " + + "WHEN NOT MATCHED THEN INSERT (name) VALUES (?); " + + "SELECT NAME FROM " + tableName + " WHERE ID = 1"; + try (PreparedStatement ps = con.prepareStatement(sql)) { + ps.setString(1, "test1"); + ps.setString(2, "test1"); + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + + /** + * Tests PreparedStatement execute two Inserts followed by select + * + * @throws SQLException + */ + @Test + public void testPreparedStatementExecuteTwoInsertsRowsAndSelect() throws SQLException { + try (Connection con = getConnection()) { + try (PreparedStatement ps = con.prepareStatement("INSERT INTO " + tableName + " (NAME) VALUES(?) INSERT INTO " + tableName + " (NAME) VALUES(?) SELECT NAME from " + tableName + " WHERE ID = 1")) { + ps.setString(1, "test"); + ps.setString(2, "test"); + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + + /** + * Tests PreparedStatement execute for Update followed by select + * + * @throws SQLException + */ + @Test + public void testPreparedStatementExecuteUpdAndSelect() throws SQLException { + try (Connection con = getConnection()) { + try (PreparedStatement ps = con.prepareStatement("UPDATE " + tableName + " SET NAME = ? SELECT NAME FROM " + tableName + " WHERE ID = 1")) { + ps.setString(1, "test"); + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 3, "update count should have been 3"); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + + /** + * Tests PreparedStatement execute for Delete followed by select + * + * @throws SQLException + */ + @Test + public void testPreparedStatementExecuteDelAndSelect() throws SQLException { + try (Connection con = getConnection()) { + try (PreparedStatement ps = con.prepareStatement("DELETE FROM " + tableName + " WHERE ID = ? SELECT NAME FROM " + tableName + " WHERE ID = 2")) { + ps.setInt(1, 1); + boolean retval = ps.execute(); + do { + if (!retval) { + int count = ps.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = ps.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = ps.getMoreResults(); + } while (true); + } + } + } + @AfterEach public void terminate() { try (Connection con = getConnection(); Statement stmt = con.createStatement()) {