Skip to content
This repository was archived by the owner on Jul 31, 2023. It is now read-only.

Commit 27696c4

Browse files
committed
Update inferred setup.py path to full path.
Change-Id: I50522de73ed523f248b3894022ce3bba7aac611e
1 parent b035336 commit 27696c4

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

tfrutil/beam_pipeline.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,16 @@
3434
from tfrutil import constants
3535

3636

37+
def _get_setup_py_filepath() -> str:
38+
"""Returns the file path to the setup.py file.
39+
40+
The location of the setup.py file is needed to run Dataflow jobs.
41+
"""
42+
43+
return os.path.join(
44+
os.path.dirname(os.path.abspath(__file__)), '..', 'setup.py')
45+
46+
3747
def _get_job_name(job_label: str = None) -> str:
3848
"""Returns Beam runner job name.
3949
@@ -84,7 +94,7 @@ def _get_pipeline_options(
8494
if region:
8595
options_dict['region'] = region
8696
if runner == 'DataFlowRunner':
87-
options_dict['setup_file'] = os.path.join('..', 'setup.py')
97+
options_dict['setup_file'] = _get_setup_py_filepath()
8898
if dataflow_options:
8999
options_dict.update(dataflow_options)
90100

tfrutil/beam_pipeline_test.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
"""Tests for beam_pipeline."""
1818

19+
import os
1920
import unittest
2021
from unittest import mock
2122

@@ -28,7 +29,7 @@
2829
# pylint: disable=protected-access
2930

3031
class BeamPipelineTests(unittest.TestCase):
31-
"""Tests for beam_image.py"""
32+
"""Tests for beam_pipeline.py"""
3233

3334
def test_processing_fn_with_int_label(self):
3435
'Test preprocessing fn with integer label.'
@@ -76,3 +77,9 @@ def test_partition_fn(self):
7677
self.assertEqual(
7778
index, i,
7879
'{} should be index {} but was index {}'.format(part, i, index))
80+
81+
def test_get_setup_py_filepath(self):
82+
"""Tests `_get_setup_py_filepath`."""
83+
filepath = beam_pipeline._get_setup_py_filepath()
84+
self.assertTrue(os.path.isfile(filepath))
85+
self.assertTrue(os.path.isabs(filepath))

0 commit comments

Comments
 (0)