From 1b34521049e9994935cf3a0a9b761ce473355e8f Mon Sep 17 00:00:00 2001 From: Ricardo Peres Date: Wed, 12 Nov 2014 17:48:46 +0000 Subject: [PATCH] NH-3659 and unit tests --- src/NHibernate.Test/Linq/DeleteTests.cs | 49 +++++++++++ src/NHibernate.Test/NHibernate.Test.csproj | 1 + src/NHibernate/Linq/LinqExtensionMethods.cs | 98 +++++++++++++++++++++ 3 files changed, 148 insertions(+) create mode 100644 src/NHibernate.Test/Linq/DeleteTests.cs diff --git a/src/NHibernate.Test/Linq/DeleteTests.cs b/src/NHibernate.Test/Linq/DeleteTests.cs new file mode 100644 index 00000000000..95f05b0b69a --- /dev/null +++ b/src/NHibernate.Test/Linq/DeleteTests.cs @@ -0,0 +1,49 @@ +using System.Linq; +using NHibernate.DomainModel.Northwind.Entities; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.Linq +{ + [TestFixture] + public class DeleteTests : LinqTestCase + { + [Test] + public void CanDeleteSimpleExpression() + { + //NH-3659 + using (this.session.BeginTransaction()) + { + var beforeDeleteCount = this.session.Query().Count(u => u.Id > 0); + + var deletedCount = this.session.Delete(u => u.Id > 0); + + var afterDeleteCount = this.session.Query().Count(u => u.Id > 0); + + Assert.AreEqual(beforeDeleteCount, deletedCount); + + Assert.AreEqual(0, afterDeleteCount); + } + } + + [Test] + public void CanDeleteComplexExpression() + { + //NH-3659 + using (this.session.BeginTransaction()) + { + var cities = new string[] { "Paris", "Madrid" }; + + var beforeDeleteCount = this.session.Query().Count(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City)); + + var deletedCount = this.session.Delete(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City)); + + var afterDeleteCount = this.session.Query().Count(c => c.Orders.Count() == 0 && cities.Contains(c.Address.City)); + + Assert.AreEqual(beforeDeleteCount, deletedCount); + + Assert.AreEqual(0, afterDeleteCount); + } + } + } +} diff --git a/src/NHibernate.Test/NHibernate.Test.csproj b/src/NHibernate.Test/NHibernate.Test.csproj index 2533d1bd023..e5ad2804735 100644 --- a/src/NHibernate.Test/NHibernate.Test.csproj +++ b/src/NHibernate.Test/NHibernate.Test.csproj @@ -517,6 +517,7 @@ + diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index e9ee33324f0..9cbc2448f8b 100755 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -1,8 +1,16 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using System.Reflection; +using System.Text; +using System.Text.RegularExpressions; +using NHibernate.Engine; +using NHibernate.Exceptions; +using NHibernate.Hql.Ast.ANTLR; using NHibernate.Impl; +using NHibernate.SqlCommand; using Remotion.Linq; using Remotion.Linq.Parsing.ExpressionTreeVisitors; @@ -10,6 +18,96 @@ namespace NHibernate.Linq { public static class LinqExtensionMethods { + public static Int32 Delete(this ISession session, Expression> condition) + { + //these could be cached as static readonly fields + var instanceBindingFlags = BindingFlags.NonPublic | BindingFlags.Public | BindingFlags.Instance; + var staticBindingFlags = BindingFlags.Public | BindingFlags.Static; + var selectMethod = typeof(Queryable).GetMethods(staticBindingFlags).First(x => x.Name == "Select"); + var whereMethod = typeof(Queryable).GetMethods(staticBindingFlags).First(x => x.Name == "Where"); + var translatorFactory = new ASTQueryTranslatorFactory(); + var aliasRegex = new Regex(" from (\\w+) (\\w+) "); + var parameterTokensRegex = new Regex("\\?"); + + var entityType = typeof(T); + var queryable = session.Query(); + var sessionImpl = session.GetSessionImplementation(); + var persister = sessionImpl.GetEntityPersister(entityType.FullName, null); + var idName = persister.IdentifierPropertyName; + var idType = persister.IdentifierType.ReturnedClass; + var idProperty = entityType.GetProperty(idName, instanceBindingFlags); + var idMember = idProperty as MemberInfo; + + if (idProperty == null) + { + var fieldEntityType = entityType; + + //if the property is null, it means the the id is implemented as a field + while ((fieldEntityType != typeof(Object)) && (idMember == null)) + { + //try to find the field recursively + idMember = fieldEntityType.GetField(idName, instanceBindingFlags); + + fieldEntityType = fieldEntityType.BaseType; + } + } + + if (idMember == null) + { + throw new InvalidOperationException(string.Format("Could not find identity property {0} in entity {1}.", idName, entityType.FullName)); + } + + var delegateType = typeof(Func<,>).MakeGenericType(entityType, idType); + var parm = Expression.Parameter(entityType, "x"); + var lambda = Expression.Lambda(delegateType, Expression.MakeMemberAccess(parm, idMember), new ParameterExpression[] { parm }); + var where = Expression.Call(null, whereMethod.MakeGenericMethod(entityType), queryable.Expression, condition); + var call = Expression.Call(null, selectMethod.MakeGenericMethod(entityType, idType), where, lambda); + + var nhLinqExpression = new NhLinqExpression(call, sessionImpl.Factory); + var translator = translatorFactory.CreateQueryTranslators(nhLinqExpression, null, false, sessionImpl.EnabledFilters, sessionImpl.Factory).Single(); + var parameters = nhLinqExpression.ParameterValuesByName.Select(x => x.Value.Item1).ToArray(); + //we need to turn positional parameters into named parameters because of SetParameterList + var count = 0; + var replacedSql = parameterTokensRegex.Replace(translator.SQLString, m => ":p" + count++); + var sql = new StringBuilder(replacedSql); + //find from + var fromIndex = sql.ToString().IndexOf(" from ", StringComparison.InvariantCultureIgnoreCase); + //find alias + var alias = aliasRegex.Match(sql.ToString()).Groups[2].Value; + + //make a string in the form DELETE alias FROM table alias WHERE condition + sql.Remove(0, fromIndex); + sql.Insert(0, string.Concat("delete ", alias, " ")); + + using (var childSession = session.GetSession(session.ActiveEntityMode)) + { + try + { + var query = childSession.CreateSQLQuery(sql.ToString()); + + for (var i = 0; i < parameters.Length; ++i) + { + var parameter = parameters[i]; + + if (!(parameter is IEnumerable) || (parameter is string) || (parameter is byte[])) + { + query.SetParameter(String.Format("p{0}", i), parameter); + } + else + { + query.SetParameterList(String.Format("p{0}", i), parameter as IEnumerable); + } + } + + return query.ExecuteUpdate(); + } + catch (Exception ex) + { + throw ADOExceptionHelper.Convert(sessionImpl.Factory.SQLExceptionConverter, ex, "Error deleting records.", new SqlString(sql.ToString()), parameters, nhLinqExpression.ParameterValuesByName.ToDictionary(x => x.Key, x => new TypedValue(x.Value.Item2, x.Value.Item1, session.ActiveEntityMode))); + } + } + } + public static IQueryable Query(this ISession session) { return new NhQueryable(session.GetSessionImplementation());