Skip to content
This repository was archived by the owner on Jan 23, 2023. It is now read-only.

Commit baed3e7

Browse files
committed
Merge pull request #2120 from stephentoub/select_toarray
Optimize array/list.Select(...).ToArray()
2 parents 29f95e7 + 872bb70 commit baed3e7

File tree

4 files changed

+142
-9
lines changed

4 files changed

+142
-9
lines changed

src/System.Linq/src/System/Linq/Enumerable.cs

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,11 @@ public IEnumerable<TResult> Select<TResult>(Func<TSource, TResult> selector)
144144

145145
public abstract IEnumerable<TSource> Where(Func<TSource, bool> predicate);
146146

147+
public virtual TSource[] ToArray()
148+
{
149+
return new Buffer<TSource>(this, queryInterfaces: false).ToArray();
150+
}
151+
147152
object IEnumerator.Current
148153
{
149154
get { return Current; }
@@ -431,6 +436,21 @@ public override IEnumerable<TResult> Where(Func<TResult, bool> predicate)
431436
{
432437
return new WhereEnumerableIterator<TResult>(this, predicate);
433438
}
439+
440+
public override TResult[] ToArray()
441+
{
442+
if (_predicate != null)
443+
{
444+
return base.ToArray();
445+
}
446+
447+
var results = new TResult[_source.Length];
448+
for (int i = 0; i < results.Length; i++)
449+
{
450+
results[i] = _selector(_source[i]);
451+
}
452+
return results;
453+
}
434454
}
435455

436456
internal class WhereSelectListIterator<TSource, TResult> : Iterator<TResult>
@@ -486,6 +506,21 @@ public override IEnumerable<TResult> Where(Func<TResult, bool> predicate)
486506
{
487507
return new WhereEnumerableIterator<TResult>(this, predicate);
488508
}
509+
510+
public override TResult[] ToArray()
511+
{
512+
if (_predicate != null)
513+
{
514+
return base.ToArray();
515+
}
516+
517+
var results = new TResult[_source.Count];
518+
for (int i = 0; i < results.Length; i++)
519+
{
520+
results[i] = _selector(_source[i]);
521+
}
522+
return results;
523+
}
489524
}
490525

491526
//public static IEnumerable<TSource> Where<TSource>(this IEnumerable<TSource> source, Func<TSource, bool> predicate) {
@@ -3079,21 +3114,35 @@ internal struct Buffer<TElement>
30793114
internal TElement[] items;
30803115
internal int count;
30813116

3082-
internal Buffer(IEnumerable<TElement> source)
3117+
internal Buffer(IEnumerable<TElement> source, bool queryInterfaces = true)
30833118
{
30843119
TElement[] items = null;
30853120
int count = 0;
3086-
ICollection<TElement> collection = source as ICollection<TElement>;
3087-
if (collection != null)
3121+
3122+
if (queryInterfaces)
30883123
{
3089-
count = collection.Count;
3090-
if (count > 0)
3124+
Enumerable.Iterator<TElement> iterator = source as Enumerable.Iterator<TElement>;
3125+
if (iterator != null)
30913126
{
3092-
items = new TElement[count];
3093-
collection.CopyTo(items, 0);
3127+
items = iterator.ToArray();
3128+
count = items.Length;
3129+
}
3130+
else
3131+
{
3132+
ICollection<TElement> collection = source as ICollection<TElement>;
3133+
if (collection != null)
3134+
{
3135+
count = collection.Count;
3136+
if (count > 0)
3137+
{
3138+
items = new TElement[count];
3139+
collection.CopyTo(items, 0);
3140+
}
3141+
}
30943142
}
30953143
}
3096-
else
3144+
3145+
if (items == null)
30973146
{
30983147
foreach (TElement item in source)
30993148
{
@@ -3109,6 +3158,7 @@ internal Buffer(IEnumerable<TElement> source)
31093158
count++;
31103159
}
31113160
}
3161+
31123162
this.items = items;
31133163
this.count = count;
31143164
}

src/System.Linq/tests/ReverseTests.cs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System.Linq.Tests.Helpers;
2+
using Xunit;
3+
4+
namespace System.Linq.Tests
5+
{
6+
public class ReverseTests
7+
{
8+
[Fact]
9+
public void InvalidArguments()
10+
{
11+
Assert.Throws<ArgumentNullException>(() => Enumerable.Reverse<string>(null));
12+
}
13+
14+
[Theory]
15+
[InlineData(new int[] { })]
16+
[InlineData(new int[] { 1 })]
17+
[InlineData(new int[] { 1, 3, 5 })]
18+
[InlineData(new int[] { 2, 4, 6, 8 })]
19+
public void ReverseMatches(int[] input)
20+
{
21+
int[] expectedResults = new int[input.Length];
22+
for (int i = 0; i < input.Length; i++)
23+
{
24+
expectedResults[i] = input[input.Length - 1 - i];
25+
}
26+
27+
Assert.NotSame(input, Enumerable.Reverse(input));
28+
29+
Assert.Equal(expectedResults, input.Reverse());
30+
Assert.Equal(expectedResults, new TestCollection<int>(input).Reverse());
31+
Assert.Equal(expectedResults, new TestEnumerable<int>(input).Reverse());
32+
Assert.Equal(expectedResults, new TestReadOnlyCollection<int>(input).Reverse());
33+
34+
Assert.Equal(expectedResults.Select(i => i * 2), input.Select(i => i * 2).Reverse());
35+
Assert.Equal(expectedResults.Where(i => true).Select(i => i * 2), input.Where(i => true).Select(i => i * 2).Reverse());
36+
Assert.Equal(expectedResults.Where(i => false).Select(i => i * 2), input.Where(i => false).Select(i => i * 2).Reverse());
37+
}
38+
}
39+
}

