|
22 | 22 | import logging |
23 | 23 | import unittest |
24 | 24 |
|
| 25 | +import mock |
| 26 | + |
| 27 | +from apache_beam.options import pipeline_options |
25 | 28 | from apache_beam.portability.api import beam_job_api_pb2 |
26 | 29 | from apache_beam.portability.api import beam_runner_api_pb2 |
27 | 30 | from apache_beam.runners.portability import local_job_service |
28 | 31 | from apache_beam.runners.portability.portable_runner import JobServiceHandle |
| 32 | +from apache_beam.runners.portability.portable_runner import PortableRunner |
29 | 33 |
|
30 | 34 |
|
31 | 35 | class TestJobServicePlan(JobServiceHandle): |
32 | 36 | def __init__(self, job_service): |
33 | 37 | self.job_service = job_service |
34 | 38 | self.options = None |
35 | 39 | self.timeout = None |
| 40 | + self.artifact_endpoint = None |
36 | 41 |
|
37 | 42 | def get_pipeline_options(self): |
38 | 43 | return None |
@@ -102,6 +107,32 @@ def test_error_messages_after_pipeline_failure(self): |
102 | 107 | for m in message_results), |
103 | 108 | messages_again) |
104 | 109 |
|
| 110 | + def test_artifact_service_override(self): |
| 111 | + job_service = local_job_service.LocalJobServicer() |
| 112 | + port = job_service.start_grpc_server() |
| 113 | + |
| 114 | + test_artifact_endpoint = 'testartifactendpoint:4242' |
| 115 | + |
| 116 | + options = pipeline_options.PipelineOptions([ |
| 117 | + '--job_endpoint', |
| 118 | + 'localhost:%d' % port, |
| 119 | + '--artifact_endpoint', |
| 120 | + test_artifact_endpoint, |
| 121 | + ]) |
| 122 | + runner = PortableRunner() |
| 123 | + job_service_handle = runner.create_job_service(options) |
| 124 | + |
| 125 | + with mock.patch.object(job_service_handle, 'stage') as mocked_stage: |
| 126 | + job_service_handle.submit(beam_runner_api_pb2.Pipeline()) |
| 127 | + mocked_stage.assert_called_once_with( |
| 128 | + mock.ANY, test_artifact_endpoint, mock.ANY) |
| 129 | + |
| 130 | + # Confirm the artifact_endpoint is in the options protobuf |
| 131 | + options_proto = job_service_handle.get_pipeline_options() |
| 132 | + self.assertEqual( |
| 133 | + options_proto['beam:option:artifact_endpoint:v1'], |
| 134 | + test_artifact_endpoint) |
| 135 | + |
105 | 136 |
|
106 | 137 | if __name__ == '__main__': |
107 | 138 | logging.getLogger().setLevel(logging.INFO) |
|
0 commit comments