|
1 | 1 | using System;
|
2 | 2 | using System.Collections.Generic;
|
| 3 | +using System.Linq; |
3 | 4 | using System.Linq.Expressions;
|
4 | 5 | using FluentNHibernate.Testing.Values;
|
5 | 6 | using FluentNHibernate.Utils;
|
@@ -81,6 +82,19 @@ public static PersistenceSpecification<T> CheckReference<T>(this PersistenceSpec
|
81 | 82 | return spec.RegisterCheckedProperty(new ReferenceProperty<T, object>(property, propertyValue), propertyComparer);
|
82 | 83 | }
|
83 | 84 |
|
| 85 | + public static PersistenceSpecification<T> CheckReference<T, TReference>(this PersistenceSpecification<T> spec, |
| 86 | + Expression<Func<T, object>> expression, |
| 87 | + TReference propertyValue, |
| 88 | + params Func<TReference, object>[] propertiesToCompare) |
| 89 | + { |
| 90 | + // Because of the params keyword, the compiler will select this overload |
| 91 | + // instead of the one above, even when no funcs are supplied in the method call. |
| 92 | + if (propertiesToCompare == null || propertiesToCompare.Length == 0) |
| 93 | + return spec.CheckReference(expression, propertyValue, (IEqualityComparer)null); |
| 94 | + |
| 95 | + return spec.CheckReference(expression, propertyValue, new FuncEqualityComparer<TReference>(propertiesToCompare)); |
| 96 | + } |
| 97 | + |
84 | 98 | public static PersistenceSpecification<T> CheckReference<T, TProperty>(this PersistenceSpecification<T> spec,
|
85 | 99 | Expression<Func<T, TProperty>> expression,
|
86 | 100 | TProperty propertyValue,
|
@@ -254,5 +268,25 @@ public static PersistenceSpecification<T> CheckEnumerable<T, TItem>(this Persist
|
254 | 268 | {
|
255 | 269 | return spec.CheckList(expression, itemsToAdd, addAction);
|
256 | 270 | }
|
| 271 | + |
| 272 | + private class FuncEqualityComparer<T> : EqualityComparer<T> |
| 273 | + { |
| 274 | + readonly IEnumerable<Func<T, object>> comparisons; |
| 275 | + |
| 276 | + public FuncEqualityComparer(IEnumerable<Func<T, object>> comparisons) |
| 277 | + { |
| 278 | + this.comparisons = comparisons; |
| 279 | + } |
| 280 | + |
| 281 | + public override bool Equals(T x, T y) |
| 282 | + { |
| 283 | + return comparisons.All(func => object.Equals(func(x), func(y))); |
| 284 | + } |
| 285 | + |
| 286 | + public override int GetHashCode(T obj) |
| 287 | + { |
| 288 | + throw new NotSupportedException(); |
| 289 | + } |
| 290 | + } |
257 | 291 | }
|
258 | 292 | }
|
0 commit comments