Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit e7e8e54

Browse files
committed
Short-circuiting versions of Enumerable methods.
Fixes #2349 Last<TSource>(this IEnumerable<TSource>, Func<TSource, bool>) LastOrDefault<TSource>(this IEnumerable<TSource>, Func<TSource, bool>) Single<TSource>(this IEnumerable<TSource>, Func<TSource, bool>) SingleOrDefault<TSource>(this IEnumerable<TSource>, Func<TSource, bool>) Min(this IEnumerable<float>) Min(this IEnumerable<float?>) Min(this IEnumerable<double>) Min(this IEnumerable<double?>) Last and LastOrDefault only short-circuit if the source is an IList<T>. Last and LastOrDefault add a check for if the source is an IList<T> comparable to that take by the form that don't take a predicate. Apart from that no tests are added: While we could check for e.g. int.MaxValue being seen by Max(this IEnumerable<int>), and so on, that would add a test for every element, and so penalise sequences that did not contain it. Min/MinOrDefault on double and float already had a similar improvement made in 5eb063a, but with a backwards-compatibility loop added. Here that loop is simply removed. Included are tests for both the state before and after the changes.
1 parent 0a6a0d1 commit e7e8e54

File tree

2 files changed

+71
-158
lines changed

2 files changed

+71
-158
lines changed

src/System.Linq/src/System/Linq/Enumerable.cs

