1+ from tempfile import NamedTemporaryFile
2+ import logging
3+ import json
4+
5+ from airflow .utils .decorators import apply_defaults
16from airflow .models import BaseOperator
7+ from airflow .hooks .S3_hook import S3Hook
8+
9+ from airflow .contrib .hooks .salesforce_hook import SalesforceHook
10+
211
312class SalesforceBulkQueryToS3Operator (BaseOperator ):
413 """
@@ -30,7 +39,6 @@ def __init__(self, sf_conn_id, soql, object_type, #sf config
3039
3140 def execute (self , context ):
3241 sf_conn = SalesforceHook (self .sf_conn_id ).get_conn ()
33- s3_conn = S3Hook (self .s3_conn_id )
3442
3543 logging .info (self .soql )
3644 query_results = sf_conn .bulk .__getattr__ (self .object ).query (self .soql )
@@ -40,4 +48,172 @@ def execute(self, context):
4048 query_results = [json .dumps (result , ensure_ascii = False ) for result in query_results ]
4149 query_results = '\n ' .join (query_results )
4250
43- s3 .load_string (query_results , self .s3_key , bucket_name = self .s3_bucket , replace = True )
51+ s3 .load_string (query_results , self .s3_key , bucket_name = self .s3_bucket , replace = True )
52+
53+
54+ class SalesforceToS3Operator (BaseOperator ):
55+ """
56+ Make a query against Salesforce and write the resulting data to a file.
57+ """
58+ template_fields = ("s3_key" ,
59+ "query" )
60+
61+ @apply_defaults
62+ def __init__ (
63+ self ,
64+ sf_conn_id ,
65+ sf_obj ,
66+ s3_conn_id ,
67+ s3_bucket ,
68+ s3_key ,
69+ sf_fields = None ,
70+ fmt = "csv" ,
71+ query = None ,
72+ relationship_object = None ,
73+ record_time_added = False ,
74+ coerce_to_timestamp = False ,
75+ * args ,
76+ ** kwargs
77+ ):
78+ """
79+ Initialize the operator
80+ :param sf_conn_id: Name of the Airflow connection that has
81+ the following information:
82+ - username
83+ - password
84+ - security_token
85+ :param sf_obj: Name of the relevant Salesforce object
86+ :param s3_conn_id: The destination s3 connection id.
87+ :type s3_conn_id: string
88+ :param s3_bucket: The destination s3 bucket.
89+ :type s3_bucket: string
90+ :param s3_key: The destination s3 key.
91+ :type s3_key: string
92+ :param sf_fields: *(optional)* list of fields that you want
93+ to get from the object.
94+ If *None*, then this will get all fields
95+ for the object
96+ :param fmt: *(optional)* format that the s3_key of the
97+ data should be in. Possible values include:
98+ - csv
99+ - json
100+ - ndjson
101+ *Default: csv*
102+ :param query: *(optional)* A specific query to run for
103+ the given object. This will override
104+ default query creation.
105+ *Default: None*
106+ :param relationship_object: *(optional)* Some queries require
107+ relationship objects to work, and
108+ these are not the same names as
109+ the SF object. Specify that
110+ relationship object here.
111+ *Default: None*
112+ :param record_time_added: *(optional)* True if you want to add a
113+ Unix timestamp field to the resulting data
114+ that marks when the data was
115+ fetched from Salesforce.
116+ *Default: False*.
117+ :param coerce_to_timestamp: *(optional)* True if you want to convert
118+ all fields with dates and datetimes
119+ into Unix timestamp (UTC).
120+ *Default: False*.
121+ """
122+
123+ super (SalesforceToS3Operator , self ).__init__ (* args , ** kwargs )
124+
125+ self .sf_conn_id = sf_conn_id
126+ self .object = sf_obj
127+ self .fields = sf_fields
128+ self .s3_conn_id = s3_conn_id
129+ self .s3_bucket = s3_bucket
130+ self .s3_key = s3_key
131+ self .fmt = fmt .lower ()
132+ self .query = query
133+ self .relationship_object = relationship_object
134+ self .record_time_added = record_time_added
135+ self .coerce_to_timestamp = coerce_to_timestamp
136+
137+ def special_query (self , query , sf_hook , relationship_object = None ):
138+ if not query :
139+ raise ValueError ("Query is None. Cannot query nothing" )
140+
141+ sf_hook .sign_in ()
142+
143+ results = sf_hook .make_query (query )
144+ if relationship_object :
145+ records = []
146+ for r in results ['records' ]:
147+ if r .get (relationship_object , None ):
148+ records .extend (r [relationship_object ]['records' ])
149+ results ['records' ] = records
150+
151+ return results
152+
153+ def execute (self , context ):
154+ """
155+ Execute the operator.
156+ This will get all the data for a particular Salesforce model
157+ and write it to a file.
158+ """
159+ logging .info ("Prepping to gather data from Salesforce" )
160+
161+ # Open a name temporary file to store output file until S3 upload
162+ with NamedTemporaryFile ("w" ) as tmp :
163+
164+ # Load the SalesforceHook
165+ hook = SalesforceHook (conn_id = self .sf_conn_id , output = tmp .name )
166+
167+ # Attempt to login to Salesforce
168+ # If this process fails, it will raise an error and die.
169+ try :
170+ hook .sign_in ()
171+ except :
172+ logging .debug ('Unable to login.' )
173+
174+ # Get object from Salesforce
175+ # If fields were not defined, all fields are pulled.
176+ if not self .fields :
177+ self .fields = hook .get_available_fields (self .object )
178+
179+ logging .info (
180+ "Making request for "
181+ "{0} fields from {1}" .format (len (self .fields ), self .object )
182+ )
183+
184+ if self .query :
185+ query = self .special_query (self .query ,
186+ hook ,
187+ relationship_object = self .relationship_object
188+ )
189+ else :
190+ query = hook .get_object_from_salesforce (self .object ,
191+ self .fields )
192+
193+ # output the records from the query to a file
194+ # the list of records is stored under the "records" key
195+ logging .info ("Writing query results to: {0}" .format (tmp .name ))
196+
197+ hook .write_object_to_file (query ['records' ],
198+ filename = tmp .name ,
199+ fmt = self .fmt ,
200+ coerce_to_timestamp = self .coerce_to_timestamp ,
201+ record_time_added = self .record_time_added )
202+
203+ # Flush the temp file and upload temp file to S3
204+ tmp .flush ()
205+
206+ dest_s3 = S3Hook (s3_conn_id = self .s3_conn_id )
207+
208+ dest_s3 .load_file (
209+ filename = tmp .name ,
210+ key = self .output ,
211+ bucket_name = self .s3_bucket ,
212+ replace = True
213+ )
214+
215+ dest_s3 .connection .close ()
216+
217+ tmp .close ()
218+
219+ logging .info ("Query finished!" )
0 commit comments