src/System.Linq/tests/System.Linq.Tests.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
<Compile Include="RangeTests.cs" />
3636
<Compile Include="RepeatTests.cs" />
3737
<Compile Include="SumTests.cs" />
38+
<Compile Include="ReverseTests.cs" />
3839
<Compile Include="WhereTests.cs" />
3940
<Compile Include="ToArrayTests.cs" />
4041
<Compile Include="ToDictionaryTests.cs" />
@@ -49,4 +50,4 @@
4950
</ProjectReference>
5051
</ItemGroup>
5152
<Import Project="$([MSBuild]::GetDirectoryNameOfFileAbove($(MSBuildThisFileDirectory), dir.targets))\dir.targets" />
52-
</Project>
53+
</Project>

src/System.Linq/tests/ToArrayTests.cs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ public void ToArray_AlwaysCreateACopy()
5252

5353
Assert.NotSame(sourceArray, resultArray);
5454
Assert.Equal(sourceArray, resultArray);
55+
56+
int[] emptySourceArray = Array.Empty<int>();
57+
Assert.NotSame(emptySourceArray.ToArray(), emptySourceArray.ToArray());
58+
Assert.NotSame(emptySourceArray.Select(i => i).ToArray(), emptySourceArray.Select(i => i).ToArray());
5559
}
5660

5761

@@ -153,5 +157,44 @@ public void ToArray_FailOnExtremelyLargeCollection()
153157
var thrownException = Assert.ThrowsAny<Exception>(() => { largeSeq.ToArray(); });
154158
Assert.True(thrownException.GetType() == typeof(OverflowException) || thrownException.GetType() == typeof(OutOfMemoryException));
155159
}
160+
161+
[Theory]
162+
[InlineData(new int[] { }, new string[] { })]
163+
[InlineData(new int[] { 1 }, new string[] { "1" })]
164+
[InlineData(new int[] { 1, 2, 3 }, new string[] { "1", "2", "3" })]
165+
public void ToArray_ArrayWhereSelect(int[] sourceIntegers, string[] convertedStrings)
166+
{
167+
Assert.Equal(convertedStrings, sourceIntegers.Select(i => i.ToString()).ToArray());
168+
169+
Assert.Equal(sourceIntegers, sourceIntegers.Where(i => true).ToArray());
170+
Assert.Equal(Array.Empty<int>(), sourceIntegers.Where(i => false).ToArray());
171+
172+
Assert.Equal(convertedStrings, sourceIntegers.Where(i => true).Select(i => i.ToString()).ToArray());
173+
Assert.Equal(Array.Empty<string>(), sourceIntegers.Where(i => false).Select(i => i.ToString()).ToArray());
174+
175+
Assert.Equal(convertedStrings, sourceIntegers.Select(i => i.ToString()).Where(s => s != null).ToArray());
176+
Assert.Equal(Array.Empty<string>(), sourceIntegers.Select(i => i.ToString()).Where(s => s == null).ToArray());
177+
}
178+
179+
[Theory]
180+
[InlineData(new int[] { }, new string[] { })]
181+
[InlineData(new int[] { 1 }, new string[] { "1" })]
182+
[InlineData(new int[] { 1, 2, 3 }, new string[] { "1", "2", "3" })]
183+
public void ToArray_ListWhereSelect(int[] sourceIntegers, string[] convertedStrings)
184+
{
185+
var sourceList = new List<int>(sourceIntegers);
186+
187+
Assert.Equal(convertedStrings, sourceList.Select(i => i.ToString()).ToArray());
188+
189+
Assert.Equal(sourceList, sourceList.Where(i => true).ToArray());
190+
Assert.Equal(Array.Empty<int>(), sourceList.Where(i => false).ToArray());
191+
192+
Assert.Equal(convertedStrings, sourceList.Where(i => true).Select(i => i.ToString()).ToArray());
193+
Assert.Equal(Array.Empty<string>(), sourceList.Where(i => false).Select(i => i.ToString()).ToArray());
194+
195+
Assert.Equal(convertedStrings, sourceList.Select(i => i.ToString()).Where(s => s != null).ToArray());
196+
Assert.Equal(Array.Empty<string>(), sourceList.Select(i => i.ToString()).Where(s => s == null).ToArray());
197+
}
198+
156199
}
157200
}

0 commit comments

Comments
 (0)