diff --git a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs index 86802cdb345..fd3468c6980 100644 --- a/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Async/Linq/ByMethod/JoinTests.cs @@ -13,10 +13,10 @@ using System.Reflection; using NHibernate.Cfg; using NHibernate.Engine.Query; +using NHibernate.Linq; using NHibernate.Util; using NSubstitute; using NUnit.Framework; -using NHibernate.Linq; namespace NHibernate.Test.Linq.ByMethod { @@ -27,15 +27,91 @@ public class JoinTestsAsync : LinqTestCase [Test] public async Task MultipleLinqJoinsWithSameProjectionNamesAsync() { - var orders = await (db.Orders + using (var sqlSpy = new SqlLogSpy()) + { + var orders = await (db.Orders .Join(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) .Select(x => new { First = x.order, Second = x.order1 }) .Join(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) .Select(x => new { FirstId = x.order.First.OrderId, SecondId = x.order.Second.OrderId, ThirdId = x.order1.OrderId }) .ToListAsync()); - Assert.That(orders.Count, Is.EqualTo(828)); - Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(828)); + Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } + + [Test] + public async Task MultipleLinqJoinsWithSameProjectionNamesWithLeftJoinAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = await (db.Orders + .GroupJoin(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) + .SelectMany(x => x.order1.DefaultIfEmpty(), (x, order1) => new { First = x.order, Second = order1 }) + .GroupJoin(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) + .SelectMany(x => x.order1.DefaultIfEmpty(), (x, order1) => new + { + FirstId = x.order.First.OrderId, + SecondId = (int?) x.order.Second.OrderId, + ThirdId = (int?) order1.OrderId + }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.Where(x => x.SecondId.HasValue && x.ThirdId.HasValue) + .All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } + + [Test] + public async Task MultipleLinqJoinsWithSameProjectionNamesWithLeftJoinExtensionMethodAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = await (db.Orders + .LeftJoin(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) + .Select(x => new { First = x.order, Second = x.order1 }) + .LeftJoin(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) + .Select(x => new + { + FirstId = x.order.First.OrderId, + SecondId = (int?) x.order.Second.OrderId, + ThirdId = (int?) x.order1.OrderId + }) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.Where(x => x.SecondId.HasValue && x.ThirdId.HasValue) + .All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } + + [Test] + public async Task LeftJoinExtensionMethodWithMultipleKeyPropertiesAsync() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = await (db.Orders + .LeftJoin( + db.Orders, + x => new {x.OrderId, x.Customer.CustomerId}, + x => new {x.OrderId, x.Customer.CustomerId}, + (order, order1) => new {order, order1}) + .Select(x => new {FirstId = x.order.OrderId, SecondId = x.order1.OrderId}) + .ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [TestCase(false)] diff --git a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs index a332a24f6ce..9b8d04f2c6c 100644 --- a/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Async/Linq/LinqQuerySamples.cs @@ -768,6 +768,26 @@ from o in c.Orders } } + [Category("JOIN")] + [Test(Description = "This sample uses foreign key navigation in the " + + "from clause to select all orders for customers in London.")] + public async Task DLinqJoin1LeftJoinAsync() + { + IQueryable q = + from c in db.Customers + from o in c.Orders.DefaultIfEmpty() + where c.Address.City == "London" + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample shows how to construct a join where one side is nullable and the other isn't.")] public async Task DLinqJoin10Async() @@ -974,6 +994,26 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [Test(Description = "This sample explictly joins two tables and projects results from both tables.")] + public async Task DLinqJoin5aLeftJoinAsync() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders + from o in orders.DefaultIfEmpty() + where o != null + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins two tables and projects results from both tables using a group join.")] public async Task DLinqJoin5bAsync() @@ -1032,6 +1072,21 @@ join o in db.Orders on } } + [Category("JOIN")] + [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5dLeftJoinAsync() + { + var q = + from c in db.Customers + join o in db.Orders on + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } equals + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } into orders + from o in orders.DefaultIfEmpty() + select new { c.ContactName, o.OrderId }; + + Assert.ThrowsAsync(() => ObjectDumper.WriteAsync(q)); + } + [Category("JOIN")] [Test(Description = "This sample joins two tables and projects results from the first table.")] public async Task DLinqJoin5eAsync() @@ -1051,6 +1106,26 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public async Task DLinqJoin5eLeftJoinAsync() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders + from o in orders.DefaultIfEmpty() + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] public async Task DLinqJoin5fAsync() @@ -1072,6 +1147,28 @@ join c in db.Customers on } } + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public async Task DLinqJoin5fLeftJoinAsync() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } into customers + from c in customers.DefaultIfEmpty() + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(0)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public async Task DLinqJoin6Async() @@ -1094,6 +1191,28 @@ join e in db.Employees on c.Address.City equals e.Address.City into emps } } + [Category("JOIN")] + [Test( + Description = + "This sample shows how to get LEFT OUTER JOIN by using DefaultIfEmpty(). The DefaultIfEmpty() method returns null when there is no Order for the Employee." + )] + public async Task DLinqJoin7Async() + { + var q = + from e in db.Employees + join o in db.Orders on e equals o.Employee into ords + from o in ords.DefaultIfEmpty() + select new {e.FirstName, e.LastName, Order = o}; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample projects a 'let' expression resulting from a join.")] public async Task DLinqJoin8Async() @@ -1156,6 +1275,50 @@ from d in details } } + [Category("JOIN")] + [TestCase(true, Description = "This sample shows a group left join with a composite key.")] + [TestCase(false, Description = "This sample shows a group left join with a composite key.")] + public async Task DLinqJoin9LeftJoinAsync(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + // into details + // from d in details.DefaultIfEmpty() + // where d != null && d.UnitPrice > 50 + // select new { o.OrderId, p.ProductId, d.UnitPrice }).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + await ((from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + into details + from d in details.DefaultIfEmpty() + where d != null && d.UnitPrice > 50 + select new {o.OrderId, p.ProductId, d.UnitPrice}).ToListAsync()); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(163)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample shows a join which is then grouped")] public async Task DLinqJoin9bAsync() @@ -1186,5 +1349,26 @@ join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); } } + + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables using a left join.")] + public async Task DLinqJoin10aLeftJoinAsync() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId into sup + from s in sup.DefaultIfEmpty() + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId into sup2 + from s2 in sup2.DefaultIfEmpty() + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + await (ObjectDumper.WriteAsync(q)); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } } } diff --git a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs index a013014cae8..d1dc10bec83 100644 --- a/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs +++ b/src/NHibernate.Test/Linq/ByMethod/JoinTests.cs @@ -3,6 +3,7 @@ using System.Reflection; using NHibernate.Cfg; using NHibernate.Engine.Query; +using NHibernate.Linq; using NHibernate.Util; using NSubstitute; using NUnit.Framework; @@ -15,15 +16,91 @@ public class JoinTests : LinqTestCase [Test] public void MultipleLinqJoinsWithSameProjectionNames() { - var orders = db.Orders + using (var sqlSpy = new SqlLogSpy()) + { + var orders = db.Orders .Join(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) .Select(x => new { First = x.order, Second = x.order1 }) .Join(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) .Select(x => new { FirstId = x.order.First.OrderId, SecondId = x.order.Second.OrderId, ThirdId = x.order1.OrderId }) .ToList(); - Assert.That(orders.Count, Is.EqualTo(828)); - Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(828)); + Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(2)); + } + } + + [Test] + public void MultipleLinqJoinsWithSameProjectionNamesWithLeftJoin() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = db.Orders + .GroupJoin(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) + .SelectMany(x => x.order1.DefaultIfEmpty(), (x, order1) => new { First = x.order, Second = order1 }) + .GroupJoin(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) + .SelectMany(x => x.order1.DefaultIfEmpty(), (x, order1) => new + { + FirstId = x.order.First.OrderId, + SecondId = (int?) x.order.Second.OrderId, + ThirdId = (int?) order1.OrderId + }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.Where(x => x.SecondId.HasValue && x.ThirdId.HasValue) + .All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } + + [Test] + public void MultipleLinqJoinsWithSameProjectionNamesWithLeftJoinExtensionMethod() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = db.Orders + .LeftJoin(db.Orders, x => x.OrderId, x => x.OrderId - 1, (order, order1) => new { order, order1 }) + .Select(x => new { First = x.order, Second = x.order1 }) + .LeftJoin(db.Orders, x => x.First.OrderId, x => x.OrderId - 2, (order, order1) => new { order, order1 }) + .Select(x => new + { + FirstId = x.order.First.OrderId, + SecondId = (int?) x.order.Second.OrderId, + ThirdId = (int?) x.order1.OrderId + }) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.Where(x => x.SecondId.HasValue && x.ThirdId.HasValue) + .All(x => x.FirstId == x.SecondId - 1 && x.SecondId == x.ThirdId - 1)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } + + [Test] + public void LeftJoinExtensionMethodWithMultipleKeyProperties() + { + using (var sqlSpy = new SqlLogSpy()) + { + var orders = db.Orders + .LeftJoin( + db.Orders, + x => new {x.OrderId, x.Customer.CustomerId}, + x => new {x.OrderId, x.Customer.CustomerId}, + (order, order1) => new {order, order1}) + .Select(x => new {FirstId = x.order.OrderId, SecondId = x.order1.OrderId}) + .ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(orders.Count, Is.EqualTo(830)); + Assert.IsTrue(orders.All(x => x.FirstId == x.SecondId)); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } } [TestCase(false)] diff --git a/src/NHibernate.Test/Linq/LinqQuerySamples.cs b/src/NHibernate.Test/Linq/LinqQuerySamples.cs index 193c313130c..393d55ccf0a 100755 --- a/src/NHibernate.Test/Linq/LinqQuerySamples.cs +++ b/src/NHibernate.Test/Linq/LinqQuerySamples.cs @@ -1312,6 +1312,26 @@ from o in c.Orders } } + [Category("JOIN")] + [Test(Description = "This sample uses foreign key navigation in the " + + "from clause to select all orders for customers in London.")] + public void DLinqJoin1LeftJoin() + { + IQueryable q = + from c in db.Customers + from o in c.Orders.DefaultIfEmpty() + where c.Address.City == "London" + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample shows how to construct a join where one side is nullable and the other isn't.")] public void DLinqJoin10() @@ -1518,6 +1538,26 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [Test(Description = "This sample explictly joins two tables and projects results from both tables.")] + public void DLinqJoin5aLeftJoin() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders + from o in orders.DefaultIfEmpty() + where o != null + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins two tables and projects results from both tables using a group join.")] public void DLinqJoin5b() @@ -1576,6 +1616,21 @@ join o in db.Orders on } } + [Category("JOIN")] + [Test(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5dLeftJoin() + { + var q = + from c in db.Customers + join o in db.Orders on + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } equals + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } into orders + from o in orders.DefaultIfEmpty() + select new { c.ContactName, o.OrderId }; + + Assert.Throws(() => ObjectDumper.Write(q)); + } + [Category("JOIN")] [Test(Description = "This sample joins two tables and projects results from the first table.")] public void DLinqJoin5e() @@ -1595,6 +1650,26 @@ join o in db.Orders on c.CustomerId equals o.Customer.CustomerId } } + [Category("JOIN")] + [Test(Description = "This sample joins two tables and projects results from the first table.")] + public void DLinqJoin5eLeftJoin() + { + var q = + from c in db.Customers + join o in db.Orders on c.CustomerId equals o.Customer.CustomerId into orders + from o in orders.DefaultIfEmpty() + where c.ContactName != null + select o; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] public void DLinqJoin5f() @@ -1616,6 +1691,28 @@ join c in db.Customers on } } + [Category("JOIN")] + [TestCase(Description = "This sample explictly joins two tables with a composite key and projects results from both tables.")] + public void DLinqJoin5fLeftJoin() + { + var q = + from o in db.Orders + join c in db.Customers on + new { o.Customer.CustomerId, HasContractTitle = o.Customer.ContactTitle != null } equals + new { c.CustomerId, HasContractTitle = c.ContactTitle != null } into customers + from c in customers.DefaultIfEmpty() + select new { c.ContactName, o.OrderId }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + Assert.That(GetTotalOccurrences(sql, "inner join"), Is.EqualTo(0)); + } + } + [Category("JOIN")] [Test(Description = "This sample explictly joins three tables and projects results from each of them.")] public void DLinqJoin6() @@ -1643,7 +1740,6 @@ join e in db.Employees on c.Address.City equals e.Address.City into emps Description = "This sample shows how to get LEFT OUTER JOIN by using DefaultIfEmpty(). The DefaultIfEmpty() method returns null when there is no Order for the Employee." )] - [Ignore("TODO left outer join")] public void DLinqJoin7() { var q = @@ -1723,6 +1819,50 @@ from d in details } } + [Category("JOIN")] + [TestCase(true, Description = "This sample shows a group left join with a composite key.")] + [TestCase(false, Description = "This sample shows a group left join with a composite key.")] + public void DLinqJoin9LeftJoin(bool useCrossJoin) + { + if (useCrossJoin && !Dialect.SupportsCrossJoin) + { + Assert.Ignore("Dialect does not support cross join."); + } + + // The expected collection can be obtained from the below Linq to Objects query. + //var expected = + // (from o in db.Orders.ToList() + // from p in db.Products.ToList() + // join d in db.OrderLines.ToList() + // on new { o.OrderId, p.ProductId } equals new { d.Order.OrderId, d.Product.ProductId } + // into details + // from d in details.DefaultIfEmpty() + // where d != null && d.UnitPrice > 50 + // select new { o.OrderId, p.ProductId, d.UnitPrice }).ToList(); + + using (var substitute = SubstituteDialect()) + using (var sqlSpy = new SqlLogSpy()) + { + ClearQueryPlanCache(); + substitute.Value.SupportsCrossJoin.Returns(useCrossJoin); + + var actual = + (from o in db.Orders + from p in db.Products + join d in db.OrderLines + on new {o.OrderId, p.ProductId} equals new {d.Order.OrderId, d.Product.ProductId} + into details + from d in details.DefaultIfEmpty() + where d != null && d.UnitPrice > 50 + select new {o.OrderId, p.ProductId, d.UnitPrice}).ToList(); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(actual.Count, Is.EqualTo(163)); + Assert.That(sql, Does.Contain(useCrossJoin ? "cross join" : "inner join")); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(1)); + } + } + [Category("JOIN")] [Test(Description = "This sample shows a join which is then grouped")] public void DLinqJoin9b() @@ -1754,6 +1894,27 @@ join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId } } + [Category("JOIN")] + [Test(Description = "This sample shows how to join multiple tables using a left join.")] + public void DLinqJoin10aLeftJoin() + { + var q = + from e in db.Employees + join s in db.Employees on e.Superior.EmployeeId equals s.EmployeeId into sup + from s in sup.DefaultIfEmpty() + join s2 in db.Employees on s.Superior.EmployeeId equals s2.EmployeeId into sup2 + from s2 in sup2.DefaultIfEmpty() + select new { e.FirstName, SuperiorName = s.FirstName, Superior2Name = s2.FirstName }; + + using (var sqlSpy = new SqlLogSpy()) + { + ObjectDumper.Write(q); + + var sql = sqlSpy.GetWholeLog(); + Assert.That(GetTotalOccurrences(sql, "left outer join"), Is.EqualTo(2)); + } + } + [Category("WHERE")] [Test(Description = "This sample uses WHERE to filter for orders with shipping date equals to null.")] public void DLinq2B() diff --git a/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs b/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs new file mode 100644 index 00000000000..6955afd936b --- /dev/null +++ b/src/NHibernate/Linq/Clauses/NhOuterJoinClause.cs @@ -0,0 +1,46 @@ +using System; +using System.Linq.Expressions; +using Remotion.Linq; +using Remotion.Linq.Clauses; + +namespace NHibernate.Linq.Clauses +{ + /// + /// A wrapper for that is used to mark it as an outer join. + /// + public class NhOuterJoinClause : NhClauseBase, IBodyClause, IClause, IQuerySource + { + public NhOuterJoinClause(JoinClause joinClause) + { + JoinClause = joinClause; + } + + public JoinClause JoinClause { get; } + + public string ItemName => JoinClause.ItemName; + + public System.Type ItemType => JoinClause.ItemType; + + public void TransformExpressions(Func transformation) + { + JoinClause.TransformExpressions(transformation); + } + + public IBodyClause Clone(CloneContext cloneContext) + { + return new NhOuterJoinClause(JoinClause.Clone(cloneContext)); + } + + protected override void Accept(INhQueryModelVisitor visitor, QueryModel queryModel, int index) + { + if (visitor is INhQueryModelVisitorExtended queryModelVisitorExtended) + { + queryModelVisitorExtended.VisitNhOuterJoinClause(this, queryModel, index); + } + else + { + visitor.VisitJoinClause(JoinClause, queryModel, index); + } + } + } +} diff --git a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs index 78ee78c35eb..20f2331bf9c 100644 --- a/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs +++ b/src/NHibernate/Linq/GroupJoin/GroupJoinAggregateDetectionVisitor.cs @@ -79,6 +79,12 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr } } } + // In order to detect a left join (e.g. from a in A join b in B on a.Id equals b.Id into c from b in c.DefaultIfEmpty()) + // we have to visit the subquery in order to find the group join + else if (fromClause.FromExpression is SubQueryExpression subQuery) + { + VisitSubQuery(subQuery); + } return base.VisitQuerySourceReference(expression); } diff --git a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs index ae1f8f79273..1ccaebab3bf 100644 --- a/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs +++ b/src/NHibernate/Linq/GroupJoin/NonAggregatingGroupJoinRewriter.cs @@ -2,10 +2,12 @@ using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; +using NHibernate.Linq.Clauses; using NHibernate.Linq.GroupJoin; using Remotion.Linq; using Remotion.Linq.Clauses; using Remotion.Linq.Clauses.Expressions; +using Remotion.Linq.Clauses.ResultOperators; using Remotion.Linq.Parsing; namespace NHibernate.Linq.Visitors @@ -14,7 +16,6 @@ public class NonAggregatingGroupJoinRewriter { private readonly QueryModel _model; private readonly IEnumerable _groupJoinClauses; - private QuerySourceUsageLocator _locator; private NonAggregatingGroupJoinRewriter(QueryModel model, IEnumerable groupJoinClauses) { @@ -67,19 +68,19 @@ private void ReWrite() // This is used to repesent an outer join, and again the "from" is removing the hierarchy. So // simply change the group join to an outer join - _locator = new QuerySourceUsageLocator(nonAggregatingJoin); + var locator = new QuerySourceUsageLocator(nonAggregatingJoin); foreach (var bodyClause in _model.BodyClauses) { - _locator.Search(bodyClause); + locator.Search(bodyClause); } - if (IsHierarchicalJoin(nonAggregatingJoin)) + if (IsHierarchicalJoin(nonAggregatingJoin, locator)) { } - else if (IsFlattenedJoin(nonAggregatingJoin)) + else if (IsFlattenedJoin(nonAggregatingJoin, locator)) { - ProcessFlattenedJoin(nonAggregatingJoin); + ProcessFlattenedJoin(nonAggregatingJoin, locator); } else if (IsOuterJoin(nonAggregatingJoin)) { @@ -92,19 +93,30 @@ private void ReWrite() } } - private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin) + private void ProcessFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) { + var nhJoin = locator.LeftJoin + ? new NhOuterJoinClause(nonAggregatingJoin.JoinClause) + : (IQuerySource) nonAggregatingJoin.JoinClause; + // Need to: // 1. Remove the group join and replace it with a join // 2. Remove the corresponding "from" clause (the thing that was doing the flattening) - // 3. Rewrite the selector to reference the "join" rather than the "from" clause - SwapClause(nonAggregatingJoin, nonAggregatingJoin.JoinClause); + // 3. Rewrite the query model to reference the "join" rather than the "from" clause + SwapClause(nonAggregatingJoin, (IBodyClause) nhJoin); - // TODO - don't like use of _locator here; would rather we got this passed in. Ditto on next line (esp. the cast) - _model.BodyClauses.Remove(_locator.Clauses[0]); + _model.BodyClauses.Remove((IBodyClause) locator.Usages[0]); + + SwapQuerySourceVisitor querySourceSwapper; + if (locator.LeftJoin) + { + // As we wrapped the join clause we have to update all references to the wrapped clause + querySourceSwapper = new SwapQuerySourceVisitor(nonAggregatingJoin.JoinClause, nhJoin); + _model.TransformExpressions(querySourceSwapper.Swap); + } - var querySourceSwapper = new SwapQuerySourceVisitor((IQuerySource)_locator.Clauses[0], nonAggregatingJoin.JoinClause); - _model.SelectClause.TransformExpressions(querySourceSwapper.Swap); + querySourceSwapper = new SwapQuerySourceVisitor(locator.Usages[0], nhJoin); + _model.TransformExpressions(querySourceSwapper.Swap); } // TODO - store the indexes of the join clauses when we find them, then can remove this loop @@ -125,11 +137,11 @@ private bool IsOuterJoin(GroupJoinClause nonAggregatingJoin) return false; } - private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin) + private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) { - if (_locator.Clauses.Count == 1) + if (locator.Usages.Count == 1) { - var from = _locator.Clauses[0] as AdditionalFromClause; + var from = locator.Usages[0] as AdditionalFromClause; if (from != null) { @@ -140,9 +152,9 @@ private bool IsFlattenedJoin(GroupJoinClause nonAggregatingJoin) return false; } - private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin) + private bool IsHierarchicalJoin(GroupJoinClause nonAggregatingJoin, QuerySourceUsageLocator locator) { - return _locator.Clauses.Count == 0; + return locator.Usages.Count == 0; } // TODO - rename this and share with the AggregatingGroupJoinRewriter @@ -156,27 +168,30 @@ internal class QuerySourceUsageLocator : RelinqExpressionVisitor { private readonly IQuerySource _querySource; private bool _references; - private readonly List _clauses = new List(); public QuerySourceUsageLocator(IQuerySource querySource) { _querySource = querySource; } - public IList Clauses - { - get { return _clauses.AsReadOnly(); } - } + internal bool LeftJoin { get; private set; } + + public IList Usages { get; } = new List(); public void Search(IBodyClause clause) { + if (!(clause is IQuerySource querySource)) + { + return; + } + _references = false; clause.TransformExpressions(ExpressionSearcher); if (_references) { - _clauses.Add(clause); + Usages.Add(querySource); } } @@ -195,5 +210,22 @@ protected override Expression VisitQuerySourceReference(QuerySourceReferenceExpr return expression; } + + protected override Expression VisitSubQuery(SubQueryExpression expression) + { + if (IsLeftJoin(expression.QueryModel)) + { + LeftJoin = true; + expression.QueryModel.MainFromClause.TransformExpressions(ExpressionSearcher); + } + + return expression; + } + + private static bool IsLeftJoin(QueryModel subQueryModel) + { + return subQueryModel.ResultOperators.Count == 1 && + subQueryModel.ResultOperators[0] is DefaultIfEmptyResultOperator; + } } } diff --git a/src/NHibernate/Linq/INhQueryModelVisitor.cs b/src/NHibernate/Linq/INhQueryModelVisitor.cs index 3d4ce54fe34..42899774213 100644 --- a/src/NHibernate/Linq/INhQueryModelVisitor.cs +++ b/src/NHibernate/Linq/INhQueryModelVisitor.cs @@ -11,4 +11,10 @@ public interface INhQueryModelVisitor: IQueryModelVisitor void VisitNhHavingClause(NhHavingClause nhWhereClause, QueryModel queryModel, int index); } + + // TODO 6.0: Move members into INhQueryModelVisitor + internal interface INhQueryModelVisitorExtended : INhQueryModelVisitor + { + void VisitNhOuterJoinClause(NhOuterJoinClause nhOuterJoinClause, QueryModel queryModel, int index); + } } diff --git a/src/NHibernate/Linq/LinqExtensionMethods.cs b/src/NHibernate/Linq/LinqExtensionMethods.cs index dc085c3bde5..12b84e6e18c 100644 --- a/src/NHibernate/Linq/LinqExtensionMethods.cs +++ b/src/NHibernate/Linq/LinqExtensionMethods.cs @@ -10,6 +10,7 @@ using System.Threading; using System.Threading.Tasks; using NHibernate.Engine; +using static NHibernate.Util.ReflectionCache.QueryableMethods; namespace NHibernate.Linq { @@ -2466,6 +2467,75 @@ public static IFutureValue ToFutureValue(this IQuerya #pragma warning restore CS0618 // Type or member is obsolete } + #region LeftJoin + + // Code based on: https://stackoverflow.com/a/18782867 + /// + /// Correlates the elements of two sequences based on matching keys. The default equality comparer is used to compare keys. + /// + /// The first sequence to join. + /// The sequence to join to the first sequence. + /// A dynamic function to extract the join key from each element of the first sequence. + /// A dynamic function to extract the join key from each element of the second sequence. + /// A dynamic function to create a result element from two matching elements. + /// An obtained by performing a left join on two sequences. + public static IQueryable LeftJoin( + this IQueryable outer, + IQueryable inner, + Expression> outerKeySelector, + Expression> innerKeySelector, + Expression> resultSelector) + { + outer = outer ?? throw new ArgumentNullException(nameof(outer)); + inner = inner ?? throw new ArgumentNullException(nameof(inner)); + outerKeySelector = outerKeySelector ?? throw new ArgumentNullException(nameof(outerKeySelector)); + innerKeySelector = innerKeySelector ?? throw new ArgumentNullException(nameof(innerKeySelector)); + resultSelector = resultSelector ?? throw new ArgumentNullException(nameof(resultSelector)); + + Expression, LeftJoinIntermediate>> groupJoinResultSelector = + (oneOuter, manyInners) => new LeftJoinIntermediate + { + OneOuter = oneOuter, + ManyInners = manyInners + }; + var groupJoin = GroupJoinDefinition.MakeGenericMethod( + typeof(TOuter), + typeof(TInner), + typeof(TKey), + typeof(LeftJoinIntermediate)); + var selectMany = SelectManyDefinition.MakeGenericMethod( + typeof(LeftJoinIntermediate), + typeof(TInner), + typeof(TResult)); + var exprGroupJoin = Expression.Call( + groupJoin, + outer.Expression, + inner.Expression, + outerKeySelector, + innerKeySelector, + groupJoinResultSelector); + var selectManyCollectionSelector = (Expression, IEnumerable>>) + (t => t.ManyInners.DefaultIfEmpty()); + var outerParameter = resultSelector.Parameters[0]; + var paramNew = Expression.Parameter(typeof(LeftJoinIntermediate)); + var outerProperty = Expression.Property(paramNew, nameof(LeftJoinIntermediate.OneOuter)); + var selectManyResultSelector = Expression.Lambda( + ReplacingExpressionVisitor.Replace(outerParameter, outerProperty, resultSelector.Body), + paramNew, + resultSelector.Parameters[1]); + + return outer.Provider.CreateQuery( + Expression.Call(selectMany, exprGroupJoin, selectManyCollectionSelector, selectManyResultSelector)); + } + + private class LeftJoinIntermediate + { + public TOuter OneOuter { get; set; } + public IEnumerable ManyInners { get; set; } + } + + #endregion + /// /// Allows to set NHibernate query options. /// diff --git a/src/NHibernate/Linq/QuerySourceNamer.cs b/src/NHibernate/Linq/QuerySourceNamer.cs index 1fa9963a7d0..a989b83c113 100644 --- a/src/NHibernate/Linq/QuerySourceNamer.cs +++ b/src/NHibernate/Linq/QuerySourceNamer.cs @@ -22,6 +22,14 @@ public void Add(IQuerySource querySource) _map.Add(querySource, CreateUniqueName(querySource.ItemName)); } + internal void Add(IQuerySource querySource, string name) + { + if (_map.ContainsKey(querySource)) + return; + + _map.Add(querySource, name); + } + public string GetName(IQuerySource querySource) { string result; diff --git a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs index d022e1ffc88..e1a9a9eecc0 100644 --- a/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs +++ b/src/NHibernate/Linq/ReWriters/AddJoinsReWriter.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; using NHibernate.Engine; @@ -15,12 +16,13 @@ internal interface IIsEntityDecider bool IsIdentifier(System.Type type, string propertyName); } - public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider + public class AddJoinsReWriter : NhQueryModelVisitorBase, IIsEntityDecider, INhQueryModelVisitorExtended { private readonly ISessionFactoryImplementor _sessionFactory; private readonly MemberExpressionJoinDetector _memberExpressionJoinDetector; private readonly WhereJoinDetector _whereJoinDetector; private JoinClause _currentJoin; + private bool? _innerJoin; private AddJoinsReWriter(ISessionFactoryImplementor sessionFactory, QueryModel queryModel) { @@ -61,7 +63,17 @@ public override void VisitNhHavingClause(NhHavingClause havingClause, QueryModel _whereJoinDetector.Transform(havingClause); } + public void VisitNhOuterJoinClause(NhOuterJoinClause nhOuterJoinClause, QueryModel queryModel, int index) + { + VisitJoinClause(nhOuterJoinClause.JoinClause, false); + } + public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + VisitJoinClause(joinClause, true); + } + + private void VisitJoinClause(JoinClause joinClause, bool innerJoin) { joinClause.InnerSequence = _whereJoinDetector.Transform(joinClause.InnerSequence); @@ -73,8 +85,10 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode // support them). // Link newly created joins with the current join clause in order to later detect which join type to use. _currentJoin = joinClause; + _innerJoin = innerJoin; joinClause.InnerKeySelector = _whereJoinDetector.Transform(joinClause.InnerKeySelector); _currentJoin = null; + _innerJoin = null; } public bool IsEntity(System.Type type) @@ -91,7 +105,7 @@ public bool IsIdentifier(System.Type type, string propertyName) private void AddJoin(QueryModel queryModel, NhJoinClause joinClause) { joinClause.ParentJoinClause = _currentJoin; - if (_currentJoin != null) + if (_innerJoin == true) { // Match the parent join type joinClause.MakeInner(); diff --git a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs index a487a1281ce..040e9b38932 100644 --- a/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs +++ b/src/NHibernate/Linq/Visitors/QueryModelVisitor.cs @@ -23,7 +23,7 @@ namespace NHibernate.Linq.Visitors { - public class QueryModelVisitor : NhQueryModelVisitorBase, INhQueryModelVisitor + public class QueryModelVisitor : NhQueryModelVisitorBase, INhQueryModelVisitor, INhQueryModelVisitorExtended { private readonly QueryMode _queryMode; @@ -512,6 +512,16 @@ public override void VisitOrderByClause(OrderByClause orderByClause, QueryModel } public override void VisitJoinClause(JoinClause joinClause, QueryModel queryModel, int index) + { + AddJoin(joinClause, queryModel, true); + } + + public void VisitNhOuterJoinClause(NhOuterJoinClause outerJoinClause, QueryModel queryModel, int index) + { + AddJoin(outerJoinClause.JoinClause, queryModel, false); + } + + private void AddJoin(JoinClause joinClause, QueryModel queryModel, bool innerJoin) { var equalityVisitor = new EqualityHqlGenerator(VisitorParameters); var withClause = equalityVisitor.Visit(joinClause.InnerKeySelector, joinClause.OuterKeySelector); @@ -522,12 +532,19 @@ public override void VisitJoinClause(JoinClause joinClause, QueryModel queryMode // join and add the condition in the where statement. if (queryModel.BodyClauses.OfType().Any(o => o.ParentJoinClause == joinClause)) { + if (!innerJoin) + { + throw new NotSupportedException("Left joins that have association properties in the inner key selector are not supported."); + } + _hqlTree.AddWhereClause(withClause); join = CreateCrossJoin(joinExpression, alias); } else { - join = _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias); + join = innerJoin + ? _hqlTree.TreeBuilder.InnerJoin(joinExpression.AsExpression(), alias) + : (HqlTreeNode) _hqlTree.TreeBuilder.LeftJoin(joinExpression.AsExpression(), alias); join.AddChild(_hqlTree.TreeBuilder.With(withClause)); } diff --git a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs index 475a13050c8..30bc1cde345 100644 --- a/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs +++ b/src/NHibernate/Linq/Visitors/QuerySourceIdentifier.cs @@ -14,7 +14,7 @@ namespace NHibernate.Linq.Visitors /// the HQL expression tree) means a query source may be referenced by a QuerySourceReference /// before it has been identified - and named. /// - public class QuerySourceIdentifier : NhQueryModelVisitorBase + public class QuerySourceIdentifier : NhQueryModelVisitorBase, INhQueryModelVisitorExtended { private readonly QuerySourceNamer _namer; @@ -58,6 +58,12 @@ public override void VisitNhJoinClause(NhJoinClause joinClause, QueryModel query _namer.Add(joinClause); } + public void VisitNhOuterJoinClause(NhOuterJoinClause outerJoinClause, QueryModel queryModel, int index) + { + _namer.Add(outerJoinClause); + _namer.Add(outerJoinClause.JoinClause, _namer.GetName(outerJoinClause)); + } + public override void VisitResultOperator(ResultOperatorBase resultOperator, QueryModel queryModel, int index) { var groupBy = resultOperator as GroupResultOperator; @@ -73,4 +79,4 @@ public override void VisitSelectClause(SelectClause selectClause, QueryModel que public QuerySourceNamer Namer { get { return _namer; } } } -} \ No newline at end of file +} diff --git a/src/NHibernate/Util/ReflectHelper.cs b/src/NHibernate/Util/ReflectHelper.cs index 1174a2e6ade..06352101745 100644 --- a/src/NHibernate/Util/ReflectHelper.cs +++ b/src/NHibernate/Util/ReflectHelper.cs @@ -180,6 +180,19 @@ internal static MethodInfo FastGetMethodDefinition(Syst return method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; } + /// Get a from a method group + /// A method group + /// A dummy parameter + /// A dummy parameter + /// A dummy parameter + /// A dummy parameter + /// A dummy parameter + internal static MethodInfo FastGetMethodDefinition(System.Func func, T1 a1, T2 a2, T3 a3, T4 a4, T5 a5) + { + var method = func.Method; + return method.IsGenericMethod ? method.GetGenericMethodDefinition() : method; + } + /// /// Get the for a public overload of a given method if the method does not match /// given parameter types, otherwise directly yield the given method. diff --git a/src/NHibernate/Util/ReflectionCache.cs b/src/NHibernate/Util/ReflectionCache.cs index 8e791fc2533..c40a395f98d 100644 --- a/src/NHibernate/Util/ReflectionCache.cs +++ b/src/NHibernate/Util/ReflectionCache.cs @@ -68,6 +68,21 @@ internal static class QueryableMethods { internal static readonly MethodInfo SelectDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Select, default(IQueryable), default(Expression>)); + internal static readonly MethodInfo SelectManyDefinition = + ReflectHelper.FastGetMethodDefinition( + Queryable.SelectMany, + default(IQueryable), + default(Expression>>), + default(Expression>)); + + internal static readonly MethodInfo GroupJoinDefinition = + ReflectHelper.FastGetMethodDefinition( + Queryable.GroupJoin, + default(IQueryable), + default(IEnumerable), + default(Expression>), + default(Expression>), + default(Expression, object>>)); internal static readonly MethodInfo CountDefinition = ReflectHelper.FastGetMethodDefinition(Queryable.Count, default(IQueryable));