diff --git a/src/libraries/System.Linq/src/System.Linq.csproj b/src/libraries/System.Linq/src/System.Linq.csproj index 2f93e584d18658..f032db1a303ee1 100644 --- a/src/libraries/System.Linq/src/System.Linq.csproj +++ b/src/libraries/System.Linq/src/System.Linq.csproj @@ -62,6 +62,7 @@ + @@ -70,12 +71,10 @@ - - @@ -83,6 +82,7 @@ + diff --git a/src/libraries/System.Linq/src/System/Linq/Count.cs b/src/libraries/System.Linq/src/System/Linq/Count.cs index 6a12a11cbe163d..cd175cf6b9b3cc 100644 --- a/src/libraries/System.Linq/src/System/Linq/Count.cs +++ b/src/libraries/System.Linq/src/System/Linq/Count.cs @@ -20,7 +20,7 @@ public static int Count(this IEnumerable source) return collectionoft.Count; } - if (!IsSizeOptimized && source is Iterator iterator) + if (source is Iterator iterator) { return iterator.GetCount(onlyIfCheap: false); } @@ -113,7 +113,7 @@ public static bool TryGetNonEnumeratedCount(this IEnumerable s return true; } - if (!IsSizeOptimized && source is Iterator iterator) + if (source is Iterator iterator) { int c = iterator.GetCount(onlyIfCheap: true); if (c >= 0) diff --git a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs index 26c69366fa9f3b..f4dc1f23a16c80 100644 --- a/src/libraries/System.Linq/src/System/Linq/ElementAt.cs +++ b/src/libraries/System.Linq/src/System/Linq/ElementAt.cs @@ -23,7 +23,7 @@ public static TSource ElementAt(this IEnumerable source, int i bool found; TSource? element = - !IsSizeOptimized && source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : + source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : TryGetElementAtNonIterator(source, index, out found); if (!found) @@ -121,7 +121,7 @@ public static TSource ElementAt(this IEnumerable source, Index } return - !IsSizeOptimized && source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : + source is Iterator iterator ? iterator.TryGetElementAt(index, out found) : TryGetElementAtNonIterator(source, index, out found); } diff --git a/src/libraries/System.Linq/src/System/Linq/Last.cs b/src/libraries/System.Linq/src/System/Linq/Last.cs index ca48475259d8e5..c38604e397592f 100644 --- a/src/libraries/System.Linq/src/System/Linq/Last.cs +++ b/src/libraries/System.Linq/src/System/Linq/Last.cs @@ -69,7 +69,7 @@ public static TSource LastOrDefault(this IEnumerable source, F } return - !IsSizeOptimized && source is Iterator iterator ? iterator.TryGetLast(out found) : + source is Iterator iterator ? iterator.TryGetLast(out found) : TryGetLastNonIterator(source, out found); } diff --git a/src/libraries/System.Linq/src/System/Linq/OfType.SpeedOpt.cs b/src/libraries/System.Linq/src/System/Linq/OfType.SpeedOpt.cs index 4420801ee4abcf..f2649e706c8a0e 100644 --- a/src/libraries/System.Linq/src/System/Linq/OfType.SpeedOpt.cs +++ b/src/libraries/System.Linq/src/System/Linq/OfType.SpeedOpt.cs @@ -167,7 +167,7 @@ public override IEnumerable Select(Func s // they're covariant. It's not worthwhile checking for List to use the ListWhereSelectIterator // because List<> is not covariant. Func isTResult = static o => o is TResult; - return objectSource is object[] array ? + return !IsSizeOptimized && objectSource is object[] array ? new ArrayWhereSelectIterator(array, isTResult, localSelector) : new IEnumerableWhereSelectIterator(objectSource, isTResult, localSelector); } @@ -177,7 +177,11 @@ public override IEnumerable Select(Func s public override bool Contains(TResult value) { - if (!typeof(TResult).IsValueType && // don't box TResult + // Avoid checking for IList when size-optimized because it keeps IList + // implementations which may otherwise be trimmed. Since List implements + // IList and List is popular, this could potentially be a lot of code. + if (!IsSizeOptimized && + !typeof(TResult).IsValueType && // don't box TResult _source is IList list) { return list.Contains(value); diff --git a/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs new file mode 100644 index 00000000000000..bf476355511abd --- /dev/null +++ b/src/libraries/System.Linq/src/System/Linq/Select.SizeOpt.cs @@ -0,0 +1,100 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections; +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + private sealed class SizeOptIListSelectIterator(IList _source, Func _selector) + : Iterator + { + private IEnumerator? _enumerator; + + public override int GetCount(bool onlyIfCheap) + { + // In case someone uses Count() to force evaluation of + // the selector, run it provided `onlyIfCheap` is false. + + if (onlyIfCheap) + { + return -1; + } + + int count = 0; + + foreach (TSource item in _source) + { + _selector(item); + checked + { + count++; + } + } + + return count; + } + + public override Iterator Skip(int count) + { + Debug.Assert(count > 0); + return new IListSkipTakeSelectIterator(_source, _selector, count, int.MaxValue); + } + + public override Iterator Take(int count) + { + Debug.Assert(count > 0); + return new IListSkipTakeSelectIterator(_source, _selector, 0, count - 1); + } + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _enumerator = _source.GetEnumerator(); + _state = 2; + goto case 2; + case 2: + Debug.Assert(_enumerator is not null); + if (_enumerator.MoveNext()) + { + _current = _selector(_enumerator.Current); + return true; + } + + Dispose(); + break; + } + + return false; + } + + public override TResult[] ToArray() + { + TResult[] array = new TResult[_source.Count]; + for (int i = 0; i < array.Length; i++) + { + array[i] = _selector(_source[i]); + } + return array; + } + + public override List ToList() + { + List list = new List(_source.Count); + for (int i = 0; i < list.Count; i++) + { + list.Add(_selector(_source[i])); + } + return list; + } + + private protected override Iterator Clone() + => new SizeOptIListSelectIterator(_source, _selector); + } + } +} diff --git a/src/libraries/System.Linq/src/System/Linq/Select.cs b/src/libraries/System.Linq/src/System/Linq/Select.cs index ac27c8dda22eeb..d5fcb864384440 100644 --- a/src/libraries/System.Linq/src/System/Linq/Select.cs +++ b/src/libraries/System.Linq/src/System/Linq/Select.cs @@ -32,7 +32,9 @@ public static IEnumerable Select( // don't need more code, just more data structures describing the new types). if (IsSizeOptimized && typeof(TResult).IsValueType) { - return new IEnumerableSelectIterator(iterator, selector); + return source is IList il + ? new SizeOptIListSelectIterator(il, selector) + : new IEnumerableSelectIterator(iterator, selector); } else { @@ -42,6 +44,11 @@ public static IEnumerable Select( if (source is IList ilist) { + if (IsSizeOptimized) + { + return new SizeOptIListSelectIterator(ilist, selector); + } + if (source is TSource[] array) { if (array.Length == 0) diff --git a/src/libraries/System.Linq/src/System/Linq/Skip.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Skip.SizeOpt.cs deleted file mode 100644 index 13e6642ee1fc02..00000000000000 --- a/src/libraries/System.Linq/src/System/Linq/Skip.SizeOpt.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; - -namespace System.Linq -{ - public static partial class Enumerable - { - private static IEnumerable SizeOptimizedSkipIterator(IEnumerable source, int count) - { - using IEnumerator e = source.GetEnumerator(); - while (count > 0 && e.MoveNext()) count--; - if (count <= 0) - { - while (e.MoveNext()) yield return e.Current; - } - } - } -} diff --git a/src/libraries/System.Linq/src/System/Linq/Skip.cs b/src/libraries/System.Linq/src/System/Linq/Skip.cs index ac8252a07c6d0c..4e76f1afba9794 100644 --- a/src/libraries/System.Linq/src/System/Linq/Skip.cs +++ b/src/libraries/System.Linq/src/System/Linq/Skip.cs @@ -30,12 +30,12 @@ public static IEnumerable Skip(this IEnumerable sourc count = 0; } - else if (!IsSizeOptimized && source is Iterator iterator) + else if (source is Iterator iterator) { return iterator.Skip(count) ?? Empty(); } - return IsSizeOptimized ? SizeOptimizedSkipIterator(source, count) : SpeedOptimizedSkipIterator(source, count); + return SpeedOptimizedSkipIterator(source, count); } public static IEnumerable SkipWhile(this IEnumerable source, Func predicate) diff --git a/src/libraries/System.Linq/src/System/Linq/Take.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Take.SizeOpt.cs deleted file mode 100644 index 6f2bd0d9b0fa6b..00000000000000 --- a/src/libraries/System.Linq/src/System/Linq/Take.SizeOpt.cs +++ /dev/null @@ -1,47 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Collections.Generic; -using System.Diagnostics; - -namespace System.Linq -{ - public static partial class Enumerable - { - private static IEnumerable SizeOptimizedTakeIterator(IEnumerable source, int count) - { - Debug.Assert(count > 0); - - foreach (TSource element in source) - { - yield return element; - if (--count == 0) break; - } - } - - private static IEnumerable SizeOptimizedTakeRangeIterator(IEnumerable source, int startIndex, int endIndex) - { - Debug.Assert(source is not null); - Debug.Assert(startIndex >= 0 && startIndex < endIndex); - - using IEnumerator e = source.GetEnumerator(); - - int index = 0; - while (index < startIndex && e.MoveNext()) - { - ++index; - } - - if (index < startIndex) - { - yield break; - } - - while (index < endIndex && e.MoveNext()) - { - yield return e.Current; - ++index; - } - } - } -} diff --git a/src/libraries/System.Linq/src/System/Linq/Take.cs b/src/libraries/System.Linq/src/System/Linq/Take.cs index 9df5fbc8a2bec8..8d6ad9acb3998d 100644 --- a/src/libraries/System.Linq/src/System/Linq/Take.cs +++ b/src/libraries/System.Linq/src/System/Linq/Take.cs @@ -20,7 +20,7 @@ public static IEnumerable Take(this IEnumerable sourc return []; } - return IsSizeOptimized ? SizeOptimizedTakeIterator(source, count) : SpeedOptimizedTakeIterator(source, count); + return SpeedOptimizedTakeIterator(source, count); } /// Returns a specified range of contiguous elements from a sequence. @@ -68,7 +68,7 @@ public static IEnumerable Take(this IEnumerable sourc return []; } - return IsSizeOptimized ? SizeOptimizedTakeRangeIterator(source, startIndex, endIndex) : SpeedOptimizedTakeRangeIterator(source, startIndex, endIndex); + return SpeedOptimizedTakeRangeIterator(source, startIndex, endIndex); } return TakeRangeFromEndIterator(source, isStartIndexFromEnd, startIndex, isEndIndexFromEnd, endIndex); @@ -94,9 +94,7 @@ private static IEnumerable TakeRangeFromEndIterator(IEnumerabl if (startIndex < endIndex) { - IEnumerable rangeIterator = IsSizeOptimized - ? SizeOptimizedTakeRangeIterator(source, startIndex, endIndex) - : SpeedOptimizedTakeRangeIterator(source, startIndex, endIndex); + IEnumerable rangeIterator = SpeedOptimizedTakeRangeIterator(source, startIndex, endIndex); foreach (TSource element in rangeIterator) { yield return element; diff --git a/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs b/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs new file mode 100644 index 00000000000000..8e49701a11efcb --- /dev/null +++ b/src/libraries/System.Linq/src/System/Linq/Where.SizeOpt.cs @@ -0,0 +1,114 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Collections.Generic; +using System.Diagnostics; + +namespace System.Linq +{ + public static partial class Enumerable + { + private sealed partial class SizeOptIListWhereIterator : Iterator + { + private readonly IList _source; + private readonly Func _predicate; + private IEnumerator? _enumerator; + + public SizeOptIListWhereIterator(IList source, Func predicate) + { + Debug.Assert(source is not null && source.Count > 0); + Debug.Assert(predicate is not null); + _source = source; + _predicate = predicate; + } + + private protected override Iterator Clone() => + new SizeOptIListWhereIterator(_source, _predicate); + + public override bool MoveNext() + { + switch (_state) + { + case 1: + _enumerator = _source.GetEnumerator(); + _state = 2; + goto case 2; + case 2: + while (_enumerator!.MoveNext()) + { + TSource item = _enumerator.Current; + if (_predicate(item)) + { + _current = item; + return true; + } + } + + Dispose(); + break; + } + + return false; + } + + public override IEnumerable Where(Func predicate) => + new SizeOptIListWhereIterator(_source, Utilities.CombinePredicates(_predicate, predicate)); + + public override TSource[] ToArray() + { + SegmentedArrayBuilder.ScratchBuffer scratch = default; + SegmentedArrayBuilder builder = new(scratch); + + foreach (TSource item in _source) + { + if (_predicate(item)) + { + builder.Add(item); + } + } + + TSource[] result = builder.ToArray(); + builder.Dispose(); + + return result; + } + + public override List ToList() + { + SegmentedArrayBuilder.ScratchBuffer scratch = default; + SegmentedArrayBuilder builder = new(scratch); + + foreach (TSource item in _source) + { + if (_predicate(item)) + { + builder.Add(item); + } + } + + List result = builder.ToList(); + builder.Dispose(); + + return result; + } + + public override int GetCount(bool onlyIfCheap) + { + if (onlyIfCheap) + { + return -1; + } + + int count = 0; + foreach (TSource item in _source) + { + if (_predicate(item)) + { + checked { count++; } + } + } + return count; + } + } + } +} diff --git a/src/libraries/System.Linq/src/System/Linq/Where.cs b/src/libraries/System.Linq/src/System/Linq/Where.cs index 4371af8299fb2e..c0028a1c004a22 100644 --- a/src/libraries/System.Linq/src/System/Linq/Where.cs +++ b/src/libraries/System.Linq/src/System/Linq/Where.cs @@ -26,6 +26,12 @@ public static IEnumerable Where(this IEnumerable sour return iterator.Where(predicate); } + // Only use IList when size-optimizing (no array or List specializations). + if (IsSizeOptimized && source is IList sourceList) + { + return new SizeOptIListWhereIterator(sourceList, predicate); + } + if (source is TSource[] array) { if (array.Length == 0) @@ -143,7 +149,7 @@ public override IEnumerable Select(Func sele new IEnumerableWhereSelectIterator(_source, _predicate, selector); public override IEnumerable Where(Func predicate) => - new IEnumerableWhereIterator(_source, CombinePredicates(_predicate, predicate)); + new IEnumerableWhereIterator(_source, Utilities.CombinePredicates(_predicate, predicate)); } /// diff --git a/src/libraries/System.Linq/tests/CountTests.cs b/src/libraries/System.Linq/tests/CountTests.cs index ddf96d4cf4b59a..27b18f77870218 100644 --- a/src/libraries/System.Linq/tests/CountTests.cs +++ b/src/libraries/System.Linq/tests/CountTests.cs @@ -3,10 +3,11 @@ using System.Collections.Generic; using Xunit; +using Xunit.Abstractions; namespace System.Linq.Tests { - public class CountTests : EnumerableTests + public class CountTests(ITestOutputHelper output) : EnumerableTests { [Fact] public void SameResultsRepeatCallsIntQuery() @@ -151,6 +152,7 @@ public void NonEnumeratedCount_SupportedEnumerables_ShouldReturnExpectedCount [MemberData(nameof(NonEnumeratedCount_UnsupportedEnumerables))] public void NonEnumeratedCount_UnsupportedEnumerables_ShouldReturnFalse(IEnumerable source) { + output.WriteLine(source.GetType().FullName); Assert.False(source.TryGetNonEnumeratedCount(out int actualCount)); Assert.Equal(0, actualCount); } @@ -180,15 +182,15 @@ public static IEnumerable NonEnumeratedCount_SupportedEnumerables() yield return WrapArgs(100, Enumerable.Range(1, 100)); yield return WrapArgs(80, Enumerable.Repeat(1, 80)); + yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse()); + yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x)); + yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10))); if (PlatformDetection.IsLinqSpeedOptimized) { yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1)); yield return WrapArgs(4, new int[] { 1, 2, 3, 4 }.Select(x => x + 1)); yield return WrapArgs(50, Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1)); - yield return WrapArgs(20, Enumerable.Range(1, 20).Reverse()); - yield return WrapArgs(20, Enumerable.Range(1, 20).OrderBy(x => -x)); - yield return WrapArgs(20, Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10))); } static object[] WrapArgs(int expectedCount, IEnumerable source) => [expectedCount, source]; @@ -204,11 +206,8 @@ public static IEnumerable NonEnumeratedCount_UnsupportedEnumerables() if (!PlatformDetection.IsLinqSpeedOptimized) { yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1)); - yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1)); + yield return WrapArgs(new int[] { 1, 2, 3, 4 }.Select(x => x + 1)); yield return WrapArgs(Enumerable.Range(1, 50).Select(x => x + 1).Select(x => x - 1)); - yield return WrapArgs(Enumerable.Range(1, 20).Reverse()); - yield return WrapArgs(Enumerable.Range(1, 20).OrderBy(x => -x)); - yield return WrapArgs(Enumerable.Range(1, 10).Concat(Enumerable.Range(11, 10))); } static object[] WrapArgs(IEnumerable source) => [source]; diff --git a/src/libraries/System.Linq/tests/OrderedSubsetting.cs b/src/libraries/System.Linq/tests/OrderedSubsetting.cs index 5804ac1d4229e7..7826fb338dbce5 100644 --- a/src/libraries/System.Linq/tests/OrderedSubsetting.cs +++ b/src/libraries/System.Linq/tests/OrderedSubsetting.cs @@ -224,7 +224,7 @@ public void TakeAndSkip() Assert.Equal(Enumerable.Range(10, 1), ordered.Take(11).Skip(10)); } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsLinqSpeedOptimized))] + [Fact] public void TakeAndSkip_DoesntIterateRangeUnlessNecessary() { Assert.Empty(Enumerable.Range(0, int.MaxValue).Take(int.MaxValue).OrderBy(i => i).Skip(int.MaxValue - 4).Skip(15)); diff --git a/src/libraries/System.Linq/tests/RangeTests.cs b/src/libraries/System.Linq/tests/RangeTests.cs index 476d4804fefeef..0bb6acca8bc685 100644 --- a/src/libraries/System.Linq/tests/RangeTests.cs +++ b/src/libraries/System.Linq/tests/RangeTests.cs @@ -236,7 +236,7 @@ public void LastOrDefault() Assert.Equal(int.MaxValue - 101, GetRange(-100, int.MaxValue).LastOrDefault()); } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsLinqSpeedOptimized))] + [Fact] public void IListImplementationIsValid() { Validate(GetRange(42, 10), [42, 43, 44, 45, 46, 47, 48, 49, 50, 51]); diff --git a/src/libraries/System.Linq/tests/SelectTests.cs b/src/libraries/System.Linq/tests/SelectTests.cs index c7c84565485c73..5fdae70d06df38 100644 --- a/src/libraries/System.Linq/tests/SelectTests.cs +++ b/src/libraries/System.Linq/tests/SelectTests.cs @@ -12,6 +12,32 @@ namespace System.Linq.Tests { public class SelectTests : EnumerableTests { + [Fact] + public void SelectSideEffectsExecutedOnCount() + { + int i = 0; + // If we made no promises about side effects, i could be 0, but in practice users have + // taken a dependency on side effects executing on Count. + var count = Enumerable.Range(1, 10).Select(x => i++).Count(); + Assert.Equal(10, count); + Assert.Equal(10, i); + + i = 0; + count = Enumerable.Range(1, 10).Skip(5).Select(x => i++).Count(); + Assert.Equal(5, count); + Assert.Equal(5, i); + + i = 0; + count = Enumerable.Range(1, 10).Take(5).Select(x => i++).Count(); + Assert.Equal(5, count); + Assert.Equal(5, i); + + i = 0; + count = Enumerable.Range(1, 10).Skip(2).Take(3).Select(x => i++).Count(); + Assert.Equal(3, count); + Assert.Equal(3, i); + } + [Fact] public void SameResultsRepeatCallsStringQuery() { diff --git a/src/libraries/System.Linq/tests/TakeTests.cs b/src/libraries/System.Linq/tests/TakeTests.cs index bcdaa42df0a9b5..310a942e507feb 100644 --- a/src/libraries/System.Linq/tests/TakeTests.cs +++ b/src/libraries/System.Linq/tests/TakeTests.cs @@ -669,7 +669,7 @@ public void RepeatEnumerating() } } - [ConditionalTheory(typeof(PlatformDetection), nameof(PlatformDetection.IsLinqSpeedOptimized))] + [Theory] [InlineData(1000)] [InlineData(1000000)] [InlineData(int.MaxValue)] @@ -1623,7 +1623,7 @@ public void EmptySource_DoNotThrowException_EnumerablePartition() Assert.Empty(EnumerablePartitionOrEmpty(source).Take(^6..^7)); } - [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsLinqSpeedOptimized))] + [Fact] public void SkipTakeOnIListIsIList() { IList list = new ReadOnlyCollection(Enumerable.Range(0, 100).ToList()); diff --git a/src/libraries/tests.proj b/src/libraries/tests.proj index 64c690245b51c5..0549cf599f6091 100644 --- a/src/libraries/tests.proj +++ b/src/libraries/tests.proj @@ -577,6 +577,7 @@ +