Skip to content

Commit 234fdc4

Browse files
docs; formatting
1 parent 874c657 commit 234fdc4

File tree

2 files changed

+94
-83
lines changed

2 files changed

+94
-83
lines changed

operators/__init__.py

100644100755
File mode changed.

operators/stripe_to_s3_operator.py

100644100755
Lines changed: 94 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,62 @@
1+
from tempfile import NamedTemporaryFile
12
import logging
23
import json
3-
import collections
4+
import os
5+
46
from airflow.hooks.S3_hook import S3Hook
57
from airflow.models import BaseOperator, SkipMixin
68
from airflow.utils.decorators import apply_defaults
9+
710
from stripe_plugin.hooks.stripe_hook import StripeHook
8-
from tempfile import NamedTemporaryFile
911

1012

1113
class StripeToS3Operator(BaseOperator, SkipMixin):
1214
"""
13-
Make a query against Stripe and write the resulting data to s3
15+
Stripe to S3 Operator
16+
17+
:param stripe_conn_id: Name of the Airflow connection that has
18+
your Stripe username, password and user_key
19+
:type stripe_conn_id: String
20+
:param stripe_object: Name of the Stripe object. Currently
21+
supported objects include:
22+
- BalanceTransaction
23+
- Charge
24+
- Coupon
25+
- Customer
26+
- Dispute
27+
- Event
28+
- FileUpload
29+
- Invoice
30+
- InvoiceItem
31+
- Payout
32+
- Order
33+
- OrderReturn
34+
- Plan
35+
- Product
36+
- Refund
37+
- SKU
38+
- Subscription
39+
:type stripe_object: String
40+
:param stripe_args *(optional)* Extra stripe arguments
41+
:type stripe_args: Dictionary
42+
:param s3_conn_id: Name of the S3 connection id
43+
:type s3_conn_id: String
44+
:param s3_bucket: name of the destination S3 bucket
45+
:type s3_bucket String
46+
:param s3_key: name of the destination file from bucket
47+
:type s3_key: String
48+
:param fields: *(optional)* list of fields that you want
49+
to get from the object.
50+
If *None*, then this will get all fields
51+
for the object
52+
:type fields: List
53+
:param replication_key_value: *(optional)* value of the replication key,
54+
if needed. The operator will import only
55+
results with the id grater than the value of
56+
this param.
57+
:type replication_key_value: String
1458
"""
59+
1560
template_field = ('s3_key', )
1661

1762
@apply_defaults
@@ -27,57 +72,17 @@ def __init__(self,
2772
*args,
2873
**kwargs
2974
):
30-
"""
31-
Initialize the operator
32-
:param stripe_conn_id: name of the Airflow connection that has
33-
your Stripe username, password and user_key
34-
:param stripe_object: name of the Stripe object we are
35-
fetching data from
36-
:param stripe_args *(optional)* dictionary with extra stripe
37-
arguments
38-
:param s3_conn_id: name of the Airflow connection that has
39-
your Amazon S3 conection params
40-
:param s3_bucket: name of the destination S3 bucketcd
41-
:param s3_key: name of the destination file from bucket
42-
:param fields: *(optional)* list of fields that you want
43-
to get from the object.
44-
If *None*, then this will get all fields
45-
for the object
46-
:param replication_key_value: *(optional)* value of the replication key,
47-
if needed. The operator will import only
48-
results with the id grater than the value of
49-
this param.
50-
"""
5175

5276
super().__init__(*args, **kwargs)
5377

5478
self.stripe_conn_id = stripe_conn_id
5579
self.stripe_object = stripe_object
5680
self.stripe_args = stripe_args
57-
5881
self.s3_conn_id = s3_conn_id
5982
self.s3_bucket = s3_bucket
6083
self.s3_key = s3_key
61-
6284
self.fields = fields
6385
self.replication_key_value = replication_key_value
64-
self._kwargs = kwargs
65-
66-
def filter_fields(self, result):
67-
"""
68-
Filter the fields from an resulting object.
69-
70-
This will return a object only with fields given
71-
as parameter in the constructor.
72-
73-
All fields are returned when "fields" param is None.
74-
"""
75-
if not self.fields:
76-
return result
77-
obj = {}
78-
for field in self.fields:
79-
obj[field] = result[field]
80-
return obj
8186

8287
def execute(self, context):
8388
"""
@@ -86,47 +91,37 @@ def execute(self, context):
8691
and write it to a file.
8792
"""
8893
logging.info("Prepping to gather data from Stripe")
89-
hook = StripeHook(
90-
conn_id=self.stripe_conn_id
91-
)
92-
93-
# attempt to connect to Stripe
94-
# if this process fails, it will raise an error and die right here
95-
# we could wrap it
96-
hook.get_conn()
97-
98-
logging.info(
99-
"Making request for"
100-
" {0} object".format(self.stripe_object)
101-
)
102-
103-
results = hook.run_query(
104-
self.stripe_object,
105-
self.replication_key_value,
106-
**self.stripe_args)
107-
108-
if len(results) == 0 or results is None:
109-
logging.info("No records pulled from Stripe.")
110-
downstream_tasks = context['task'].get_flat_relatives(
111-
upstream=False)
112-
logging.info('Skipping downstream tasks...')
113-
logging.debug("Downstream task_ids %s", downstream_tasks)
114-
115-
if downstream_tasks:
116-
self.skip(context['dag_run'],
117-
context['ti'].execution_date,
118-
downstream_tasks)
119-
return True
120-
121-
else:
122-
# Write the results to a temporary file and save that file to s3.
123-
with NamedTemporaryFile("w") as tmp:
124-
for result in results:
125-
filtered_result = self.filter_fields(result)
126-
tmp.write(json.dumps(filtered_result) + '\n')
127-
128-
tmp.flush()
12994

95+
hook = StripeHook(conn_id=self.stripe_conn_id)
96+
97+
logging.info("Making request for {0} object".format(self.stripe_object))
98+
99+
results = hook.run_query(self.stripe_object,
100+
self.replication_key_value,
101+
**self.stripe_args)
102+
103+
# Write the results to a temporary file and save that file to s3.
104+
with NamedTemporaryFile("w") as tmp:
105+
for result in results:
106+
filtered_result = self.filter_fields(result)
107+
tmp.write(json.dumps(filtered_result) + '\n')
108+
109+
tmp.flush()
110+
111+
if os.stat(tmp.name).st_size == 0:
112+
logging.info("No records pulled from Stripe.")
113+
downstream_tasks = context['task'].get_flat_relatives(
114+
upstream=False)
115+
logging.info('Skipping downstream tasks...')
116+
logging.debug("Downstream task_ids %s", downstream_tasks)
117+
118+
if downstream_tasks:
119+
self.skip(context['dag_run'],
120+
context['ti'].execution_date,
121+
downstream_tasks)
122+
return True
123+
124+
else:
130125
dest_s3 = S3Hook(s3_conn_id=self.s3_conn_id)
131126
dest_s3.load_file(
132127
filename=tmp.name,
@@ -137,3 +132,19 @@ def execute(self, context):
137132
)
138133
dest_s3.connection.close()
139134
tmp.close()
135+
136+
def filter_fields(self, result):
137+
"""
138+
Filter the fields from an resulting object.
139+
140+
This will return a object only with fields given
141+
as parameter in the constructor.
142+
143+
All fields are returned when "fields" param is None.
144+
"""
145+
if not self.fields:
146+
return result
147+
obj = {}
148+
for field in self.fields:
149+
obj[field] = result[field]
150+
return obj

0 commit comments

Comments
 (0)