Skip to content

Commit 4a88304

Browse files
authored
Add the pub/sub source (#15)
* added the pub/sub source * updated the readme * updated the shards --------- Co-authored-by: xqhu <xqhu@google.com>
1 parent bdc9cb2 commit 4a88304

File tree

6 files changed

+70
-23
lines changed

6 files changed

+70
-23
lines changed

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ ifeq ($(MODEL_ENV), "TORCH")
119119
--dataflow_service_option $(SERVICE_OPTIONS) \
120120
--number_of_worker_harness_threads 1 \
121121
--experiments=disable_worker_container_image_prepull \
122+
--experiments=use_pubsub_streaming \
122123
--sdk_container_image $(CUSTOM_CONTAINER_IMAGE) \
123124
--sdk_location container \
124125
--input $(INPUT_DATA) \
@@ -140,6 +141,7 @@ else
140141
--dataflow_service_option $(SERVICE_OPTIONS) \
141142
--number_of_worker_harness_threads 1 \
142143
--experiments=disable_worker_container_image_prepull \
144+
--experiments=use_pubsub_streaming \
143145
--sdk_container_image $(CUSTOM_CONTAINER_IMAGE) \
144146
--sdk_location container \
145147
--input $(INPUT_DATA) \

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,12 @@ When using resnet101 to score 50k images, the job took ~1h and cost ~0.5$ with r
268268
For `mobilenet_v2`, it cost 0.05$ with ~1h.
269269
Note the cost and time depends on your job settings and the regions.
270270

271+
### Run the Beam pipeline with the Pub/Sub source
272+
When `INPUT_DATA` from the `.env` file defines a valid Pub/Sub topic (e.g., `projects/apache-beam-testing/topics/Imagenet_openimage_50k_benchmark`),
273+
the Beam pipeline is created using the Pub/Sub source with `FixedWindows` and switches to `beam.io.fileio.WriteToFiles` that supports the streaming pipeline.
274+
We use `shards=0` here since 0 shards is the recommended approach and Dataflow would decide how many files it should write.
275+
Note that the streaming job will run forever until it is canceled or drained.
276+
271277
## FAQ
272278

273279
### Permission error when using any GCP command

src/config.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
# limitations under the License.
1414

1515
# standard libraries
16+
import re
1617
from enum import Enum
1718

1819
# third party libraries
19-
from pydantic import BaseModel, Field, root_validator
20+
from pydantic import BaseModel, Field, root_validator, validator
2021

2122

2223
class ModelName(str, Enum):
@@ -51,13 +52,23 @@ def validate_fields(cls, values):
5152
return values
5253

5354

55+
def _validate_topic_path(topic_path):
56+
pattern = r"projects/.+/topics/.+"
57+
return bool(re.match(pattern, topic_path))
58+
59+
5460
class SourceConfig(BaseModel):
55-
input: str = Field(..., description="the input path to a text file")
61+
input: str = Field(..., description="the input path to a text file or a Pub/Sub topic")
5662
images_dir: str = Field(
5763
None,
5864
description="Path to the directory where images are stored."
5965
"Not required if image names in the input file have absolute path.",
6066
)
67+
streaming: bool = False
68+
69+
@validator("streaming", pre=True, always=True)
70+
def set_streaming(cls, v, values):
71+
return _validate_topic_path(values["input"])
6172

6273

6374
class SinkConfig(BaseModel):

src/pipeline.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# standard libraries
1818
import io
1919
import os
20-
from typing import Iterable, Iterator, Optional, Tuple
20+
from typing import Iterable, Iterator, Optional, Tuple, Union
2121

2222
# third party libraries
2323
import apache_beam as beam
@@ -45,7 +45,9 @@ def get_model_class(model_name: ModelName) -> nn.Module:
4545
return model_class
4646

4747

48-
def read_image(image_file_name: str, path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
48+
def read_image(image_file_name: Union[str, bytes], path_to_dir: Optional[str] = None) -> Tuple[str, Image.Image]:
49+
if isinstance(image_file_name, bytes):
50+
image_file_name = image_file_name.decode()
4951
if path_to_dir is not None:
5052
image_file_name = os.path.join(path_to_dir, image_file_name)
5153
with FileSystems().open(image_file_name, "r") as file:
@@ -126,14 +128,24 @@ def build_pipeline(pipeline, source_config: SourceConfig, sink_config: SinkConfi
126128
else:
127129
raise ValueError("Only support PytorchModelHandler and TFModelHandlerTensor!")
128130

129-
# read the text file and create the pair of input data with the file name and its image
130-
filename_value_pair = (
131-
pipeline
132-
| "ReadImageNames" >> beam.io.ReadFromText(source_config.input)
133-
| "FilterEmptyLines" >> beam.ParDo(filter_empty_lines)
134-
| "ReadImageData"
135-
>> beam.Map(lambda image_name: read_image(image_file_name=image_name, path_to_dir=source_config.images_dir))
136-
)
131+
if source_config.streaming:
132+
# read the text file path from Pub/Sub and use FixedWindows to group these images
133+
# and then run the model inference and store the results into GCS
134+
filename_value_pair = (
135+
pipeline
136+
| "ReadImageNamesFromPubSub" >> beam.io.ReadFromPubSub(topic=source_config.input)
137+
| "Window into fixed intervals" >> beam.WindowInto(beam.window.FixedWindows(60 * 5))
138+
| "ReadImageData" >> beam.Map(lambda image_name: read_image(image_file_name=image_name))
139+
)
140+
else:
141+
# read the text file and create the pair of input data with the file name and its image
142+
filename_value_pair = (
143+
pipeline
144+
| "ReadImageNames" >> beam.io.ReadFromText(source_config.input)
145+
| "FilterEmptyLines" >> beam.ParDo(filter_empty_lines)
146+
| "ReadImageData"
147+
>> beam.Map(lambda image_name: read_image(image_file_name=image_name, path_to_dir=source_config.images_dir))
148+
)
137149

138150
if model_config.model_state_dict_path:
139151
filename_value_pair = filename_value_pair | "PreprocessImages" >> beam.MapTuple(
@@ -151,7 +163,15 @@ def build_pipeline(pipeline, source_config: SourceConfig, sink_config: SinkConfi
151163
| "ProcessOutput" >> beam.ParDo(PostProcessor())
152164
)
153165

154-
# save the predictions to a text file
155-
predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
156-
sink_config.output, shard_name_template="", append_trailing_newlines=True
157-
)
166+
# combine all the window results into one text for GCS
167+
if source_config.streaming:
168+
(
169+
predictions
170+
| "WriteOutputToGCS"
171+
>> beam.io.fileio.WriteToFiles(sink_config.output, shards=0) # pylint: disable=expression-not-assigned
172+
)
173+
else:
174+
# save the predictions to a text file
175+
predictions | "WriteOutputToGCS" >> beam.io.WriteToText( # pylint: disable=expression-not-assigned
176+
sink_config.output, shard_name_template="", append_trailing_newlines=True
177+
)

src/run.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,6 @@ def run(argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult
6464
"""
6565
known_args, pipeline_args = parse_known_args(argv)
6666

67-
pipeline_options = PipelineOptions(pipeline_args)
68-
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
69-
70-
pipeline = test_pipeline
71-
if not test_pipeline:
72-
pipeline = beam.Pipeline(options=pipeline_options)
73-
7467
# setup configs
7568
model_config = ModelConfig(
7669
model_state_dict_path=known_args.model_state_dict_path,
@@ -83,6 +76,14 @@ def run(argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult
8376
source_config = SourceConfig(input=known_args.input)
8477
sink_config = SinkConfig(output=known_args.output)
8578

79+
# setup pipeline
80+
pipeline_options = PipelineOptions(pipeline_args, streaming=source_config.streaming)
81+
pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
82+
83+
pipeline = test_pipeline
84+
if not test_pipeline:
85+
pipeline = beam.Pipeline(options=pipeline_options)
86+
8687
# build the pipeline using configs
8788
build_pipeline(pipeline, source_config=source_config, sink_config=sink_config, model_config=model_config)
8889

tests/test_pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ def test_build_pipeline_with_tf():
4848

4949
p = beam.Pipeline()
5050
build_pipeline(p, source_config=source_config, sink_config=sink_config, model_config=model_config)
51+
52+
53+
def test_source_config_streaming():
54+
source_config = SourceConfig(input=str(DATA_FILE_PATH / "openimage_10.txt"))
55+
assert source_config.streaming is False
56+
source_config = SourceConfig(input="projects/apache-beam-testing/topics/Imagenet_openimage_50k_benchmark")
57+
assert source_config.streaming is True

0 commit comments

Comments
 (0)