Lines changed: 71 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,17 +1242,34 @@ public static TSource Last<TSource>(this IEnumerable<TSource> source, Func<TSour
12421242
{
12431243
if (source == null) throw Error.ArgumentNull("source");
12441244
if (predicate == null) throw Error.ArgumentNull("predicate");
1245-
TSource result = default(TSource);
1246-
bool found = false;
1247-
foreach (TSource element in source)
1245+
IList<TSource> list = source as IList<TSource>;
1246+
if (list != null)
12481247
{
1249-
if (predicate(element))
1248+
for (int i = list.Count - 1; i >= 0; --i)
1249+
{
1250+
TSource result = list[i];
1251+
if (predicate(result)) return result;
1252+
}
1253+
}
1254+
else
1255+
{
1256+
using (IEnumerator<TSource> e = source.GetEnumerator())
12501257
{
1251-
result = element;
1252-
found = true;
1258+
while (e.MoveNext())
1259+
{
1260+
TSource result = e.Current;
1261+
if (predicate(result))
1262+
{
1263+
while (e.MoveNext())
1264+
{
1265+
TSource element = e.Current;
1266+
if (predicate(element)) result = element;
1267+
}
1268+
return result;
1269+
}
1270+
}
12531271
}
12541272
}
1255-
if (found) return result;
12561273
throw Error.NoMatch();
12571274
}
12581275

@@ -1287,15 +1304,28 @@ public static TSource LastOrDefault<TSource>(this IEnumerable<TSource> source, F
12871304
{
12881305
if (source == null) throw Error.ArgumentNull("source");
12891306
if (predicate == null) throw Error.ArgumentNull("predicate");
1290-
TSource result = default(TSource);
1291-
foreach (TSource element in source)
1307+
IList<TSource> list = source as IList<TSource>;
1308+
if (list != null)
12921309
{
1293-
if (predicate(element))
1310+
for (int i = list.Count - 1; i >= 0; --i)
12941311
{
1295-
result = element;
1312+
TSource element = list[i];
1313+
if (predicate(element)) return element;
12961314
}
1315+
return default(TSource);
1316+
}
1317+
else
1318+
{
1319+
TSource result = default(TSource);
1320+
foreach (TSource element in source)
1321+
{
1322+
if (predicate(element))
1323+
{
1324+
result = element;
1325+
}
1326+
}
1327+
return result;
12971328
}
1298-
return result;
12991329
}
13001330

13011331
public static TSource Single<TSource>(this IEnumerable<TSource> source)
@@ -1326,22 +1356,22 @@ public static TSource Single<TSource>(this IEnumerable<TSource> source, Func<TSo
13261356
{
13271357
if (source == null) throw Error.ArgumentNull("source");
13281358
if (predicate == null) throw Error.ArgumentNull("predicate");
1329-
TSource result = default(TSource);
1330-
long count = 0;
1331-
foreach (TSource element in source)
1359+
using (IEnumerator<TSource> e = source.GetEnumerator())
13321360
{
1333-
if (predicate(element))
1361+
while (e.MoveNext())
13341362
{
1335-
result = element;
1336-
checked { count++; }
1363+
TSource result = e.Current;
1364+
if (predicate(result))
1365+
{
1366+
while (e.MoveNext())
1367+
{
1368+
if (predicate(e.Current)) throw Error.MoreThanOneMatch();
1369+
}
1370+
return result;
1371+
}
13371372
}
13381373
}
1339-
switch (count)
1340-
{
1341-
case 0: throw Error.NoMatch();
1342-
case 1: return result;
1343-
}
1344-
throw Error.MoreThanOneMatch();
1374+
throw Error.NoMatch();
13451375
}
13461376

13471377
public static TSource SingleOrDefault<TSource>(this IEnumerable<TSource> source)
@@ -1372,22 +1402,22 @@ public static TSource SingleOrDefault<TSource>(this IEnumerable<TSource> source,
13721402
{
13731403
if (source == null) throw Error.ArgumentNull("source");
13741404
if (predicate == null) throw Error.ArgumentNull("predicate");
1375-
TSource result = default(TSource);
1376-
long count = 0;
1377-
foreach (TSource element in source)
1405+
using (IEnumerator<TSource> e = source.GetEnumerator())
13781406
{
1379-
if (predicate(element))
1407+
while (e.MoveNext())
13801408
{
1381-
result = element;
1382-
checked { count++; }
1409+
TSource result = e.Current;
1410+
if (predicate(result))
1411+
{
1412+
while (e.MoveNext())
1413+
{
1414+
if (predicate(e.Current)) throw Error.MoreThanOneMatch();
1415+
}
1416+
return result;
1417+
}
13831418
}
13841419
}
1385-
switch (count)
1386-
{
1387-
case 0: return default(TSource);
1388-
case 1: return result;
1389-
}
1390-
throw Error.MoreThanOneMatch();
1420+
return default(TSource);
13911421
}
13921422

13931423
public static TSource ElementAt<TSource>(this IEnumerable<TSource> source, int index)
@@ -1873,13 +1903,8 @@ public static float Min(this IEnumerable<float> source)
18731903
// ordering where NaN is smaller than every value, including
18741904
// negative infinity.
18751905
// Not testing for NaN therefore isn't an option, but since we
1876-
// can't find a smaller value, we can short-circuit. But we consume
1877-
// the rest for backwards-compatibility reasons.
1878-
else if (float.IsNaN(x))
1879-
{
1880-
while (e.MoveNext()) {}
1881-
return x;
1882-
}
1906+
// can't find a smaller value, we can short-circuit.
1907+
else if (float.IsNaN(x)) return x;
18831908
}
18841909
}
18851910
return value;
@@ -1908,11 +1933,7 @@ public static float Min(this IEnumerable<float> source)
19081933
valueVal = x;
19091934
value = cur;
19101935
}
1911-
else if (float.IsNaN(x))
1912-
{
1913-
while (e.MoveNext()) { }
1914-
return cur;
1915-
}
1936+
else if (float.IsNaN(x)) return cur;
19161937
}
19171938
}
19181939
}
@@ -1931,11 +1952,7 @@ public static double Min(this IEnumerable<double> source)
19311952
{
19321953
double x = e.Current;
19331954
if (x < value) value = x;
1934-
else if (double.IsNaN(x))
1935-
{
1936-
while (e.MoveNext()) {}
1937-
return x;
1938-
}
1955+
else if (double.IsNaN(x)) return x;
19391956
}
19401957
}
19411958
return value;
@@ -1964,11 +1981,7 @@ public static double Min(this IEnumerable<double> source)
19641981
valueVal = x;
19651982
value = cur;
19661983
}
1967-
else if (double.IsNaN(x))
1968-
{
1969-
while (e.MoveNext()) {}
1970-
return cur;
1971-
}
1984+
else if (double.IsNaN(x)) return cur;
19721985
}
19731986
}
19741987
}

