Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.cluster.metadata.ProjectId;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
Expand Down Expand Up @@ -130,6 +131,7 @@ record Parsed(Pipeline pipeline, List<IngestDocument> documents, boolean verbose
static final String SIMULATED_PIPELINE_ID = "_simulate_pipeline";

static Parsed parseWithPipelineId(
ProjectId projectId,
String pipelineId,
Map<String, Object> config,
boolean verbose,
Expand All @@ -139,7 +141,7 @@ static Parsed parseWithPipelineId(
if (pipelineId == null) {
throw new IllegalArgumentException("param [pipeline] is null");
}
Pipeline pipeline = ingestService.getPipeline(pipelineId);
Pipeline pipeline = ingestService.getPipeline(projectId, pipelineId);
if (pipeline == null) {
throw new IllegalArgumentException("pipeline [" + pipelineId + "] does not exist");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.node.DiscoveryNodes;
import org.elasticsearch.cluster.project.ProjectResolver;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.util.concurrent.EsExecutors;
Expand Down Expand Up @@ -49,6 +50,7 @@ public class SimulatePipelineTransportAction extends HandledTransportAction<Simu
private final IngestService ingestService;
private final SimulateExecutionService executionService;
private final TransportService transportService;
private final ProjectResolver projectResolver;
private volatile TimeValue ingestNodeTransportActionTimeout;
// ThreadLocal because our unit testing framework does not like sharing Randoms across threads
private final ThreadLocal<Random> random = ThreadLocal.withInitial(Randomness::get);
Expand All @@ -58,7 +60,8 @@ public SimulatePipelineTransportAction(
ThreadPool threadPool,
TransportService transportService,
ActionFilters actionFilters,
IngestService ingestService
IngestService ingestService,
ProjectResolver projectResolver
) {
super(
SimulatePipelineAction.NAME,
Expand All @@ -70,6 +73,7 @@ public SimulatePipelineTransportAction(
this.ingestService = ingestService;
this.executionService = new SimulateExecutionService(threadPool);
this.transportService = transportService;
this.projectResolver = projectResolver;
this.ingestNodeTransportActionTimeout = INGEST_NODE_TRANSPORT_ACTION_TIMEOUT.get(ingestService.getClusterService().getSettings());
ingestService.getClusterService()
.getClusterSettings()
Expand All @@ -96,9 +100,11 @@ protected void doExecute(Task task, SimulatePipelineRequest request, ActionListe
}
try {
if (discoveryNodes.getLocalNode().isIngestNode()) {
final var projectId = projectResolver.getProjectId();
final SimulatePipelineRequest.Parsed simulateRequest;
if (request.getId() != null) {
simulateRequest = SimulatePipelineRequest.parseWithPipelineId(
projectId,
request.getId(),
source,
request.isVerbose(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -61,7 +63,7 @@ public void init() throws IOException {
(factories, tag, description, config) -> processor
);
ingestService = mock(IngestService.class);
when(ingestService.getPipeline(SIMULATED_PIPELINE_ID)).thenReturn(pipeline);
when(ingestService.getPipeline(any(), eq(SIMULATED_PIPELINE_ID))).thenReturn(pipeline);
when(ingestService.getProcessorFactories()).thenReturn(registry);
}

Expand Down Expand Up @@ -89,7 +91,9 @@ public void testParseUsingPipelineStore() throws Exception {
expectedDocs.add(expectedDoc);
}

var projectId = randomProjectIdOrDefault();
SimulatePipelineRequest.Parsed actualRequest = SimulatePipelineRequest.parseWithPipelineId(
projectId,
SIMULATED_PIPELINE_ID,
requestContent,
false,
Expand Down Expand Up @@ -213,24 +217,40 @@ public void testParseWithProvidedPipeline() throws Exception {
}

public void testNullPipelineId() {
var projectId = randomProjectIdOrDefault();
Map<String, Object> requestContent = new HashMap<>();
List<Map<String, Object>> docs = new ArrayList<>();
requestContent.put(Fields.DOCS, docs);
Exception e = expectThrows(
IllegalArgumentException.class,
() -> SimulatePipelineRequest.parseWithPipelineId(null, requestContent, false, ingestService, RestApiVersion.current())
() -> SimulatePipelineRequest.parseWithPipelineId(
projectId,
null,
requestContent,
false,
ingestService,
RestApiVersion.current()
)
);
assertThat(e.getMessage(), equalTo("param [pipeline] is null"));
}

public void testNonExistentPipelineId() {
var projectId = randomProjectIdOrDefault();
String pipelineId = randomAlphaOfLengthBetween(1, 10);
Map<String, Object> requestContent = new HashMap<>();
List<Map<String, Object>> docs = new ArrayList<>();
requestContent.put(Fields.DOCS, docs);
Exception e = expectThrows(
IllegalArgumentException.class,
() -> SimulatePipelineRequest.parseWithPipelineId(pipelineId, requestContent, false, ingestService, RestApiVersion.current())
() -> SimulatePipelineRequest.parseWithPipelineId(
projectId,
pipelineId,
requestContent,
false,
ingestService,
RestApiVersion.current()
)
);
assertThat(e.getMessage(), equalTo("pipeline [" + pipelineId + "] does not exist"));
}
Expand Down