1+ from tempfile import NamedTemporaryFile
12import logging
23import json
3- import collections
4+ import os
5+
46from airflow .hooks .S3_hook import S3Hook
57from airflow .models import BaseOperator , SkipMixin
68from airflow .utils .decorators import apply_defaults
9+
710from stripe_plugin .hooks .stripe_hook import StripeHook
8- from tempfile import NamedTemporaryFile
911
1012
1113class StripeToS3Operator (BaseOperator , SkipMixin ):
1214 """
13- Make a query against Stripe and write the resulting data to s3
15+ Stripe to S3 Operator
16+
17+ :param stripe_conn_id: Name of the Airflow connection that has
18+ your Stripe username, password and user_key
19+ :type stripe_conn_id: String
20+ :param stripe_object: Name of the Stripe object. Currently
21+ supported objects include:
22+ - BalanceTransaction
23+ - Charge
24+ - Coupon
25+ - Customer
26+ - Dispute
27+ - Event
28+ - FileUpload
29+ - Invoice
30+ - InvoiceItem
31+ - Payout
32+ - Order
33+ - OrderReturn
34+ - Plan
35+ - Product
36+ - Refund
37+ - SKU
38+ - Subscription
39+ :type stripe_object: String
40+ :param stripe_args *(optional)* Extra stripe arguments
41+ :type stripe_args: Dictionary
42+ :param s3_conn_id: Name of the S3 connection id
43+ :type s3_conn_id: String
44+ :param s3_bucket: name of the destination S3 bucket
45+ :type s3_bucket String
46+ :param s3_key: name of the destination file from bucket
47+ :type s3_key: String
48+ :param fields: *(optional)* list of fields that you want
49+ to get from the object.
50+ If *None*, then this will get all fields
51+ for the object
52+ :type fields: List
53+ :param replication_key_value: *(optional)* value of the replication key,
54+ if needed. The operator will import only
55+ results with the id grater than the value of
56+ this param.
57+ :type replication_key_value: String
1458 """
59+
1560 template_field = ('s3_key' , )
1661
1762 @apply_defaults
@@ -27,57 +72,17 @@ def __init__(self,
2772 * args ,
2873 ** kwargs
2974 ):
30- """
31- Initialize the operator
32- :param stripe_conn_id: name of the Airflow connection that has
33- your Stripe username, password and user_key
34- :param stripe_object: name of the Stripe object we are
35- fetching data from
36- :param stripe_args *(optional)* dictionary with extra stripe
37- arguments
38- :param s3_conn_id: name of the Airflow connection that has
39- your Amazon S3 conection params
40- :param s3_bucket: name of the destination S3 bucketcd
41- :param s3_key: name of the destination file from bucket
42- :param fields: *(optional)* list of fields that you want
43- to get from the object.
44- If *None*, then this will get all fields
45- for the object
46- :param replication_key_value: *(optional)* value of the replication key,
47- if needed. The operator will import only
48- results with the id grater than the value of
49- this param.
50- """
5175
5276 super ().__init__ (* args , ** kwargs )
5377
5478 self .stripe_conn_id = stripe_conn_id
5579 self .stripe_object = stripe_object
5680 self .stripe_args = stripe_args
57-
5881 self .s3_conn_id = s3_conn_id
5982 self .s3_bucket = s3_bucket
6083 self .s3_key = s3_key
61-
6284 self .fields = fields
6385 self .replication_key_value = replication_key_value
64- self ._kwargs = kwargs
65-
66- def filter_fields (self , result ):
67- """
68- Filter the fields from an resulting object.
69-
70- This will return a object only with fields given
71- as parameter in the constructor.
72-
73- All fields are returned when "fields" param is None.
74- """
75- if not self .fields :
76- return result
77- obj = {}
78- for field in self .fields :
79- obj [field ] = result [field ]
80- return obj
8186
8287 def execute (self , context ):
8388 """
@@ -86,47 +91,37 @@ def execute(self, context):
8691 and write it to a file.
8792 """
8893 logging .info ("Prepping to gather data from Stripe" )
89- hook = StripeHook (
90- conn_id = self .stripe_conn_id
91- )
92-
93- # attempt to connect to Stripe
94- # if this process fails, it will raise an error and die right here
95- # we could wrap it
96- hook .get_conn ()
97-
98- logging .info (
99- "Making request for"
100- " {0} object" .format (self .stripe_object )
101- )
102-
103- results = hook .run_query (
104- self .stripe_object ,
105- self .replication_key_value ,
106- ** self .stripe_args )
107-
108- if len (results ) == 0 or results is None :
109- logging .info ("No records pulled from Stripe." )
110- downstream_tasks = context ['task' ].get_flat_relatives (
111- upstream = False )
112- logging .info ('Skipping downstream tasks...' )
113- logging .debug ("Downstream task_ids %s" , downstream_tasks )
114-
115- if downstream_tasks :
116- self .skip (context ['dag_run' ],
117- context ['ti' ].execution_date ,
118- downstream_tasks )
119- return True
120-
121- else :
122- # Write the results to a temporary file and save that file to s3.
123- with NamedTemporaryFile ("w" ) as tmp :
124- for result in results :
125- filtered_result = self .filter_fields (result )
126- tmp .write (json .dumps (filtered_result ) + '\n ' )
127-
128- tmp .flush ()
12994
95+ hook = StripeHook (conn_id = self .stripe_conn_id )
96+
97+ logging .info ("Making request for {0} object" .format (self .stripe_object ))
98+
99+ results = hook .run_query (self .stripe_object ,
100+ self .replication_key_value ,
101+ ** self .stripe_args )
102+
103+ # Write the results to a temporary file and save that file to s3.
104+ with NamedTemporaryFile ("w" ) as tmp :
105+ for result in results :
106+ filtered_result = self .filter_fields (result )
107+ tmp .write (json .dumps (filtered_result ) + '\n ' )
108+
109+ tmp .flush ()
110+
111+ if os .stat (tmp .name ).st_size == 0 :
112+ logging .info ("No records pulled from Stripe." )
113+ downstream_tasks = context ['task' ].get_flat_relatives (
114+ upstream = False )
115+ logging .info ('Skipping downstream tasks...' )
116+ logging .debug ("Downstream task_ids %s" , downstream_tasks )
117+
118+ if downstream_tasks :
119+ self .skip (context ['dag_run' ],
120+ context ['ti' ].execution_date ,
121+ downstream_tasks )
122+ return True
123+
124+ else :
130125 dest_s3 = S3Hook (s3_conn_id = self .s3_conn_id )
131126 dest_s3 .load_file (
132127 filename = tmp .name ,
@@ -137,3 +132,19 @@ def execute(self, context):
137132 )
138133 dest_s3 .connection .close ()
139134 tmp .close ()
135+
136+ def filter_fields (self , result ):
137+ """
138+ Filter the fields from an resulting object.
139+
140+ This will return a object only with fields given
141+ as parameter in the constructor.
142+
143+ All fields are returned when "fields" param is None.
144+ """
145+ if not self .fields :
146+ return result
147+ obj = {}
148+ for field in self .fields :
149+ obj [field ] = result [field ]
150+ return obj
0 commit comments