Skip to content

Commit e6c8692

Browse files
authored
Remove special handling for only PK column tables (#2929)
* Remove special handling for only PK column tables * Fix spotless * Refactor maps and update UTs
1 parent 677835c commit e6c8692

File tree

3 files changed

+359
-36
lines changed

3 files changed

+359
-36
lines changed

v2/spanner-to-sourcedb/src/main/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGenerator.java

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -121,44 +121,23 @@ public DMLGeneratorResponse getDMLStatement(DMLGeneratorRequest dmlGeneratorRequ
121121
}
122122

123123
private static DMLGeneratorResponse getUpsertStatement(
124-
String tableName,
125-
List<String> primaryKeys,
126-
Map<String, String> columnNameValues,
127-
Map<String, String> pkcolumnNameValues) {
124+
String tableName, Map<String, String> allColumnNameValues) {
128125

129126
String allColumns = "";
130127
String allValues = "";
131128
String updateValues = "";
132129

133-
for (Map.Entry<String, String> entry : pkcolumnNameValues.entrySet()) {
134-
String colName = entry.getKey();
135-
String colValue = entry.getValue();
136-
137-
allColumns += "`" + colName + "`,";
138-
allValues += colValue + ",";
139-
}
140-
141-
if (columnNameValues.size() == 0) { // if there are only PKs
142-
// trim the last ','
143-
allColumns = allColumns.substring(0, allColumns.length() - 1);
144-
allValues = allValues.substring(0, allValues.length() - 1);
145-
146-
String returnVal =
147-
"INSERT INTO `" + tableName + "`(" + allColumns + ")" + " VALUES (" + allValues + ") ";
148-
return new DMLGeneratorResponse(returnVal);
149-
}
150130
int index = 0;
151131

152-
for (Map.Entry<String, String> entry : columnNameValues.entrySet()) {
132+
for (Map.Entry<String, String> entry : allColumnNameValues.entrySet()) {
153133
String colName = entry.getKey();
154134
String colValue = entry.getValue();
155135
allColumns += "`" + colName + "`";
156136
allValues += colValue;
157-
if (!primaryKeys.contains(colName)) {
158-
updateValues += " `" + colName + "` = " + colValue;
159-
}
137+
updateValues += " `" + colName + "` = " + colValue;
160138

161-
if (index + 1 < columnNameValues.size()) {
139+
// Add comma if not the last item in this loop
140+
if (index + 1 < allColumnNameValues.size()) {
162141
allColumns += ",";
163142
allValues += ",";
164143
updateValues += ",";
@@ -214,8 +193,8 @@ private static DMLGeneratorResponse generateUpsertStatement(
214193
dmlGeneratorRequest.getKeyValuesJson(),
215194
dmlGeneratorRequest.getSourceDbTimezoneOffset(),
216195
dmlGeneratorRequest.getCustomTransformationResponse());
217-
return getUpsertStatement(
218-
sourceTable.name(), sourceTable.primaryKeyColumns(), columnNameValues, pkcolumnNameValues);
196+
columnNameValues.putAll(pkcolumnNameValues);
197+
return getUpsertStatement(sourceTable.name(), columnNameValues);
219198
}
220199

221200
private static Map<String, String> getColumnValues(

v2/spanner-to-sourcedb/src/test/java/com/google/cloud/teleport/v2/templates/dbutils/dml/MySQLDMLGeneratorTest.java

Lines changed: 65 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
package com.google.cloud.teleport.v2.templates.dbutils.dml;
1717

1818
import static com.google.common.truth.Truth.assertThat;
19+
import static org.junit.Assert.assertEquals;
1920
import static org.junit.Assert.assertThrows;
2021
import static org.junit.Assert.assertTrue;
2122

@@ -34,6 +35,7 @@
3435
import java.io.InputStream;
3536
import java.nio.channels.Channels;
3637
import java.nio.charset.StandardCharsets;
38+
import java.util.Arrays;
3739
import java.util.HashMap;
3840
import java.util.Map;
3941
import org.apache.beam.sdk.io.FileSystems;
@@ -256,6 +258,42 @@ public void primaryKeyNotPresentInSourceSchema() {
256258
assertTrue(sql.isEmpty());
257259
}
258260

261+
@Test
262+
public void tableOnlyContainsPrimaryKeyColumns() {
263+
String sessionFile = "src/test/resources/onlyPKColumnsSession.json";
264+
Ddl ddl = SchemaUtils.buildSpannerDdlFromSessionFile(sessionFile);
265+
SourceSchema sourceSchema = SchemaUtils.buildSourceSchemaFromSessionFile(sessionFile);
266+
ISchemaMapper schemaMapper = new SessionBasedMapper(sessionFile, ddl);
267+
268+
String tableName = "resource_access";
269+
String newValuesString = "{\"user_id\":\"101\",\"group_id\":\"5\",\"resource_id\":\"99\"}";
270+
JSONObject newValuesJson = new JSONObject(newValuesString);
271+
// the keys and the newValues are the same because all the columns are part of the key
272+
JSONObject keyValuesJson = new JSONObject(newValuesString);
273+
String modType = "INSERT";
274+
275+
/*The expected sql is:
276+
INSERT INTO `resource_access`(`user_id`,`group_id`,`resource_id`) VALUES (101,5,99) ON DUPLICATE KEY UPDATE `user_id` = 101, `group_id` = 5, `resource_id` = 99
277+
*/
278+
MySQLDMLGenerator mySQLDMLGenerator = new MySQLDMLGenerator();
279+
DMLGeneratorResponse dmlGeneratorResponse =
280+
mySQLDMLGenerator.getDMLStatement(
281+
new DMLGeneratorRequest.Builder(
282+
modType, tableName, newValuesJson, keyValuesJson, "+00:00")
283+
.setSchemaMapper(schemaMapper)
284+
.setDdl(ddl)
285+
.setSourceSchema(sourceSchema)
286+
.build());
287+
String sql = dmlGeneratorResponse.getDmlStatement();
288+
assertThat(sql.contains("ON DUPLICATE KEY UPDATE"));
289+
assertTrue(sql.contains("`user_id` = 101"));
290+
assertTrue(sql.contains("`group_id` = 5"));
291+
assertTrue(sql.contains("`resource_id` = 99"));
292+
assertEquals(2, countInSQL(sql, "user_id"));
293+
assertEquals(2, countInSQL(sql, "group_id"));
294+
assertEquals(2, countInSQL(sql, "resource_id"));
295+
}
296+
259297
@Test
260298
public void timezoneOffsetMismatch() {
261299
String sessionFile = "src/test/resources/timeZoneSession.json";
@@ -965,10 +1003,13 @@ public void testSpannerKeyIsNull() {
9651003
.setSourceSchema(sourceSchema)
9661004
.build());
9671005
String sql = dmlGeneratorResponse.getDmlStatement();
968-
969-
assertTrue(
970-
sql.contains(
971-
"INSERT INTO `Singers`(`SingerId`,`FirstName`,`LastName`) VALUES (NULL,'kk','ll')"));
1006+
assertThat(sql.contains("ON DUPLICATE KEY UPDATE"));
1007+
assertTrue(sql.contains("`FirstName` = 'kk'"));
1008+
assertTrue(sql.contains("`SingerId` = NULL"));
1009+
assertTrue(sql.contains("`LastName` = 'll'"));
1010+
assertEquals(2, countInSQL(sql, "FirstName"));
1011+
assertEquals(2, countInSQL(sql, "SingerId"));
1012+
assertEquals(2, countInSQL(sql, "LastName"));
9721013
}
9731014

9741015
@Test
@@ -994,9 +1035,13 @@ public void testKeyInNewValuesJson() {
9941035
.setSourceSchema(sourceSchema)
9951036
.build());
9961037
String sql = dmlGeneratorResponse.getDmlStatement();
997-
assertTrue(
998-
sql.contains(
999-
"INSERT INTO `Singers`(`SingerId`,`FirstName`,`LastName`) VALUES (NULL,'kk','ll')"));
1038+
assertThat(sql.contains("ON DUPLICATE KEY UPDATE"));
1039+
assertTrue(sql.contains("`FirstName` = 'kk'"));
1040+
assertTrue(sql.contains("`SingerId` = NULL"));
1041+
assertTrue(sql.contains("`LastName` = 'll'"));
1042+
assertEquals(2, countInSQL(sql, "FirstName"));
1043+
assertEquals(2, countInSQL(sql, "SingerId"));
1044+
assertEquals(2, countInSQL(sql, "LastName"));
10001045
}
10011046

10021047
@Test
@@ -1249,8 +1294,11 @@ public void customTransformationMatch() {
12491294
.build());
12501295
String sql = dmlGeneratorResponse.getDmlStatement();
12511296

1297+
assertThat(sql.contains("ON DUPLICATE KEY UPDATE"));
12521298
assertTrue(sql.contains("`FullName` = 'kk ll'"));
1253-
assertTrue(sql.contains("VALUES (1,'kk ll')"));
1299+
assertTrue(sql.contains("`SingerId` = 1"));
1300+
assertEquals(2, countInSQL(sql, "FullName"));
1301+
assertEquals(2, countInSQL(sql, "SingerId"));
12541302
}
12551303

12561304
@Test
@@ -1264,4 +1312,13 @@ public void testConvertBase64ToXHex() {
12641312
IllegalArgumentException.class,
12651313
() -> MySQLDMLGenerator.convertBase64ToHex("####GOOGLE####"));
12661314
}
1315+
1316+
public long countInSQL(String sql, String targetWord) {
1317+
if (sql == null || sql.isEmpty()) {
1318+
return 0;
1319+
}
1320+
return Arrays.stream(sql.split("\\W+"))
1321+
.filter(word -> word.equalsIgnoreCase(targetWord))
1322+
.count();
1323+
}
12671324
}

0 commit comments

Comments
 (0)