Skip to content

Commit 343cbca

Browse files
committed
fix: Update sqlExecute method to return structured result maps
- Modified the `sqlExecute` method to return a map containing success status, affected rows, and error messages for both DDL and DML operations. - Updated documentation to reflect the new return structure. - Adjusted unit tests to validate the new response format and ensure proper handling of success and error cases. Signed-off-by: Edmund Miller <[email protected]>
1 parent 3598b14 commit 343cbca

File tree

3 files changed

+88
-60
lines changed

3 files changed

+88
-60
lines changed

plugins/nf-sqldb/src/main/nextflow/sql/ChannelSqlExtension.groovy

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,21 @@ class ChannelSqlExtension extends PluginExtensionPoint {
145145

146146
/**
147147
* Execute a SQL statement that does not return a result set (DDL/DML statements)
148-
* For DML statements (INSERT, UPDATE, DELETE), it returns the number of affected rows
149-
* For DDL statements (CREATE, ALTER, DROP), it returns null
148+
* For DML statements (INSERT, UPDATE, DELETE), it returns a result map with success status and number of affected rows
149+
* For DDL statements (CREATE, ALTER, DROP), it returns a result map with success status
150150
*
151151
* @param params A map containing 'db' (database alias) and 'statement' (SQL string to execute)
152-
* @return The number of rows affected by the SQL statement for DML operations, null for DDL operations
152+
* @return A map containing 'success' (boolean), 'result' (rows affected or null) and optionally 'error' (message)
153153
*/
154154
@Function
155-
Integer sqlExecute(Map params) {
155+
Map sqlExecute(Map params) {
156156
CheckHelper.checkParams('sqlExecute', params, EXECUTE_PARAMS)
157157

158158
final String dbName = params.db as String ?: 'default'
159159
final String statement = params.statement as String
160160

161161
if (!statement)
162-
throw new IllegalArgumentException("Missing required parameter 'statement'")
162+
return [success: false, error: "Missing required parameter 'statement'"]
163163

164164
final sqlConfig = new SqlConfig((Map) session.config.navigate('sql.db'))
165165
final SqlDataSource dataSource = sqlConfig.getDataSource(dbName)
@@ -171,28 +171,27 @@ class ChannelSqlExtension extends PluginExtensionPoint {
171171
msg += " - Did you mean: ${choices.get(0)}?"
172172
else if (choices)
173173
msg += " - Did you mean any of these?\n" + choices.collect { " $it" }.join('\n') + '\n'
174-
throw new IllegalArgumentException(msg)
174+
return [success: false, error: msg]
175175
}
176176

177177
try (Connection conn = groovy.sql.Sql.newInstance(dataSource.toMap()).getConnection()) {
178178
try (Statement stm = conn.createStatement()) {
179179
String normalizedStatement = normalizeStatement(statement)
180180

181-
// For DDL statements (CREATE, ALTER, DROP, etc.), execute() returns true if the first result is a ResultSet
182-
// For DML statements (INSERT, UPDATE, DELETE), executeUpdate() returns the number of rows affected
183181
boolean isDDL = normalizedStatement.trim().toLowerCase().matches("^(create|alter|drop|truncate).*")
184182

185183
if (isDDL) {
186184
stm.execute(normalizedStatement)
187-
return null
185+
return [success: true, result: null]
188186
} else {
189-
return stm.executeUpdate(normalizedStatement)
187+
Integer rowsAffected = stm.executeUpdate(normalizedStatement)
188+
return [success: true, result: rowsAffected]
190189
}
191190
}
192191
}
193192
catch (Exception e) {
194193
log.error("Error executing SQL statement: ${e.message}", e)
195-
throw e
194+
return [success: false, error: e.message]
196195
}
197196
}
198197

@@ -204,7 +203,7 @@ class ChannelSqlExtension extends PluginExtensionPoint {
204203
*/
205204
private static String normalizeStatement(String statement) {
206205
if (!statement)
207-
throw new IllegalArgumentException("Missing SQL statement")
206+
return null
208207
def result = statement.trim()
209208
if (!result.endsWith(';'))
210209
result += ';'

plugins/nf-sqldb/src/test/nextflow/sql/SqlExecutionTest.groovy

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class SqlExecutionTest extends Specification {
4040
Global.session = null
4141
}
4242

43-
def 'should execute DDL statements successfully and return null'() {
43+
def 'should execute DDL statements successfully and return success map'() {
4444
given:
4545
def JDBC_URL = 'jdbc:h2:mem:test_ddl_' + Random.newInstance().nextInt(1_000_000)
4646
def sql = Sql.newInstance(JDBC_URL, 'sa', null)
@@ -58,29 +58,32 @@ class SqlExecutionTest extends Specification {
5858
statement: 'CREATE TABLE test_table(id INT PRIMARY KEY, name VARCHAR(255))'
5959
])
6060

61-
then: 'Table should be created and result should be null'
61+
then: 'Table should be created and result should indicate success'
6262
sql.rows('SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = \'TEST_TABLE\'').size() > 0
63-
createResult == null
63+
createResult.success == true
64+
createResult.result == null
6465

6566
when: 'Altering the table'
6667
def alterResult = sqlExtension.sqlExecute([
6768
db: 'test',
6869
statement: 'ALTER TABLE test_table ADD COLUMN description VARCHAR(255)'
6970
])
7071

71-
then: 'Column should be added and result should be null'
72+
then: 'Column should be added and result should indicate success'
7273
sql.rows('SELECT COLUMN_NAME FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_NAME = \'TEST_TABLE\' AND COLUMN_NAME = \'DESCRIPTION\'').size() > 0
73-
alterResult == null
74+
alterResult.success == true
75+
alterResult.result == null
7476

7577
when: 'Dropping the table'
7678
def dropResult = sqlExtension.sqlExecute([
7779
db: 'test',
7880
statement: 'DROP TABLE test_table'
7981
])
8082

81-
then: 'Table should be dropped and result should be null'
83+
then: 'Table should be dropped and result should indicate success'
8284
sql.rows('SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = \'TEST_TABLE\'').size() == 0
83-
dropResult == null
85+
dropResult.success == true
86+
dropResult.result == null
8487
}
8588

8689
def 'should execute DML statements successfully and return affected row count'() {
@@ -104,30 +107,33 @@ class SqlExecutionTest extends Specification {
104107
statement: 'INSERT INTO test_dml (id, name, value) VALUES (1, \'item1\', 100)'
105108
])
106109

107-
then: 'Row should be inserted and result should be 1'
110+
then: 'Row should be inserted and result should indicate success with 1 row affected'
108111
sql.rows('SELECT * FROM test_dml').size() == 1
109112
sql.firstRow('SELECT * FROM test_dml WHERE id = 1').name == 'item1'
110-
insertResult == 1
113+
insertResult.success == true
114+
insertResult.result == 1
111115

112116
when: 'Updating data'
113117
def updateResult = sqlExtension.sqlExecute([
114118
db: 'test',
115119
statement: 'UPDATE test_dml SET value = 200 WHERE id = 1'
116120
])
117121

118-
then: 'Row should be updated and result should be 1'
122+
then: 'Row should be updated and result should indicate success with 1 row affected'
119123
sql.firstRow('SELECT value FROM test_dml WHERE id = 1').value == 200
120-
updateResult == 1
124+
updateResult.success == true
125+
updateResult.result == 1
121126

122127
when: 'Deleting data'
123128
def deleteResult = sqlExtension.sqlExecute([
124129
db: 'test',
125130
statement: 'DELETE FROM test_dml WHERE id = 1'
126131
])
127132

128-
then: 'Row should be deleted and result should be 1'
133+
then: 'Row should be deleted and result should indicate success with 1 row affected'
129134
sql.rows('SELECT * FROM test_dml').size() == 0
130-
deleteResult == 1
135+
deleteResult.success == true
136+
deleteResult.result == 1
131137
}
132138

133139
def 'should return correct affected row count for multiple row operations'() {
@@ -149,40 +155,42 @@ class SqlExecutionTest extends Specification {
149155
sqlExtension.init(session)
150156

151157
when: 'Inserting data'
152-
def insertCount = sqlExtension.sqlExecute([
158+
def insertResult = sqlExtension.sqlExecute([
153159
db: 'test',
154160
statement: 'INSERT INTO test_update (id, name, value) VALUES (4, \'item4\', 100)'
155161
])
156162

157-
then: 'Should return 1 affected row'
158-
insertCount == 1
163+
then: 'Should return success with 1 affected row'
164+
insertResult.success == true
165+
insertResult.result == 1
159166
sql.rows('SELECT * FROM test_update').size() == 4
160167

161168
when: 'Updating multiple rows'
162-
def updateCount = sqlExtension.sqlExecute([
169+
def updateResult = sqlExtension.sqlExecute([
163170
db: 'test',
164171
statement: 'UPDATE test_update SET value = 200 WHERE value = 100'
165172
])
166173

167-
then: 'Should return 4 affected rows'
168-
updateCount == 4
174+
then: 'Should return success with 4 affected rows'
175+
updateResult.success == true
176+
updateResult.result == 4
169177
sql.rows('SELECT * FROM test_update WHERE value = 200').size() == 4
170178

171179
when: 'Deleting multiple rows'
172-
def deleteCount = sqlExtension.sqlExecute([
180+
def deleteResult = sqlExtension.sqlExecute([
173181
db: 'test',
174182
statement: 'DELETE FROM test_update WHERE value = 200'
175183
])
176184

177-
then: 'Should return 4 affected rows'
178-
deleteCount == 4
185+
then: 'Should return success with 4 affected rows'
186+
deleteResult.success == true
187+
deleteResult.result == 4
179188
sql.rows('SELECT * FROM test_update').size() == 0
180189
}
181190

182191
def 'should handle invalid SQL correctly'() {
183192
given:
184193
def JDBC_URL = 'jdbc:h2:mem:test_error_' + Random.newInstance().nextInt(1_000_000)
185-
def sql = Sql.newInstance(JDBC_URL, 'sa', null)
186194

187195
and:
188196
def session = Mock(Session) {
@@ -192,22 +200,24 @@ class SqlExecutionTest extends Specification {
192200
sqlExtension.init(session)
193201

194202
when: 'Executing invalid SQL'
195-
sqlExtension.sqlExecute([
203+
def invalidResult = sqlExtension.sqlExecute([
196204
db: 'test',
197205
statement: 'INVALID SQL STATEMENT'
198206
])
199207

200-
then: 'Should throw an exception'
201-
thrown(Exception)
208+
then: 'Should return failure with error message'
209+
invalidResult.success == false
210+
invalidResult.error != null
202211

203212
when: 'Executing query with invalid table name'
204-
sqlExtension.sqlExecute([
213+
def noTableResult = sqlExtension.sqlExecute([
205214
db: 'test',
206215
statement: 'SELECT * FROM non_existent_table'
207216
])
208217

209-
then: 'Should throw an exception'
210-
thrown(Exception)
218+
then: 'Should return failure with error message'
219+
noTableResult.success == false
220+
noTableResult.error != null
211221
}
212222

213223
def 'should handle invalid database configuration correctly'() {
@@ -219,30 +229,36 @@ class SqlExecutionTest extends Specification {
219229
sqlExtension.init(session)
220230

221231
when: 'Using non-existent database alias'
222-
sqlExtension.sqlExecute([
232+
def nonExistentDbResult = sqlExtension.sqlExecute([
223233
db: 'non_existent_db',
224234
statement: 'SELECT 1'
225235
])
226236

227-
then: 'Should throw an IllegalArgumentException'
228-
thrown(IllegalArgumentException)
237+
then: 'Should return failure with error message'
238+
nonExistentDbResult.success == false
239+
nonExistentDbResult.error != null
240+
nonExistentDbResult.error.contains('Unknown db name')
229241

230242
when: 'Missing statement parameter'
231-
sqlExtension.sqlExecute([
243+
def missingStatementResult = sqlExtension.sqlExecute([
232244
db: 'test'
233245
])
234246

235-
then: 'Should throw an IllegalArgumentException'
236-
thrown(IllegalArgumentException)
247+
then: 'Should return failure with error message'
248+
missingStatementResult.success == false
249+
missingStatementResult.error != null
250+
missingStatementResult.error.contains('Missing required parameter')
237251

238252
when: 'Empty statement parameter'
239-
sqlExtension.sqlExecute([
253+
def emptyStatementResult = sqlExtension.sqlExecute([
240254
db: 'test',
241255
statement: ''
242256
])
243257

244-
then: 'Should throw an IllegalArgumentException'
245-
thrown(IllegalArgumentException)
258+
then: 'Should return failure with error message'
259+
emptyStatementResult.success == false
260+
emptyStatementResult.error != null
261+
emptyStatementResult.error.contains('Missing required parameter')
246262
}
247263

248264
def 'should handle statement normalization correctly'() {
@@ -258,23 +274,25 @@ class SqlExecutionTest extends Specification {
258274
sqlExtension.init(session)
259275

260276
when: 'Executing statement without semicolon'
261-
def result = sqlExtension.sqlExecute([
277+
def createResult = sqlExtension.sqlExecute([
262278
db: 'test',
263279
statement: 'CREATE TABLE test_norm(id INT PRIMARY KEY)'
264280
])
265281

266-
then: 'Statement should be executed successfully and result should be null'
282+
then: 'Statement should be executed successfully'
267283
sql.rows('SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = \'TEST_NORM\'').size() > 0
268-
result == null
284+
createResult.success == true
285+
createResult.result == null
269286

270287
when: 'Executing statement with trailing whitespace'
271288
def dropResult = sqlExtension.sqlExecute([
272289
db: 'test',
273290
statement: 'DROP TABLE test_norm '
274291
])
275292

276-
then: 'Statement should be executed successfully and result should be null'
293+
then: 'Statement should be executed successfully'
277294
sql.rows('SELECT 1 FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = \'TEST_NORM\'').size() == 0
278-
dropResult == null
295+
dropResult.success == true
296+
dropResult.result == null
279297
}
280298
}

plugins/nf-sqldb/src/testResources/testDir/test_sql_db.nf

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ nextflow.enable.dsl=2
33
include { fromQuery; sqlInsert; sqlExecute } from 'plugin/nf-sqldb'
44

55
workflow {
6-
// Setup: create table (DDL operation returns null)
6+
// Setup: create table (DDL operation)
77
def createResult = sqlExecute(
88
db: 'foo',
99
statement: '''
@@ -14,7 +14,13 @@ workflow {
1414
)
1515
'''
1616
)
17-
println "Create result: $createResult" // null
17+
println "Create table success: ${createResult.success}" // Should be true
18+
19+
// Handle potential failure
20+
if (!createResult.success) {
21+
println "Failed to create table: ${createResult.error}"
22+
return
23+
}
1824

1925
// Insert data using sqlInsert
2026
Channel
@@ -29,10 +35,15 @@ workflow {
2935
fromQuery('SELECT * FROM sample_table', db: 'foo')
3036
.view()
3137

32-
// Update data using sqlExecute (DML operation returns affected row count)
33-
def updated = sqlExecute(
38+
// Update data using sqlExecute (DML operation returns affected row count in result field)
39+
def updateResult = sqlExecute(
3440
db: 'foo',
3541
statement: "UPDATE sample_table SET value = 30.5 WHERE name = 'beta'"
3642
)
37-
println "Updated $updated row(s)"
43+
44+
if (updateResult.success) {
45+
println "Updated ${updateResult.result} row(s)"
46+
} else {
47+
println "Update failed: ${updateResult.error}"
48+
}
3849
}

0 commit comments

Comments
 (0)