src/System.Linq/tests/ShortCircuitingTests.cs

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -51,98 +51,6 @@ public Func<T, TResult> Func
5151
}
5252

5353
[Fact]
54-
public void ListLastChecksAll()
55-
{
56-
var source = Enumerable.Range(0, 10).ToList();
57-
var pred = new CountedFunction<int, bool>(i => i < 7);
58-
Assert.Equal(6, source.Last(pred.Func));
59-
Assert.Equal(10, pred.Calls);
60-
}
61-
62-
[Fact]
63-
public void MinDoubleChecksAll()
64-
{
65-
var tracker = new TrackingEnumerable(10);
66-
var source = tracker.Select(i => i == 5 ? double.NaN : (double)i);
67-
Assert.True(double.IsNaN(source.Min()));
68-
Assert.Equal(10, tracker.Moves);
69-
}
70-
71-
[Fact]
72-
void MinNullableDoubleChecksAll()
73-
{
74-
var tracker = new TrackingEnumerable(10);
75-
var source = tracker.Select(i => (double?)(i == 5 ? double.NaN : (double)i));
76-
Assert.True(double.IsNaN(source.Min().GetValueOrDefault()));
77-
Assert.Equal(10, tracker.Moves);
78-
}
79-
80-
[Fact]
81-
public void MinSingleChecksAll()
82-
{
83-
var tracker = new TrackingEnumerable(10);
84-
var source = tracker.Select(i => i == 5 ? float.NaN : (float)i);
85-
Assert.True(float.IsNaN(source.Min()));
86-
Assert.Equal(10, tracker.Moves);
87-
}
88-
89-
[Fact]
90-
void MinNullableSingleChecksAll()
91-
{
92-
var tracker = new TrackingEnumerable(10);
93-
var source = tracker.Select(i => (float?)(i == 5 ? float.NaN : (float)i));
94-
Assert.True(float.IsNaN(source.Min().GetValueOrDefault()));
95-
Assert.Equal(10, tracker.Moves);
96-
}
97-
98-
[Fact]
99-
void SingleWithPredicateChecksAll()
100-
{
101-
var tracker = new TrackingEnumerable(10);
102-
var pred = new CountedFunction<int, bool>(i => i > 2);
103-
Assert.Throws<InvalidOperationException>(() => tracker.Single(pred.Func));
104-
Assert.Equal(10, tracker.Moves);
105-
Assert.Equal(10, pred.Calls);
106-
}
107-
108-
[Fact]
109-
void SingleOrDefaultWithPredicateChecksAll()
110-
{
111-
var tracker = new TrackingEnumerable(10);
112-
var pred = new CountedFunction<int, bool>(i => i > 2);
113-
Assert.Throws<InvalidOperationException>(() => tracker.SingleOrDefault(pred.Func));
114-
Assert.Equal(10, tracker.Moves);
115-
Assert.Equal(10, pred.Calls);
116-
}
117-
118-
[Fact]
119-
void SingleWithPredicateDifferentToWhereFollowedBySingle()
120-
{
121-
var tracker0 = new TrackingEnumerable(10);
122-
var pred0 = new CountedFunction<int, bool>(i => i > 2);
123-
Assert.Throws<InvalidOperationException>(() => tracker0.Single(pred0.Func));
124-
var tracker1 = new TrackingEnumerable(10);
125-
var pred1 = new CountedFunction<int, bool>(i => i > 2);
126-
Assert.Throws<InvalidOperationException>(() => tracker1.Where(pred1.Func).Single());
127-
Assert.NotEqual(tracker0.Moves, tracker1.Moves);
128-
Assert.NotEqual(pred0.Calls, pred1.Calls);
129-
}
130-
131-
[Fact]
132-
void SingleOrDefaultWithPredicateDifferentToWhereFollowedBySingleOrDefault()
133-
{
134-
var tracker0 = new TrackingEnumerable(10);
135-
var pred0 = new CountedFunction<int, bool>(i => i > 2);
136-
Assert.Throws<InvalidOperationException>(() => tracker0.SingleOrDefault(pred0.Func));
137-
var tracker1 = new TrackingEnumerable(10);
138-
var pred1 = new CountedFunction<int, bool>(i => i > 2);
139-
Assert.Throws<InvalidOperationException>(() => tracker1.Where(pred1.Func).SingleOrDefault());
140-
Assert.NotEqual(tracker0.Moves, tracker1.Moves);
141-
Assert.NotEqual(pred0.Calls, pred1.Calls);
142-
}
143-
144-
[Fact]
145-
[ActiveIssue(2349)]
14654
public void ListLastDoesntCheckAll()
14755
{
14856
var source = Enumerable.Range(0, 10).ToList();
@@ -152,7 +60,6 @@ public void ListLastDoesntCheckAll()
15260
}
15361

