diff --git a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchema.java b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchema.java index e851e29c23fb..666cde876eeb 100644 --- a/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchema.java +++ b/sdks/java/io/google-cloud-platform/src/main/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchema.java @@ -22,6 +22,7 @@ import com.google.cloud.spanner.Type; import java.io.Serializable; import java.util.List; +import java.util.Map; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableListMultimap; @@ -161,47 +162,49 @@ static Column create(String name, String spannerType, Dialect dialect) { public abstract Type getType(); + private static final Map GOOGLE_STANDARD_SQL_TYPE_MAP = + ImmutableMap.builder() + .put("BOOL", Type.bool()) + .put("INT64", Type.int64()) + .put("FLOAT32", Type.float32()) + .put("FLOAT64", Type.float64()) + .put("UUID", Type.string()) + .put("TOKENLIST", Type.bytes()) + .put("TIMESTAMP", Type.timestamp()) + .put("DATE", Type.date()) + .put("NUMERIC", Type.numeric()) + .put("JSON", Type.json()) + .build(); + private static final Map POSTGRES_TYPE_MAP = + ImmutableMap.builder() + .put("BOOLEAN", Type.bool()) + .put("BIGINT", Type.int64()) + .put("REAL", Type.float32()) + .put("DOUBLE PRECISION", Type.float64()) + .put("TEXT", Type.string()) + .put("BYTEA", Type.bytes()) + .put("TIMESTAMP WITH TIME ZONE", Type.timestamp()) + .put("DATE", Type.date()) + .put("SPANNER.COMMIT_TIMESTAMP", Type.timestamp()) + .put("SPANNER.TOKENLIST", Type.bytes()) + .put("UUID", Type.string()) + .build(); + private static Type parseSpannerType(String spannerType, Dialect dialect) { String originalSpannerType = spannerType; spannerType = spannerType.toUpperCase(); switch (dialect) { case GOOGLE_STANDARD_SQL: - if ("BOOL".equals(spannerType)) { - return Type.bool(); - } - if ("INT64".equals(spannerType)) { - return Type.int64(); - } - if ("FLOAT32".equals(spannerType)) { - return Type.float32(); - } - if ("FLOAT64".equals(spannerType)) { - return Type.float64(); + Type type = GOOGLE_STANDARD_SQL_TYPE_MAP.get(spannerType); + if (type != null) { + return type; } if (spannerType.startsWith("STRING")) { return Type.string(); } - if ("UUID".equals(spannerType)) { - return Type.string(); - } if (spannerType.startsWith("BYTES")) { return Type.bytes(); } - if ("TOKENLIST".equals(spannerType)) { - return Type.bytes(); - } - if ("TIMESTAMP".equals(spannerType)) { - return Type.timestamp(); - } - if ("DATE".equals(spannerType)) { - return Type.date(); - } - if ("NUMERIC".equals(spannerType)) { - return Type.numeric(); - } - if ("JSON".equals(spannerType)) { - return Type.json(); - } if (spannerType.startsWith("ARRAY")) { // Substring "ARRAY" String spannerArrayType = @@ -230,42 +233,19 @@ private static Type parseSpannerType(String spannerType, Dialect dialect) { Type itemType = parseSpannerType(spannerArrayType, dialect); return Type.array(itemType); } - if ("BOOLEAN".equals(spannerType)) { - return Type.bool(); - } - if ("BIGINT".equals(spannerType)) { - return Type.int64(); - } - if ("REAL".equals(spannerType)) { - return Type.float32(); + type = POSTGRES_TYPE_MAP.get(spannerType); + if (type != null) { + return type; } - if ("DOUBLE PRECISION".equals(spannerType)) { - return Type.float64(); - } - if (spannerType.startsWith("CHARACTER VARYING") || "TEXT".equals(spannerType)) { + if (spannerType.startsWith("CHARACTER VARYING")) { return Type.string(); } - if ("BYTEA".equals(spannerType)) { - return Type.bytes(); - } - if ("TIMESTAMP WITH TIME ZONE".equals(spannerType)) { - return Type.timestamp(); - } - if ("DATE".equals(spannerType)) { - return Type.date(); - } if (spannerType.startsWith("NUMERIC")) { return Type.pgNumeric(); } - if ("SPANNER.COMMIT_TIMESTAMP".equals(spannerType)) { - return Type.timestamp(); - } if (spannerType.startsWith("JSONB")) { return Type.pgJsonb(); } - if ("UUID".equals(spannerType)) { - return Type.string(); - } throw new IllegalArgumentException("Unknown spanner type " + spannerType); default: throw new IllegalArgumentException("Unrecognized dialect: " + dialect.name()); diff --git a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTest.java b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTest.java index fefe7dc1ef85..b82a1d4fbddd 100644 --- a/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTest.java +++ b/sdks/java/io/google-cloud-platform/src/test/java/org/apache/beam/sdk/io/gcp/spanner/SpannerSchemaTest.java @@ -86,14 +86,16 @@ public void testSinglePgTable() throws Exception { .addColumn("test", "numericVal", "numeric") .addColumn("test", "commitTime", "spanner.commit_timestamp") .addColumn("test", "jsonbCol", "jsonb") + .addColumn("test", "tokens", "spanner.tokenlist") .addColumn("test", "uuidCol", "uuid") .build(); assertEquals(1, schema.getTables().size()); - assertEquals(6, schema.getColumns("test").size()); + assertEquals(7, schema.getColumns("test").size()); assertEquals(1, schema.getKeyParts("test").size()); assertEquals(Type.timestamp(), schema.getColumns("test").get(3).getType()); - assertEquals(Type.string(), schema.getColumns("test").get(5).getType()); + assertEquals(Type.bytes(), schema.getColumns("test").get(5).getType()); + assertEquals(Type.string(), schema.getColumns("test").get(6).getType()); } @Test