Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.search.asynchronous.context.state.AsynchronousSearchState;
import org.opensearch.search.asynchronous.id.AsynchronousSearchId;
import org.opensearch.search.asynchronous.listener.AsynchronousSearchProgressListener;
import org.opensearch.search.asynchronous.response.AsynchronousSearchProgress;
import org.opensearch.search.asynchronous.response.AsynchronousSearchResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -65,18 +66,27 @@ public AsynchronousSearchContextId getContextId() {

public abstract @Nullable User getUser();

public @Nullable AsynchronousSearchProgress getProgress() {
return null;
}

public boolean isExpired() {
return getExpirationTimeMillis() < currentTimeSupplier.getAsLong();
}

public AsynchronousSearchResponse getAsynchronousSearchResponse() {
AsynchronousSearchProgress progress = getProgress();
if (progress == null && asynchronousSearchProgressListener != null) {
progress = asynchronousSearchProgressListener.progress();
}
return new AsynchronousSearchResponse(
getAsynchronousSearchId(),
getAsynchronousSearchState(),
getStartTimeMillis(),
getExpirationTimeMillis(),
getSearchResponse(),
getSearchError()
getSearchError(),
progress
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.search.asynchronous.id.AsynchronousSearchId;
import org.opensearch.search.asynchronous.id.AsynchronousSearchIdConverter;
import org.opensearch.search.asynchronous.listener.AsynchronousSearchProgressListener;
import org.opensearch.search.asynchronous.response.AsynchronousSearchProgress;
import org.opensearch.common.SetOnce;
import org.opensearch.core.action.ActionListener;
import org.opensearch.action.search.SearchProgressActionListener;
Expand All @@ -30,6 +31,7 @@
import java.io.Closeable;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.LongSupplier;
import java.util.function.Supplier;

Expand All @@ -56,6 +58,7 @@ public class AsynchronousSearchActiveContext extends AsynchronousSearchContext i
private final Supplier<Boolean> persistSearchFailureSupplier;
private final AsynchronousSearchContextPermits asynchronousSearchContextPermits;
private final Supplier<SearchResponse> partialResponseSupplier;
private final AtomicReference<AsynchronousSearchProgress> progress;
@Nullable
private final User user;

Expand Down Expand Up @@ -87,6 +90,7 @@ public AsynchronousSearchActiveContext(
: new NoopAsynchronousSearchContextPermits(asynchronousSearchContextId);
this.user = user;
this.persistSearchFailureSupplier = persistSearchFailureSupplier;
this.progress = new AtomicReference<>();
}

public void setTask(SearchTask searchTask) {
Expand All @@ -110,6 +114,7 @@ public void processSearchFailure(Exception e) {
e.getCause().setStackTrace(new StackTraceElement[0]);
}
this.error.set(e);
progress.compareAndSet(null, asynchronousSearchProgressListener.progress());
} finally {
boolean result = completed.compareAndSet(false, true);
assert result : "Process search failure already complete";
Expand All @@ -127,6 +132,7 @@ public void processSearchResponse(SearchResponse response) {
}
}
this.searchResponse.set(response);
progress.compareAndSet(null, asynchronousSearchProgressListener.progress());
} finally {
boolean result = completed.compareAndSet(false, true);
assert result : "Process search response already complete";
Expand All @@ -138,6 +144,11 @@ public SearchResponse getSearchResponse() {
return completed.get() ? searchResponse.get() : partialResponseSupplier.get();
}

@Override
public AsynchronousSearchProgress getProgress() {
return progress.get();
}

@Override
public String getAsynchronousSearchId() {
return asynchronousSearchId.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.search.asynchronous.response.AsynchronousSearchProgress;

import java.io.IOException;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -141,6 +142,35 @@ public User getUser() {
return asynchronousSearchPersistenceModel.getUser();
}

@Override
public AsynchronousSearchProgress getProgress() {
if (asynchronousSearchPersistenceModel.getProgress() == null) {
return null;
}
BytesReference bytesReference = BytesReference.fromByteBuffer(
ByteBuffer.wrap(Base64.getUrlDecoder().decode(asynchronousSearchPersistenceModel.getProgress()))
);
try (
NamedWriteableAwareStreamInput wrapperStreamInput = new NamedWriteableAwareStreamInput(
bytesReference.streamInput(),
namedWriteableRegistry
)
) {
wrapperStreamInput.setVersion(wrapperStreamInput.readVersion());
return new AsynchronousSearchProgress(wrapperStreamInput);
} catch (IOException e) {
logger.error(
() -> new ParameterizedMessage(
"Failed to parse search progress for asynchronous search [{}] Progress : [{}] ",
asynchronousSearchId,
asynchronousSearchPersistenceModel.getProgress()
),
e
);
return null;
}
}

@Override
public AsynchronousSearchState getAsynchronousSearchState() {
return AsynchronousSearchState.STORE_RESIDENT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.search.asynchronous.response.AsynchronousSearchProgress;

import java.io.IOException;
import java.util.Base64;
Expand All @@ -27,13 +28,22 @@ public class AsynchronousSearchPersistenceModel {
private final long startTimeMillis;
private final String response;
private final String error;
private final String progress;
private final User user;

public AsynchronousSearchPersistenceModel(long startTimeMillis, long expirationTimeMillis, String response, String error, User user) {
public AsynchronousSearchPersistenceModel(
long startTimeMillis,
long expirationTimeMillis,
String response,
String error,
String progress,
User user
) {
this.startTimeMillis = startTimeMillis;
this.expirationTimeMillis = expirationTimeMillis;
this.response = response;
this.error = error;
this.progress = progress;
this.user = user;
}

Expand All @@ -42,12 +52,14 @@ public AsynchronousSearchPersistenceModel(
long expirationTimeMillis,
SearchResponse response,
Exception error,
AsynchronousSearchProgress progress,
User user
) throws IOException {
this.startTimeMillis = startTimeMillis;
this.expirationTimeMillis = expirationTimeMillis;
this.response = serializeResponse(response);
this.error = serializeError(error);
this.progress = serializeProgress(progress);
this.user = user;
}

Expand All @@ -63,6 +75,18 @@ private String serializeResponse(SearchResponse response) throws IOException {
}
}

private String serializeProgress(AsynchronousSearchProgress progress) throws IOException {
if (progress == null) {
return null;
}
try (BytesStreamOutput out = new BytesStreamOutput()) {
out.writeVersion(Version.CURRENT);
progress.writeTo(out);
byte[] bytes = BytesReference.toBytes(out.bytes());
return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes);
}
}

/**
* Serializes exception in string format
*
Expand Down Expand Up @@ -100,6 +124,10 @@ public String getError() {
return error;
}

public String getProgress() {
return progress;
}

public long getExpirationTimeMillis() {
return expirationTimeMillis;
}
Expand All @@ -122,6 +150,8 @@ public boolean equals(Object o) {
&& ((response == null && other.response == null)
|| (response != null && other.response != null && response.equals(other.response)))
&& ((error == null && other.error == null) || (error != null && other.error != null && error.equals(other.error)))
&& ((progress == null && other.progress == null)
|| (progress != null && other.progress != null && progress.equals(other.progress)))
&& ((user == null && other.user == null) || (user != null && other.user != null && user.equals(other.user)));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public AsynchronousSearchPersistenceModel getAsynchronousSearchPersistenceModel(
asynchronousSearchContext.getExpirationTimeMillis(),
asynchronousSearchContext.getSearchResponse(),
asynchronousSearchContext.getSearchError(),
asynchronousSearchContext.getProgress(),
asynchronousSearchContext.getUser()
);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@
*/
package org.opensearch.search.asynchronous.listener;

import org.opensearch.search.asynchronous.response.AsynchronousSearchResponse;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.SetOnce;
import org.opensearch.action.search.SearchProgressActionListener;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchShard;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.search.asynchronous.response.AsynchronousSearchProgress;
import org.opensearch.search.asynchronous.response.AsynchronousSearchResponse;
import org.opensearch.search.SearchHits;
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.aggregations.InternalAggregation;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -76,6 +79,8 @@ protected void onListShards(
boolean fetchPhase
) {
partialResultsHolder.hasFetchPhase.set(fetchPhase);
partialResultsHolder.queryShards.set(Collections.unmodifiableList(shards));
partialResultsHolder.initShardProgress(shards.size());
partialResultsHolder.totalShards.set(shards.size() + skippedShards.size());
partialResultsHolder.skippedShards.set(skippedShards.size());
partialResultsHolder.successfulShards.set(skippedShards.size());
Expand Down Expand Up @@ -131,13 +136,25 @@ protected void onQueryResult(int shardIndex) {
onShardResult(shardIndex);
}

// Not annotated with @Override to remain compatible with older SearchProgressActionListener versions.
protected void onQueryResult(int shardIndex, long maxDocIdProcessed, long maxDoc) {
assert shardIndex < partialResultsHolder.totalShards.get();
onShardProgress(shardIndex, maxDocIdProcessed, maxDoc);
onShardResult(shardIndex);
}

private synchronized void onShardResult(int shardIndex) {
if (partialResultsHolder.successfulShardIds.contains(shardIndex) == false) {
partialResultsHolder.successfulShardIds.add(shardIndex);
partialResultsHolder.successfulShards.incrementAndGet();
}
}

private synchronized void onShardProgress(int shardIndex, long maxDocIdProcessed, long maxDoc) {
partialResultsHolder.maxDocIdProcessedByShard[shardIndex] = maxDocIdProcessed;
partialResultsHolder.maxDocByShard[shardIndex] = maxDoc;
}

private synchronized void onSearchFailure(int shardIndex, SearchShardTarget shardTarget, Exception e) {
// It's hard to build partial search failures since the elasticsearch doesn't consider shard not available exceptions as failures
// while internally it has exceptions from all shards of a particular shard group, it exposes only the exception on the
Expand All @@ -154,6 +171,10 @@ public CompositeSearchProgressActionListener<AsynchronousSearchResponse> searchP
return searchProgressActionListener;
}

public AsynchronousSearchProgress progress() {
return partialResultsHolder == null ? null : partialResultsHolder.progress();
}

@Override
public void onResponse(SearchResponse searchResponse) {
executor.execute(() -> {
Expand Down Expand Up @@ -198,12 +219,15 @@ static class PartialResultsHolder {
final SetOnce<Integer> totalShards;
final SetOnce<Integer> skippedShards;
final SetOnce<SearchResponse.Clusters> clusters;
final SetOnce<List<SearchShard>> queryShards;
final Set<Integer> successfulShardIds;
final SetOnce<Boolean> hasFetchPhase;
final AtomicInteger successfulShards;
final AtomicReference<TotalHits> totalHits;
final AtomicReference<InternalAggregations> internalAggregations;
final AtomicReference<InternalAggregations> partialInternalAggregations;
volatile long[] maxDocIdProcessedByShard;
volatile long[] maxDocByShard;
final long relativeStartMillis;
final LongSupplier relativeTimeSupplier;
final Supplier<InternalAggregation.ReduceContextBuilder> reduceContextBuilder;
Expand All @@ -222,13 +246,41 @@ static class PartialResultsHolder {
this.hasFetchPhase = new SetOnce<>();
this.totalHits = new AtomicReference<>();
this.clusters = new SetOnce<>();
this.queryShards = new SetOnce<>();
this.partialInternalAggregations = new AtomicReference<>();
this.relativeStartMillis = relativeStartMillis;
this.successfulShardIds = new HashSet<>(1);
this.relativeTimeSupplier = relativeTimeSupplier;
this.reduceContextBuilder = reduceContextBuilder;
}

void initShardProgress(int shardCount) {
maxDocIdProcessedByShard = new long[shardCount];
maxDocByShard = new long[shardCount];
Arrays.fill(maxDocIdProcessedByShard, -1L);
Arrays.fill(maxDocByShard, -1L);
}

AsynchronousSearchProgress progress() {
if (isInitialized == false || queryShards.get() == null) {
return null;
}
List<AsynchronousSearchProgress.ShardProgress> shardProgress = new ArrayList<>();
List<SearchShard> shards = queryShards.get();
for (int i = 0; i < shards.size(); i++) {
if (maxDocByShard[i] >= 0) {
shardProgress.add(
AsynchronousSearchProgress.ShardProgress.fromSearchShard(
shards.get(i),
maxDocIdProcessedByShard[i],
maxDocByShard[i]
)
);
}
}
return new AsynchronousSearchProgress(shardProgress);
}

public SearchResponse partialResponse() {
if (isInitialized) {
SearchHits searchHits = new SearchHits(SearchHits.EMPTY, totalHits.get(), Float.NaN);
Expand Down
Loading