diff --git a/src/NHibernate.Test/Async/Linq/ParameterTests.cs b/src/NHibernate.Test/Async/Linq/ParameterTests.cs index b64e2ef99f4..ad1c885dc4f 100644 --- a/src/NHibernate.Test/Async/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Async/Linq/ParameterTests.cs @@ -142,6 +142,20 @@ public async Task UsingParameterInEvaluatableExpressionAsync() await (db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToListAsync()); } + [Test] + public async Task UsingParameterOnSelectorsAsync() + { + var user = new User() {Id = 1}; + await (db.Users.Where(o => o == user).ToListAsync()); + await (db.Users.FirstOrDefaultAsync(o => o == user)); + await (db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToListAsync()); + + var users = new[] {new User() {Id = 1}}; + await (db.Users.Where(o => users.Contains(o)).ToListAsync()); + await (db.Users.FirstOrDefaultAsync(o => users.Contains(o))); + await (db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToListAsync()); + } + [Test] public async Task UsingNegateValueTypeParameterTwiceAsync() { diff --git a/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs new file mode 100644 index 00000000000..2084a41ac03 --- /dev/null +++ b/src/NHibernate.Test/Async/NHSpecificTest/GH2465/Fixture.cs @@ -0,0 +1,100 @@ +//------------------------------------------------------------------------------ +// +// This code was generated by AsyncGenerator. +// +// Changes to this file may cause incorrect behavior and will be lost if +// the code is regenerated. +// +//------------------------------------------------------------------------------ + + +using System.Linq; +using NUnit.Framework; +using NHibernate.Linq; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + using System.Threading.Tasks; + [TestFixture] + public class FixtureAsync : BugTestCase + { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return dialect.SupportsScalarSubSelects; + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var applicant = new Entity {IdentityNames = {"name1", "name2"}}; + session.Save(applicant); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + transaction.Commit(); + } + } + + [Test] + public async Task ContainsInsideValueCollectionAsync() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var identityNames = new[] {"name1", "x"}; + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => identityNames.Contains(n))) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.All(n => identityNames.Contains(n))) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test") + .ToListAsync()); + + await (transaction.CommitAsync()); + } + } + + [Test] + public async Task EqualsInsideValueCollectionAsync() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var value = "test"; + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.Any(n => (string) n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.All(n => n == value)) + .ToListAsync()); + await (session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test") + .ToListAsync()); + + await (transaction.CommitAsync()); + } + } + } +} diff --git a/src/NHibernate.Test/Linq/ParameterTests.cs b/src/NHibernate.Test/Linq/ParameterTests.cs index c406ea68929..97da1e0a079 100644 --- a/src/NHibernate.Test/Linq/ParameterTests.cs +++ b/src/NHibernate.Test/Linq/ParameterTests.cs @@ -130,6 +130,20 @@ public void UsingParameterInEvaluatableExpression() db.Users.Where(x => names.Length == 0 || names.Contains(x.Name)).ToList(); } + [Test] + public void UsingParameterOnSelectors() + { + var user = new User() {Id = 1}; + db.Users.Where(o => o == user).ToList(); + db.Users.FirstOrDefault(o => o == user); + db.Timesheets.Where(o => o.Users.Any(u => u == user)).ToList(); + + var users = new[] {new User() {Id = 1}}; + db.Users.Where(o => users.Contains(o)).ToList(); + db.Users.FirstOrDefault(o => users.Contains(o)); + db.Timesheets.Where(o => o.Users.Any(u => users.Contains(u))).ToList(); + } + [Test] public void ValidateMixingTwoParametersCacheKeys() { diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs b/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs new file mode 100644 index 00000000000..72dac9bddcf --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Entity.cs @@ -0,0 +1,14 @@ +using System; +using System.Collections.Generic; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + public class Entity + { + private readonly IList _identityNames = new List(); + + public virtual Guid Id { get; set; } + + public virtual IList IdentityNames => _identityNames; + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs new file mode 100644 index 00000000000..121b3756270 --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Fixture.cs @@ -0,0 +1,88 @@ +using System.Linq; +using NUnit.Framework; + +namespace NHibernate.Test.NHSpecificTest.GH2465 +{ + [TestFixture] + public class Fixture : BugTestCase + { + protected override bool AppliesTo(Dialect.Dialect dialect) + { + return dialect.SupportsScalarSubSelects; + } + + protected override void OnSetUp() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var applicant = new Entity {IdentityNames = {"name1", "name2"}}; + session.Save(applicant); + + transaction.Commit(); + } + } + + protected override void OnTearDown() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + session.Delete("from System.Object"); + transaction.Commit(); + } + } + + [Test] + public void ContainsInsideValueCollection() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var identityNames = new[] {"name1", "x"}; + session + .Query() + .Where(a => a.IdentityNames.Any(n => identityNames.Contains(n))) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.All(n => identityNames.Contains(n))) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => identityNames.Contains(n)) == "test") + .ToList(); + + transaction.Commit(); + } + } + + [Test] + public void EqualsInsideValueCollection() + { + using (var session = OpenSession()) + using (var transaction = session.BeginTransaction()) + { + var value = "test"; + session + .Query() + .Where(a => a.IdentityNames.Any(n => n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.Any(n => (string) n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.All(n => n == value)) + .ToList(); + session + .Query() + .Where(a => a.IdentityNames.FirstOrDefault(n => n == "test") == "test") + .ToList(); + + transaction.Commit(); + } + } + } +} diff --git a/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml b/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml new file mode 100644 index 00000000000..996ab1025da --- /dev/null +++ b/src/NHibernate.Test/NHSpecificTest/GH2465/Mappings.hbm.xml @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs index 40f3ec0d3d3..41bc2547f31 100644 --- a/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs +++ b/src/NHibernate/Linq/Visitors/ParameterTypeLocator.cs @@ -4,6 +4,7 @@ using System.Linq.Expressions; using NHibernate.Engine; using NHibernate.Param; +using NHibernate.Persister.Collection; using NHibernate.Type; using NHibernate.Util; using Remotion.Linq; @@ -110,6 +111,12 @@ internal static void SetParameterTypes( out _, out _)) { + if (type.IsAssociationType && visitor.SequenceSelectorExpressions.Contains(memberExpression)) + { + var collection = (IQueryableCollection) ((IAssociationType) type).GetAssociatedJoinable(sessionFactory); + type = collection.ElementType; + } + break; } } @@ -137,6 +144,7 @@ private class ConstantTypeLocatorVisitor : RelinqExpressionVisitor new Dictionary(); public readonly Dictionary> RelatedExpressions = new Dictionary>(); + public readonly HashSet SequenceSelectorExpressions = new HashSet(); public ConstantTypeLocatorVisitor( bool removeMappedAsCalls, @@ -247,6 +255,13 @@ querySourceReference.ReferencedQuerySource is MainFromClause mainFromClause && } else { + // In case a parameter is related to a sequence selector we will have to get the underlying item type + // (e.g. q.Where(o => o.Users.Any(u => u == user))) + if (node.QueryModel.ResultOperators.Any(o => o is ValueFromSequenceResultOperatorBase)) + { + SequenceSelectorExpressions.Add(node.QueryModel.SelectClause.Selector); + } + node.QueryModel.TransformExpressions(Visit); }