Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ public ContinuablePagedIterable(Supplier<PageRetrieverSync<C, P>> pageRetrieverS

@Override
public Stream<T> stream() {
return StreamSupport.stream(iterableByItemInternal().spliterator(), false);
// Return a stream that supports proper parallel processing
return new ParallelCapablePagedStream<>(iterableByItemInternal());
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,274 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package com.azure.core.util.paging;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Spliterator;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import java.util.function.BinaryOperator;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.function.ToDoubleFunction;
import java.util.function.ToIntFunction;
import java.util.function.ToLongFunction;
import java.util.stream.Collector;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import java.util.stream.LongStream;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;

/**
* A stream wrapper that enables proper parallel processing for paged data.
* This stream behaves sequentially by default but can efficiently switch to parallel
* processing when .parallel() is called by collecting all data first.
*
* @param <T> The type of elements in the stream
*/
final class ParallelCapablePagedStream<T> implements Stream<T> {
private final Iterable<T> source;
private Stream<T> delegate;
private boolean isParallel;

ParallelCapablePagedStream(Iterable<T> source) {
this.source = source;
this.delegate = null;
this.isParallel = false;
}

private Stream<T> getDelegate() {
if (delegate == null) {
if (isParallel) {
// For parallel processing, collect all items first and create a parallel stream
List<T> collected = new ArrayList<>();
source.forEach(collected::add);
delegate = collected.parallelStream();
} else {
// For sequential processing, use the original iterable
delegate = StreamSupport.stream(source.spliterator(), false);
}
}
return delegate;
}

@Override
public Stream<T> parallel() {
if (isParallel) {
return this;
}
ParallelCapablePagedStream<T> parallelStream = new ParallelCapablePagedStream<>(source);
parallelStream.isParallel = true;
return parallelStream;
}

@Override
public Stream<T> sequential() {
if (!isParallel) {
return this;
}
ParallelCapablePagedStream<T> sequentialStream = new ParallelCapablePagedStream<>(source);
sequentialStream.isParallel = false;
return sequentialStream;
}

@Override
public boolean isParallel() {
return isParallel;
}

@Override
public Stream<T> unordered() {
return getDelegate().unordered();
}

@Override
public Stream<T> onClose(Runnable closeHandler) {
return getDelegate().onClose(closeHandler);
}

@Override
public void close() {
if (delegate != null) {
delegate.close();
}
}

@Override
public Stream<T> filter(Predicate<? super T> predicate) {
return getDelegate().filter(predicate);
}

@Override
public <R> Stream<R> map(Function<? super T, ? extends R> mapper) {
return getDelegate().map(mapper);
}

@Override
public IntStream mapToInt(ToIntFunction<? super T> mapper) {
return getDelegate().mapToInt(mapper);
}

@Override
public LongStream mapToLong(ToLongFunction<? super T> mapper) {
return getDelegate().mapToLong(mapper);
}

@Override
public DoubleStream mapToDouble(ToDoubleFunction<? super T> mapper) {
return getDelegate().mapToDouble(mapper);
}

@Override
public <R> Stream<R> flatMap(Function<? super T, ? extends Stream<? extends R>> mapper) {
return getDelegate().flatMap(mapper);
}

@Override
public IntStream flatMapToInt(Function<? super T, ? extends IntStream> mapper) {
return getDelegate().flatMapToInt(mapper);
}

@Override
public LongStream flatMapToLong(Function<? super T, ? extends LongStream> mapper) {
return getDelegate().flatMapToLong(mapper);
}

@Override
public DoubleStream flatMapToDouble(Function<? super T, ? extends DoubleStream> mapper) {
return getDelegate().flatMapToDouble(mapper);
}

@Override
public Stream<T> distinct() {
return getDelegate().distinct();
}

@Override
public Stream<T> sorted() {
return getDelegate().sorted();
}

@Override
public Stream<T> sorted(Comparator<? super T> comparator) {
return getDelegate().sorted(comparator);
}

@Override
public Stream<T> peek(Consumer<? super T> action) {
return getDelegate().peek(action);
}

@Override
public Stream<T> limit(long maxSize) {
return getDelegate().limit(maxSize);
}

@Override
public Stream<T> skip(long n) {
return getDelegate().skip(n);
}

@Override
public void forEach(Consumer<? super T> action) {
getDelegate().forEach(action);
}

@Override
public void forEachOrdered(Consumer<? super T> action) {
getDelegate().forEachOrdered(action);
}

@Override
public Object[] toArray() {
return getDelegate().toArray();
}

@Override
public <A> A[] toArray(IntFunction<A[]> generator) {
return getDelegate().toArray(generator);
}

@Override
public T reduce(T identity, BinaryOperator<T> accumulator) {
return getDelegate().reduce(identity, accumulator);
}

@Override
public Optional<T> reduce(BinaryOperator<T> accumulator) {
return getDelegate().reduce(accumulator);
}

@Override
public <U> U reduce(U identity, BiFunction<U, ? super T, U> accumulator, BinaryOperator<U> combiner) {
return getDelegate().reduce(identity, accumulator, combiner);
}

@Override
public <R> R collect(Supplier<R> supplier, BiConsumer<R, ? super T> accumulator, BiConsumer<R, R> combiner) {
return getDelegate().collect(supplier, accumulator, combiner);
}

@Override
public <R, A> R collect(Collector<? super T, A, R> collector) {
return getDelegate().collect(collector);
}

@Override
public Optional<T> min(Comparator<? super T> comparator) {
return getDelegate().min(comparator);
}

@Override
public Optional<T> max(Comparator<? super T> comparator) {
return getDelegate().max(comparator);
}

@Override
public long count() {
return getDelegate().count();
}

@Override
public boolean anyMatch(Predicate<? super T> predicate) {
return getDelegate().anyMatch(predicate);
}

@Override
public boolean allMatch(Predicate<? super T> predicate) {
return getDelegate().allMatch(predicate);
}

@Override
public boolean noneMatch(Predicate<? super T> predicate) {
return getDelegate().noneMatch(predicate);
}

@Override
public Optional<T> findFirst() {
return getDelegate().findFirst();
}

@Override
public Optional<T> findAny() {
return getDelegate().findAny();
}

@Override
public Iterator<T> iterator() {
return getDelegate().iterator();
}

@Override
public Spliterator<T> spliterator() {
return getDelegate().spliterator();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -546,10 +546,10 @@ public void streamFindFirstOnlyRetrievesOnePage() {
OnlyOnePagedIterable pagedIterable = new OnlyOnePagedIterable(new OnlyOnePagedFlux(() -> pageRetriever));

// Validation that there is more than one paged in the full return.
pagedIterable.stream().count();
pagedIterable.stream().sequential().count();
assertEquals(DEFAULT_PAGE_COUNT, pageRetriever.getGetCount());

Integer next = pagedIterable.stream().findFirst().orElse(0);
Integer next = pagedIterable.stream().sequential().findFirst().orElse(0);

sleep();

Expand All @@ -565,10 +565,10 @@ public void streamFindFirstOnlyRetrievesOnePageSync() {
OnlyOnePagedIterable pagedIterable = new OnlyOnePagedIterable(() -> pageRetrieverSync, null, null);

// Validation that there is more than one paged in the full return.
pagedIterable.stream().count();
pagedIterable.stream().sequential().count();
assertEquals(DEFAULT_PAGE_COUNT, pageRetrieverSync.getGetCount());

Integer next = pagedIterable.stream().findFirst().orElse(0);
Integer next = pagedIterable.stream().sequential().findFirst().orElse(0);

/*
* Given that each page contains more than one element we are able to only retrieve a single page.
Expand Down Expand Up @@ -793,6 +793,71 @@ public <C, T, P extends ContinuablePage<C, T>> void iteratingTerminatesOn(Contin
}
}

@Test
public void streamParallelUsesMultipleThreads() {
// Create a PagedIterable with multiple pages to test parallel processing
PagedIterable<Integer> pagedIterable = getIntegerPagedIterable(5);

// Use a concurrent set to track which threads are used for processing
java.util.concurrent.ConcurrentSkipListSet<String> threadsUsed
= new java.util.concurrent.ConcurrentSkipListSet<>();

// Process items in parallel and track which threads are used
List<Integer> results = pagedIterable.stream()
.parallel()
.peek(item -> threadsUsed.add(Thread.currentThread().getName()))
.collect(Collectors.toList());

// Verify we got all the expected results
assertEquals(5 * 3, results.size());

// Verify that parallel processing actually used multiple threads
// Note: We can't guarantee multiple threads will be used, but for 15 items
// with parallel processing, it's very likely unless running on a single-core system
assertTrue(threadsUsed.size() >= 1, "Should use at least one thread");

// More importantly, verify that calling .parallel() doesn't break the functionality
// This is the key part of the fix - ensuring .parallel() works correctly
List<Integer> expectedResults = Stream.iterate(0, i -> i + 1).limit(15L).collect(Collectors.toList());
Collections.sort(results); // Sort since parallel processing may change order
assertEquals(expectedResults, results);
}

@Test
public void streamSequentialStillWorks() {
// Verify that sequential processing still works after our fix
PagedIterable<Integer> pagedIterable = getIntegerPagedIterable(3);

List<Integer> results = pagedIterable.stream()
.sequential() // Explicitly make it sequential
.collect(Collectors.toList());

// Verify we got all the expected results in order
List<Integer> expectedResults = Stream.iterate(0, i -> i + 1).limit(9L).collect(Collectors.toList());
assertEquals(expectedResults, results);
}

@Test
public void streamDefaultBehaviorWorksWithParallelToggle() {
// Test that we can toggle between parallel and sequential
PagedIterable<Integer> pagedIterable = getIntegerPagedIterable(2);

// Start parallel, then sequential
List<Integer> results1 = pagedIterable.stream().parallel().sequential().collect(Collectors.toList());

// Start sequential, then parallel
List<Integer> results2 = pagedIterable.stream().sequential().parallel().collect(Collectors.toList());

List<Integer> expectedResults = Stream.iterate(0, i -> i + 1).limit(6L).collect(Collectors.toList());

// Sequential result should be in order
assertEquals(expectedResults, results1);

// Parallel result should have same elements (but potentially different order)
Collections.sort(results2);
assertEquals(expectedResults, results2);
}

private static void sleep() {
try {
Thread.sleep(500);
Expand Down
Loading