Skip to content

Commit c505411

Browse files
authored
Base AWS classes - S3 (apache#47321)
1 parent 91a9347 commit c505411

File tree

7 files changed

+248
-245
lines changed

7 files changed

+248
-245
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py

Lines changed: 147 additions & 157 deletions
Large diffs are not rendered by default.

providers/amazon/src/airflow/providers/amazon/aws/sensors/s3.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import re
2424
from collections.abc import Sequence
2525
from datetime import datetime, timedelta
26-
from functools import cached_property
2726
from typing import TYPE_CHECKING, Any, Callable, cast
2827

2928
from airflow.configuration import conf
@@ -34,11 +33,13 @@
3433

3534
from airflow.exceptions import AirflowException
3635
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
36+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
3737
from airflow.providers.amazon.aws.triggers.s3 import S3KeysUnchangedTrigger, S3KeyTrigger
38-
from airflow.sensors.base import BaseSensorOperator, poke_mode_only
38+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
39+
from airflow.sensors.base import poke_mode_only
3940

4041

41-
class S3KeySensor(BaseSensorOperator):
42+
class S3KeySensor(AwsBaseSensor[S3Hook]):
4243
"""
4344
Waits for one or multiple keys (a file-like instance on S3) to be present in a S3 bucket.
4445
@@ -65,27 +66,25 @@ class S3KeySensor(BaseSensorOperator):
6566
6667
def check_fn(files: List, **kwargs) -> bool:
6768
return any(f.get('Size', 0) > 1048576 for f in files)
68-
:param aws_conn_id: a reference to the s3 connection
69-
:param verify: Whether to verify SSL certificates for S3 connection.
70-
By default, SSL certificates are verified.
71-
You can provide the following values:
72-
73-
- ``False``: do not validate SSL certificates. SSL will still be used
74-
(unless use_ssl is False), but SSL certificates will not be
75-
verified.
76-
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
77-
You can specify this argument if you want to use a different
78-
CA cert bundle than the one used by botocore.
7969
:param deferrable: Run operator in the deferrable mode
8070
:param use_regex: whether to use regex to check bucket
8171
:param metadata_keys: List of head_object attributes to gather and send to ``check_fn``.
8272
Acceptable values: Any top level attribute returned by s3.head_object. Specify * to return
8373
all available attributes.
8474
Default value: "Size".
8575
If the requested attribute is not found, the key is still included and the value is None.
76+
:param aws_conn_id: The Airflow connection used for AWS credentials.
77+
If this is ``None`` or empty then the default boto3 behaviour is used. If
78+
running Airflow in a distributed manner and aws_conn_id is None or
79+
empty, then default boto3 configuration would be used (and must be
80+
maintained on each worker node).
81+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
82+
:param verify: Whether or not to verify SSL certificates. See:
83+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
8684
"""
8785

88-
template_fields: Sequence[str] = ("bucket_key", "bucket_name")
86+
template_fields: Sequence[str] = aws_template_fields("bucket_key", "bucket_name")
87+
aws_hook_class = S3Hook
8988

9089
def __init__(
9190
self,
@@ -94,7 +93,6 @@ def __init__(
9493
bucket_name: str | None = None,
9594
wildcard_match: bool = False,
9695
check_fn: Callable[..., bool] | None = None,
97-
aws_conn_id: str | None = "aws_default",
9896
verify: str | bool | None = None,
9997
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
10098
use_regex: bool = False,
@@ -106,14 +104,13 @@ def __init__(
106104
self.bucket_key = bucket_key
107105
self.wildcard_match = wildcard_match
108106
self.check_fn = check_fn
109-
self.aws_conn_id = aws_conn_id
110107
self.verify = verify
111108
self.deferrable = deferrable
112109
self.use_regex = use_regex
113110
self.metadata_keys = metadata_keys if metadata_keys else ["Size"]
114111

115112
def _check_key(self, key, context: Context):
116-
bucket_name, key = S3Hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
113+
bucket_name, key = self.hook.get_s3_bucket_key(self.bucket_name, key, "bucket_name", "bucket_key")
117114
self.log.info("Poking for key : s3://%s/%s", bucket_name, key)
118115

119116
"""
@@ -199,7 +196,9 @@ def _defer(self) -> None:
199196
bucket_key=self.bucket_key,
200197
wildcard_match=self.wildcard_match,
201198
aws_conn_id=self.aws_conn_id,
199+
region_name=self.region_name,
202200
verify=self.verify,
201+
botocore_config=self.botocore_config,
203202
poke_interval=self.poke_interval,
204203
should_check_fn=bool(self.check_fn),
205204
use_regex=self.use_regex,
@@ -220,13 +219,9 @@ def execute_complete(self, context: Context, event: dict[str, Any]) -> None:
220219
elif event["status"] == "error":
221220
raise AirflowException(event["message"])
222221

223-
@cached_property
224-
def hook(self) -> S3Hook:
225-
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
226-
227222

228223
@poke_mode_only
229-
class S3KeysUnchangedSensor(BaseSensorOperator):
224+
class S3KeysUnchangedSensor(AwsBaseSensor[S3Hook]):
230225
"""
231226
Return True if inactivity_period has passed with no increase in the number of objects matching prefix.
232227
@@ -239,17 +234,7 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
239234
240235
:param bucket_name: Name of the S3 bucket
241236
:param prefix: The prefix being waited on. Relative path from bucket root level.
242-
:param aws_conn_id: a reference to the s3 connection
243-
:param verify: Whether or not to verify SSL certificates for S3 connection.
244-
By default SSL certificates are verified.
245-
You can provide the following values:
246-
247-
- ``False``: do not validate SSL certificates. SSL will still be used
248-
(unless use_ssl is False), but SSL certificates will not be
249-
verified.
250-
- ``path/to/cert/bundle.pem``: A filename of the CA cert bundle to uses.
251-
You can specify this argument if you want to use a different
252-
CA cert bundle than the one used by botocore.
237+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
253238
:param inactivity_period: The total seconds of inactivity to designate
254239
keys unchanged. Note, this mechanism is not real time and
255240
this operator may not return until a poke_interval after this period
@@ -261,16 +246,24 @@ class S3KeysUnchangedSensor(BaseSensorOperator):
261246
between pokes valid behavior. If true a warning message will be logged
262247
when this happens. If false an error will be raised.
263248
:param deferrable: Run sensor in the deferrable mode
249+
:param aws_conn_id: The Airflow connection used for AWS credentials.
250+
If this is ``None`` or empty then the default boto3 behaviour is used. If
251+
running Airflow in a distributed manner and aws_conn_id is None or
252+
empty, then default boto3 configuration would be used (and must be
253+
maintained on each worker node).
254+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
255+
:param verify: Whether or not to verify SSL certificates. See:
256+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
264257
"""
265258

266-
template_fields: Sequence[str] = ("bucket_name", "prefix")
259+
template_fields: Sequence[str] = aws_template_fields("bucket_name", "prefix")
260+
aws_hook_class = S3Hook
267261

268262
def __init__(
269263
self,
270264
*,
271265
bucket_name: str,
272266
prefix: str,
273-
aws_conn_id: str | None = "aws_default",
274267
verify: bool | str | None = None,
275268
inactivity_period: float = 60 * 60,
276269
min_objects: int = 1,
@@ -291,15 +284,9 @@ def __init__(
291284
self.inactivity_seconds = 0
292285
self.allow_delete = allow_delete
293286
self.deferrable = deferrable
294-
self.aws_conn_id = aws_conn_id
295287
self.verify = verify
296288
self.last_activity_time: datetime | None = None
297289

298-
@cached_property
299-
def hook(self):
300-
"""Returns S3Hook."""
301-
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
302-
303290
def is_keys_unchanged(self, current_objects: set[str]) -> bool:
304291
"""
305292
Check for new objects after the inactivity_period and update the sensor state accordingly.
@@ -382,7 +369,9 @@ def execute(self, context: Context) -> None:
382369
inactivity_seconds=self.inactivity_seconds,
383370
allow_delete=self.allow_delete,
384371
aws_conn_id=self.aws_conn_id,
372+
region_name=self.region_name,
385373
verify=self.verify,
374+
botocore_config=self.botocore_config,
386375
last_activity_time=self.last_activity_time,
387376
),
388377
method_name="execute_complete",

providers/amazon/src/airflow/providers/amazon/aws/triggers/s3.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def __init__(
5353
poke_interval: float = 5.0,
5454
should_check_fn: bool = False,
5555
use_regex: bool = False,
56+
region_name: str | None = None,
57+
verify: bool | str | None = None,
58+
botocore_config: dict | None = None,
5659
**hook_params: Any,
5760
):
5861
super().__init__()
@@ -64,6 +67,9 @@ def __init__(
6467
self.poke_interval = poke_interval
6568
self.should_check_fn = should_check_fn
6669
self.use_regex = use_regex
70+
self.region_name = region_name
71+
self.verify = verify
72+
self.botocore_config = botocore_config
6773

6874
def serialize(self) -> tuple[str, dict[str, Any]]:
6975
"""Serialize S3KeyTrigger arguments and classpath."""
@@ -78,12 +84,20 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
7884
"poke_interval": self.poke_interval,
7985
"should_check_fn": self.should_check_fn,
8086
"use_regex": self.use_regex,
87+
"region_name": self.region_name,
88+
"verify": self.verify,
89+
"botocore_config": self.botocore_config,
8190
},
8291
)
8392

8493
@cached_property
8594
def hook(self) -> S3Hook:
86-
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
95+
return S3Hook(
96+
aws_conn_id=self.aws_conn_id,
97+
region_name=self.region_name,
98+
verify=self.verify,
99+
config=self.botocore_config,
100+
)
87101

88102
async def run(self) -> AsyncIterator[TriggerEvent]:
89103
"""Make an asynchronous connection using S3HookAsync."""
@@ -143,7 +157,9 @@ def __init__(
143157
allow_delete: bool = True,
144158
aws_conn_id: str | None = "aws_default",
145159
last_activity_time: datetime | None = None,
160+
region_name: str | None = None,
146161
verify: bool | str | None = None,
162+
botocore_config: dict | None = None,
147163
**hook_params: Any,
148164
):
149165
super().__init__()
@@ -160,8 +176,10 @@ def __init__(
160176
self.allow_delete = allow_delete
161177
self.aws_conn_id = aws_conn_id
162178
self.last_activity_time = last_activity_time
163-
self.verify = verify
164179
self.polling_period_seconds = 0
180+
self.region_name = region_name
181+
self.verify = verify
182+
self.botocore_config = botocore_config
165183
self.hook_params = hook_params
166184

167185
def serialize(self) -> tuple[str, dict[str, Any]]:
@@ -179,14 +197,21 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
179197
"aws_conn_id": self.aws_conn_id,
180198
"last_activity_time": self.last_activity_time,
181199
"hook_params": self.hook_params,
182-
"verify": self.verify,
183200
"polling_period_seconds": self.polling_period_seconds,
201+
"region_name": self.region_name,
202+
"verify": self.verify,
203+
"botocore_config": self.botocore_config,
184204
},
185205
)
186206

187207
@cached_property
188208
def hook(self) -> S3Hook:
189-
return S3Hook(aws_conn_id=self.aws_conn_id, verify=self.hook_params.get("verify"))
209+
return S3Hook(
210+
aws_conn_id=self.aws_conn_id,
211+
region_name=self.region_name,
212+
verify=self.verify,
213+
config=self.botocore_config,
214+
)
190215

191216
async def run(self) -> AsyncIterator[TriggerEvent]:
192217
"""Make an asynchronous connection using S3Hook."""

providers/amazon/tests/unit/amazon/aws/operators/test_s3.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -415,20 +415,19 @@ def test_template_fields(self):
415415

416416

417417
class TestS3ListOperator:
418-
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
419-
def test_execute(self, mock_hook):
420-
mock_hook.return_value.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
421-
418+
def test_execute(self):
422419
operator = S3ListOperator(
423420
task_id="test-s3-list-operator",
424421
bucket=BUCKET_NAME,
425422
prefix="TEST",
426423
delimiter=".csv",
427424
)
425+
operator.hook = mock.MagicMock()
426+
operator.hook.list_keys.return_value = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
428427

429428
files = operator.execute(None)
430429

431-
mock_hook.return_value.list_keys.assert_called_once_with(
430+
operator.hook.list_keys.assert_called_once_with(
432431
bucket_name=BUCKET_NAME,
433432
prefix="TEST",
434433
delimiter=".csv",
@@ -447,17 +446,16 @@ def test_template_fields(self):
447446

448447

449448
class TestS3ListPrefixesOperator:
450-
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
451-
def test_execute(self, mock_hook):
452-
mock_hook.return_value.list_prefixes.return_value = ["test/"]
453-
449+
def test_execute(self):
454450
operator = S3ListPrefixesOperator(
455451
task_id="test-s3-list-prefixes-operator", bucket=BUCKET_NAME, prefix="test/", delimiter="/"
456452
)
453+
operator.hook = mock.MagicMock()
454+
operator.hook.list_prefixes.return_value = ["test/"]
457455

458456
subfolders = operator.execute(None)
459457

460-
mock_hook.return_value.list_prefixes.assert_called_once_with(
458+
operator.hook.list_prefixes.assert_called_once_with(
461459
bucket_name=BUCKET_NAME, prefix="test/", delimiter="/"
462460
)
463461
assert subfolders == ["test/"]
@@ -870,8 +868,7 @@ def test_validate_keys_and_prefix_in_execute(self, keys, prefix, from_datetime,
870868
assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test
871869

872870
@pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"]))
873-
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
874-
def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys):
871+
def test_get_openlineage_facets_on_complete_single_object(self, keys):
875872
bucket = "testbucket"
876873
expected_input = Dataset(
877874
namespace=f"s3://{bucket}",
@@ -888,14 +885,14 @@ def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, keys)
888885
)
889886

890887
op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
888+
op.hook = mock.MagicMock()
891889
op.execute(None)
892890

893891
lineage = op.get_openlineage_facets_on_complete(None)
894892
assert len(lineage.inputs) == 1
895893
assert lineage.inputs[0] == expected_input
896894

897-
@mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
898-
def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook):
895+
def test_get_openlineage_facets_on_complete_multiple_objects(self):
899896
bucket = "testbucket"
900897
keys = ["path/data1.txt", "path/data2.txt"]
901898
expected_inputs = [
@@ -928,6 +925,7 @@ def test_get_openlineage_facets_on_complete_multiple_objects(self, mock_hook):
928925
]
929926

930927
op = S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", bucket=bucket, keys=keys)
928+
op.hook = mock.MagicMock()
931929
op.execute(None)
932930

933931
lineage = op.get_openlineage_facets_on_complete(None)

providers/amazon/tests/unit/amazon/aws/sensors/test_s3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,10 @@ def test_key_changes(self, current_objects, expected_returns, inactivity_periods
538538
assert self.sensor.inactivity_seconds == period
539539
time_machine.coordinates.shift(10)
540540

541-
@mock.patch("airflow.providers.amazon.aws.sensors.s3.S3Hook")
542-
def test_poke_succeeds_on_upload_complete(self, mock_hook, time_machine):
541+
def test_poke_succeeds_on_upload_complete(self, time_machine):
543542
time_machine.move_to(DEFAULT_DATE)
544-
mock_hook.return_value.list_keys.return_value = {"a"}
543+
self.sensor.hook = mock.MagicMock()
544+
self.sensor.hook.list_keys.return_value = {"a"}
545545
assert not self.sensor.poke(dict())
546546
time_machine.coordinates.shift(10)
547547
assert not self.sensor.poke(dict())

providers/amazon/tests/unit/amazon/aws/triggers/test_s3.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ def test_serialization(self):
4646
"poke_interval": 5.0,
4747
"should_check_fn": False,
4848
"use_regex": False,
49+
"verify": None,
50+
"region_name": None,
51+
"botocore_config": None,
4952
}
5053

5154
@pytest.mark.asyncio
@@ -106,6 +109,8 @@ def test_serialization(self):
106109
"last_activity_time": None,
107110
"hook_params": {},
108111
"verify": None,
112+
"region_name": None,
113+
"botocore_config": None,
109114
"polling_period_seconds": 0,
110115
}
111116

0 commit comments

Comments
 (0)