From 7fb2d23aa1d1d2dcbc4a8c0a282fa0cf20ab0d29 Mon Sep 17 00:00:00 2001 From: maca88 Date: Mon, 3 Aug 2020 23:20:58 +0200 Subject: [PATCH 1/3] Fix parameter detection related to sequence selectors for Linq provider --- .../Async/Linq/ParameterTests.cs | 14 +++ .../Async/NHSpecificTest/GH2465/Fixture.cs | 95 +++++++++++++++++++ src/NHibernate.Test/Linq/ParameterTests.cs | 14 +++ .../NHSpecificTest/GH2465/Entity.cs | 14 +++ .../NHSpecificTest/GH2465/Fixture.cs | 83 ++++++++++++++++ .../NHSpecificTest/GH2465/Mappings.hbm.xml | 12 +++ .../Linq/Visitors/ParameterTypeLocator.cs | 15 +++ 7 files changed, 247 insertions(+) create mode 100644 src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs create mode 100644 src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index b64e2ef99f4..ad1c885dc4f 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -142,6 +142,20 @@ public async Task UsingParameterInEvaluatableExpressionAsync() await (db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToListAsync()); } + [Test] + public async Task UsingParameterOnSelectorsAsync() + { + var user = new User() {Id = 1}; + await (db.Users.Where(o => o == user).ToListAsync()); + await (db.Users.FirstOrDefaultAsync(o => o == user)); + await (db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToListAsync()); + + var users = new[] {new User() {Id = 1}}; + await (db.Users.Where(o => users.Contains(o)).ToListAsync()); + await (db.Users.FirstOrDefaultAsync(o => users.Contains(o))); + await (db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToListAsync()); + } + [Test] public async Task UsingNegateValueTypeParameterTwiceAsync() { diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs new file mode 100644 index 00000000000..cea4d27fcd3 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs @@ -0,0 +1,95 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + using System.Threading.Tasks; + [TestFixture] + public class FixtureAsync : BugTestCase + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var applicant = new Entity {IdentityNames = {"name1", "name2"}}; + session.Save(applicant); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + transaction.Commit(); + } + } + + [Test] + public async Task ContainsInsideValueCollectionAsync() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var identityNames = new[] {"name1", "x"}; + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => identityNames.Contains(n))) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.All(n => identityNames.Contains(n))) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test") + .ToListAsync()); + + await (transaction.CommitAsync()); + } + } + + [Test] + public async Task EqualsInsideValueCollectionAsync() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var value = "test"; + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => (string) n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.All(n => n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test") + .ToListAsync()); + + await (transaction.CommitAsync()); + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index c406ea68929..97da1e0a079 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -130,6 +130,20 @@ public void UsingParameterInEvaluatableExpression() db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToList(); } + [Test] + public void UsingParameterOnSelectors() + { + var user = new User() {Id = 1}; + db.Users.Where(o => o == user).ToList(); + db.Users.FirstOrDefault(o => o == user); + db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToList(); + + var users = new[] {new User() {Id = 1}}; + db.Users.Where(o => users.Contains(o)).ToList(); + db.Users.FirstOrDefault(o => users.Contains(o)); + db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToList(); + } + [Test] public void ValidateMixingTwoParametersCacheKeys() { diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs new file mode 100644 index 00000000000..72dac9bddcf --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + public class Entity + { + private readonly IList _identityNames = new List(); + + public virtual Guid Id { get; set; } + + public virtual IList IdentityNames => _identityNames; + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs new file mode 100644 index 00000000000..bed0a709ff5 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs @@ -0,0 +1,83 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var applicant = new Entity {IdentityNames = {"name1", "name2"}}; + session.Save(applicant); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + transaction.Commit(); + } + } + + [Test] + public void ContainsInsideValueCollection() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var identityNames = new[] {"name1", "x"}; + session + .Query() + .Where(a => a.IdentityNames.Any(n => identityNames.Contains(n))) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.All(n => identityNames.Contains(n))) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test") + .ToList(); + + transaction.Commit(); + } + } + + [Test] + public void EqualsInsideValueCollection() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var value = "test"; + session + .Query() + .Where(a => a.IdentityNames.Any(n => n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.Any(n => (string) n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.All(n => n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test") + .ToList(); + + transaction.Commit(); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml new file mode 100644 index 00000000000..996ab1025da --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 40f3ec0d3d3..af264988925 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -4,6 +4,7 @@ using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Persister.Collection; using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; @@ -110,6 +111,12 @@ internal static void SetParameterTypes( out _, out _)) { + if (visitor.SequenceSelectorExpressions.Contains(memberExpression) && type is IAssociationType associationType) + { + var collection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + type = collection.ElementType; + } + break; } } @@ -137,6 +144,7 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor new Dictionary(); public readonly Dictionary> RelatedExpressions = new Dictionary>(); + public readonly HashSet SequenceSelectorExpressions = new HashSet(); public ConstantTypeLocatorVisitor( bool removeMappedAsCalls, @@ -247,6 +255,13 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && } else { + // In case a parameter is related to a sequence selector we will have to get the underling item type + // (e.g. q.Where(o => o.Users.Any(u => u == user))) + if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase)) + { + SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector); + } + node.QueryModel.TransformExpressions(Visit); } From dc3219d9c333e33df39d040728cdff58049feb0d Mon Sep 17 00:00:00 2001 From: maca88 Date: Tue, 4 Aug 2020 20:08:20 +0200 Subject: [PATCH 2/3] Code review changes --- src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs | 5 +++++ src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs | 5 +++++ src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 4 ++-- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs index cea4d27fcd3..2084a41ac03 100644 --- a/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs @@ -18,6 +18,11 @@ namespace NHibernate.Test.NHSpecificTest.GH2465 [TestFixture] public class FixtureAsync : BugTestCase { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return dialect.SupportsScalarSubSelects; + } + protected override void OnSetUp() { using (var session = OpenSession()) diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs index bed0a709ff5..121b3756270 100644 --- a/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs @@ -6,6 +6,11 @@ namespace NHibernate.Test.NHSpecificTest.GH2465 [TestFixture] public class Fixture : BugTestCase { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return dialect.SupportsScalarSubSelects; + } + protected override void OnSetUp() { using (var session = OpenSession()) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index af264988925..ec7161b58e0 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -111,9 +111,9 @@ internal static void SetParameterTypes( out _, out _)) { - if (visitor.SequenceSelectorExpressions.Contains(memberExpression) && type is IAssociationType associationType) + if (type.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(memberExpression)) { - var collection = (IQueryableCollection) associationType.GetAssociatedJoinable(sessionFactory); + var collection = (IQueryableCollection) ((IAssociationType) type).GetAssociatedJoinable(sessionFactory); type = collection.ElementType; } From 5d576b69550c2a8f644812b982068fdc0f0d12ad Mon Sep 17 00:00:00 2001 From: maca88 Date: Thu, 6 Aug 2020 23:01:15 +0200 Subject: [PATCH 3/3] Update src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Frédéric Delaporte <12201973+fredericDelaporte@users.noreply.github.com> --- src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index ec7161b58e0..41bc2547f31 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -255,7 +255,7 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && } else { - // In case a parameter is related to a sequence selector we will have to get the underling item type + // In case a parameter is related to a sequence selector we will have to get the underlying item type // (e.g. q.Where(o => o.Users.Any(u => u == user))) if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase)) {