2525import io .cdap .plugin .db .TransactionIsolationLevel ;
2626import io .cdap .plugin .util .DBUtils ;
2727import org .apache .hadoop .conf .Configuration ;
28+ import org .apache .hadoop .mapreduce .JobContext ;
29+ import org .apache .hadoop .mapreduce .OutputCommitter ;
2830import org .apache .hadoop .mapreduce .RecordWriter ;
2931import org .apache .hadoop .mapreduce .TaskAttemptContext ;
3032import org .apache .hadoop .mapreduce .lib .db .DBConfiguration ;
4143import java .sql .PreparedStatement ;
4244import java .sql .SQLException ;
4345import java .sql .Statement ;
46+ import java .util .HashMap ;
4447import java .util .Map ;
4548import java .util .Properties ;
4649
5659public class ETLDBOutputFormat <K extends DBWritable , V > extends DBOutputFormat <K , V > {
5760 // Batch size before submitting a batch to the SQL engine. If set to 0, no batches will be submitted until commit.
5861 public static final String COMMIT_BATCH_SIZE = "io.cdap.plugin.db.output.commit.batch.size" ;
62+ public static final String STAGE_NAME = "io.cdap.plugin.db.output.stage_name" ;
5963 public static final int DEFAULT_COMMIT_BATCH_SIZE = 1000 ;
6064 private static final Character ESCAPE_CHAR = '"' ;
61-
65+ private static final String CONNECTION_MAP_ID_REGEX = "%s_%s" ;
6266 private static final Logger LOG = LoggerFactory .getLogger (ETLDBOutputFormat .class );
6367
6468 private Configuration conf ;
6569 private Driver driver ;
6670 private JDBCDriverShim driverShim ;
6771
72+ private static Map <String , Connection > connectionMap = new HashMap <>();
73+
74+ @ Override
75+ public OutputCommitter getOutputCommitter (TaskAttemptContext context )
76+ throws IOException , InterruptedException {
77+ return new OutputCommitter () {
78+ @ Override
79+ public void setupJob (JobContext jobContext ) throws IOException {
80+ // do nothing
81+ }
82+
83+ @ Override
84+ public void setupTask (TaskAttemptContext taskContext ) throws IOException {
85+ // do nothing
86+ }
87+
88+ @ Override
89+ public boolean needsTaskCommit (TaskAttemptContext taskContext ) throws IOException {
90+ return true ;
91+ }
92+
93+ @ Override
94+ public void commitTask (TaskAttemptContext taskContext ) throws IOException {
95+ conf = context .getConfiguration ();
96+ String stageName = conf .get (STAGE_NAME );
97+ String connectionId = String .format (CONNECTION_MAP_ID_REGEX , context .getTaskAttemptID ().toString (), stageName );
98+ Connection connection ;
99+ if ((connection = connectionMap .remove (connectionId )) != null ) {
100+ try {
101+ connection .commit ();
102+ } catch (SQLException e ) {
103+ try {
104+ connection .rollback ();
105+ } catch (SQLException ex ) {
106+ LOG .warn (StringUtils .stringifyException (ex ));
107+ }
108+ throw new IOException (e );
109+ } finally {
110+ try {
111+ connection .close ();
112+ LOG .debug ("Connection Closed after committing the task with taskAttemptId {}" , connectionId );
113+ } catch (SQLException ex ) {
114+ LOG .warn (StringUtils .stringifyException (ex ));
115+ }
116+ }
117+ try {
118+ DriverManager .deregisterDriver (driverShim );
119+ } catch (SQLException e ) {
120+ throw new IOException (e );
121+ }
122+ }
123+ }
124+
125+ @ Override
126+ public void abortTask (TaskAttemptContext taskContext ) throws IOException {
127+ conf = context .getConfiguration ();
128+ String stageName = conf .get (STAGE_NAME );
129+ String connectionId = String .format (CONNECTION_MAP_ID_REGEX , context .getTaskAttemptID ().toString (), stageName );
130+ Connection connection ;
131+ if ((connection = connectionMap .remove (connectionId )) != null ) {
132+ try {
133+ connection .rollback ();
134+ } catch (SQLException e ) {
135+ throw new IOException (e );
136+ } finally {
137+ try {
138+ connection .close ();
139+ LOG .debug ("Connection Closed after rollback the task with taskAttemptId {}" , connectionId );
140+ } catch (SQLException ex ) {
141+ LOG .warn (StringUtils .stringifyException (ex ));
142+ }
143+ }
144+ try {
145+ DriverManager .deregisterDriver (driverShim );
146+ } catch (SQLException e ) {
147+ throw new IOException (e );
148+ }
149+ }
150+ }
151+ };
152+ }
153+
68154 @ Override
69155 public RecordWriter <K , V > getRecordWriter (TaskAttemptContext context ) throws IOException {
70156 conf = context .getConfiguration ();
@@ -81,6 +167,11 @@ public RecordWriter<K, V> getRecordWriter(TaskAttemptContext context) throws IOE
81167
82168 try {
83169 Connection connection = getConnection (conf );
170+ String stageName = conf .get (STAGE_NAME );
171+ // If using multiple sinks, task attemptID can be same in that case, appending stage in the end for uniqueness.
172+ String connectionId = String .format (CONNECTION_MAP_ID_REGEX , context .getTaskAttemptID ().toString (), stageName );
173+ connectionMap .put (connectionId , connection );
174+ LOG .debug ("Connection Added to the map with connectionId : {}" , connectionId );
84175 PreparedStatement statement = connection .prepareStatement (constructQueryOnOperation (tableName , fieldNames ,
85176 operationName , listKeys ));
86177 return new DBRecordWriter (connection , statement ) {
@@ -98,28 +189,15 @@ public void close(TaskAttemptContext context) throws IOException {
98189 if (!emptyData ) {
99190 getStatement ().executeBatch ();
100191 }
101- getConnection ().commit ();
102192 } catch (SQLException e ) {
103- try {
104- getConnection ().rollback ();
105- } catch (SQLException ex ) {
106- LOG .warn (StringUtils .stringifyException (ex ));
107- }
108193 throw new IOException (e );
109194 } finally {
110195 try {
111196 getStatement ().close ();
112- getConnection ().close ();
113197 } catch (SQLException ex ) {
114198 throw new IOException (ex );
115199 }
116200 }
117-
118- try {
119- DriverManager .deregisterDriver (driverShim );
120- } catch (SQLException e ) {
121- throw new IOException (e );
122- }
123201 }
124202
125203 @ Override
0 commit comments