15462
[Fact]
155-
[ActiveIssue(2349)]
15663
public void MinDoubleDoesntCheckAll()
15764
{
15865
var tracker = new TrackingEnumerable(10);
@@ -162,7 +69,6 @@ public void MinDoubleDoesntCheckAll()
16269
}
16370

16471
[Fact]
165-
[ActiveIssue(2349)]
16672
void MinNullableDoubleDoesntCheckAll()
16773
{
16874
var tracker = new TrackingEnumerable(10);
@@ -172,7 +78,6 @@ void MinNullableDoubleDoesntCheckAll()
17278
}
17379

17480
[Fact]
175-
[ActiveIssue(2349)]
17681
public void MinSingleDoesntCheckAll()
17782
{
17883
var tracker = new TrackingEnumerable(10);
@@ -182,7 +87,6 @@ public void MinSingleDoesntCheckAll()
18287
}
18388

18489
[Fact]
185-
[ActiveIssue(2349)]
18690
void MinNullableSingleDoesntCheckAll()
18791
{
18892
var tracker = new TrackingEnumerable(10);
@@ -192,7 +96,6 @@ void MinNullableSingleDoesntCheckAll()
19296
}
19397

19498
[Fact]
195-
[ActiveIssue(2349)]
19699
void SingleWithPredicateDoesntCheckAll()
197100
{
198101
var tracker = new TrackingEnumerable(10);
@@ -203,7 +106,6 @@ void SingleWithPredicateDoesntCheckAll()
203106
}
204107

205108
[Fact]
206-
[ActiveIssue(2349)]
207109
void SingleOrDefaultWithPredicateDoesntCheckAll()
208110
{
209111
var tracker = new TrackingEnumerable(10);
@@ -214,7 +116,6 @@ void SingleOrDefaultWithPredicateDoesntCheckAll()
214116
}
215117

216118
[Fact]
217-
[ActiveIssue(2349)]
218119
void SingleWithPredicateWorksLikeWhereFollowedBySingle()
219120
{
220121
var tracker0 = new TrackingEnumerable(10);
@@ -228,7 +129,6 @@ void SingleWithPredicateWorksLikeWhereFollowedBySingle()
228129
}
229130

230131
[Fact]
231-
[ActiveIssue(2349)]
232132
void SingleOrDefaultWithPredicateWorksLikeWhereFollowedBySingleOrDefault()
233133
{
234134
var tracker0 = new TrackingEnumerable(10);

0 commit comments

Comments
 (0)