Skip to content
24 changes: 20 additions & 4 deletions src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ public override bool NextResult()
{
stmt = _stmtEnumerator.Current;

var connectionHandle = _command.Connection!.Handle;
var totalChangesBefore = sqlite3_total_changes(connectionHandle);

var timer = SharedStopwatch.StartNew();

while (IsBusy(rc = sqlite3_step(stmt)))
Expand All @@ -172,7 +175,7 @@ public override bool NextResult()

_totalElapsedTime += timer.Elapsed;

SqliteException.ThrowExceptionForRC(rc, _command.Connection!.Handle);
SqliteException.ThrowExceptionForRC(rc, connectionHandle);

// It's a SELECT statement
if (sqlite3_column_count(stmt) != 0)
Expand All @@ -185,13 +188,26 @@ public override bool NextResult()
while (rc != SQLITE_DONE)
{
rc = sqlite3_step(stmt);
SqliteException.ThrowExceptionForRC(rc, _command.Connection.Handle);
SqliteException.ThrowExceptionForRC(rc, connectionHandle);
}

sqlite3_reset(stmt);

var changes = sqlite3_changes(_command.Connection.Handle);
AddChanges(changes);
// sqlite3_changes() returns the row count from the most recent INSERT, UPDATE, or DELETE
// and incorrectly persists across DDL statements. Use sqlite3_total_changes() before and after
// to calculate the actual delta for this statement, ensuring DDL statements don't add stale counts.
var totalChangesAfter = sqlite3_total_changes(connectionHandle);
var changes = totalChangesAfter - totalChangesBefore;
// sqlite3_total_changes, unfortunately, counts also changes from triggers, etc. which is not what we want.
// So we use it only to detect changes and if so, use sqlite3_changes.
if (changes > 0)
{
AddChanges(sqlite3_changes(connectionHandle));
}
else
{
AddChanges(0);
}
}
catch
{
Expand Down
2 changes: 1 addition & 1 deletion test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ public void CreateFunction_deterministic_param_works()
connection.ExecuteNonQuery("CREATE TABLE Data (Value); INSERT INTO Data VALUES (0);");
connection.CreateFunction("test", (double x) => x, true);

Assert.Equal(1, connection.ExecuteNonQuery("CREATE INDEX InvalidIndex ON Data (Value) WHERE test(Value) = 0;"));
Assert.Equal(0, connection.ExecuteNonQuery("CREATE INDEX InvalidIndex ON Data (Value) WHERE test(Value) = 0;"));
}

[Fact]
Expand Down
39 changes: 39 additions & 0 deletions test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,45 @@ public void RecordsAffected_works_with_returning_multiple()
}
}

[Fact]
public void RecordsAffected_not_affected_by_DDL_statements()
{
using (var connection = new SqliteConnection("Data Source=:memory:"))
{
connection.Open();

using (var reader = connection.ExecuteReader(
@"CREATE TABLE foo(bar TEXT NOT NULL);
CREATE TABLE xyz(aaa TEXT NOT NULL);
INSERT INTO foo(bar) VALUES('baz');
INSERT INTO foo(bar) VALUES('baz2');
DROP TABLE xyz;"))
{
Assert.Equal(2, reader.RecordsAffected);
}
}
}

[Fact]
public void RecordsAffected_not_affected_by_DDL_statements_with_drop_and_create()
{
using (var connection = new SqliteConnection("Data Source=:memory:"))
{
connection.Open();

using (var reader = connection.ExecuteReader(
@"CREATE TABLE foo(bar TEXT NOT NULL);
CREATE TABLE xyz(aaa TEXT NOT NULL);
INSERT INTO foo(bar) VALUES('baz');
INSERT INTO foo(bar) VALUES('baz2');
DROP TABLE xyz;
CREATE TABLE xyz(aaa TEXT NOT NULL);"))
{
Assert.Equal(2, reader.RecordsAffected);
}
}
}

[Fact]
public void GetSchemaTable_works()
{
Expand Down